当我们使用序列到序列(sequence-to-sequence)模型时,在每一步中,我们需要将前一步的输出馈送到下一步作为输入。如果使用基本的 TensorFlow API来进行操作,会非常麻烦。tf.contrib.seq2seq.dynamic_decode()函数则可以简化此过程。
tf.contrib.seq2seq.dynamic_decode函数主要作用是将输入的Decoder RNN(例如LSTM或GRU)在一个循环过程中迭代运行,各个时间步之间共享权值,每个时间步上使用的输入是上一个时间步上的输出。函数的主要输入分为三部分:decoder、initial_state和maximum_iterations。其中,最重要的一个是decoder,因为他是描述整个序列到序列模型的网络结构,这个decoder可以使用tf.contrib.seq2seq中的一些实现类,例如 tf.contrib.seq2seq.BasicDecoder 。initial_state 是RNN的初始隐藏状态,而maximum_iterations定义了模型处理的最长步数。
使用方法:
outputs, final_state, _=tf.contrib.seq2seq.dynamic_decode(decoder=decoder
imitial_state=initial_state,
maximum_iterations=maximum_iterations)
其中outputs是decoder的输出,在大部分情况下是一个形状为[batch_size, max_sequence_length, num_decoder_symbols]的张量;final_state则记录了最后一步的状态,常常作为下一个batch的initial_state;_则是一些辅助信息,例如长度信息等。
举个例子,下面是一个使用方式:
import tensorflow as tf
from tensorflow.contrib.seq2seq import BasicDecoder, dynamic_decode
# 假设我们的目标是使用一个LSTM在一段序列中找到最大值所在的位置
sequence_length = 10 # 输入序列长度
batch_size = 4 # 批大小
input_dim = 5 # 输入维度
hidden_dim = 10 # LSTM隐藏层大小
num_classes = 3 # 序列分类数
tf.reset_default_graph()
# 用于存储所有的LSTM状态
initial_state = tf.zeros((batch_size, hidden_dim))
# 模型的输入
inputs = tf.placeholder(tf.float32, shape=[batch_size, sequence_length, input_dim], name='inputs')
targets = tf.placeholder(tf.int32, shape=[batch_size, sequence_length], name='targets')
# 构建Decoder
decoder_cell = tf.nn.rnn_cell.LSTMCell(num_units=hidden_dim)
decoder = BasicDecoder(cell=decoder_cell, helper=tf.contrib.seq2seq.TrainingHelper(inputs=inputs, sequence_length=[sequence_length]*batch_size), initial_state=initial_state)
# 使用dynamic_decode操作完成模型的运行
outputs, final_state, _ = dynamic_decode(decoder=decoder, output_time_major=False, imitial_state=initial_state)
# 定义损失函数和优化器
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=outputs.rnn_output))
optimizer = tf.train.AdamOptimizer().minimize(loss)
# 模型训练
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(100):
inputs_data = np.random.random(size=[batch_size, sequence_length, input_dim])
targets_data = np.random.randint(0, num_classes, size=[batch_size, sequence_length])
_, loss_val = sess.run([optimizer, loss], feed_dict={inputs: inputs_data, targets: targets_data})
print('step: %d, loss: %f' % (i, loss_val))
另外一个示例是使用 seq2seq 模型进行机器翻译,在这个例子里面,我们可以使用dynamic_decode将RNN从头开始迭代,然后获得输出翻译结果的概率分布。在这种情况下,我们需要定义一个不同的helper对象以提供decoder所需要的输入,另外还需要实现一个对应的输出层对翻译结果进行预测。