tf.contrib.seq2seq.BasicDecoder 是 TensorFlow 提供的一种序列到序列模型中的解码器(Decoder)类,用于生成文本序列。该类通常与 tf.contrib.seq2seq.dynamic_decode 函数搭配使用,结合 Encoder 的输出,生成最终的文本序列。
函数的使用方法如下:
1.参数列表:
BasicDecoder(
cell,
helper,
initial_state,
output_layer=None
)
参数解释:
– cell: RNNCell对象,即解码器中的RNN层,可以使用tf.nn.rnn_cell.LSTMCell等
– helper: tf.contrib.seq2seq.Helper对象,即用于读取解码器输入的对象
– initial_state: RNNCell对象的初始化状态,通常使用encoder输出的状态作为初始状态
– output_layer: 用于处理解码器输出的可选层
- 返回值:
BasicDecoderOutput实例,其中包含以下属性: - rnn_output: 对应解码器每一步的输出
-
sample_id: 解码器每一步的输出的id
-
示例:
下面分别演示两个例子,分别是基于RNN的语言模型和基于seq2seq的文本生成。
a. 基于RNN的语言模型
import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell
from tensorflow.contrib.seq2seq import BasicDecoder, TrainingHelper, dynamic_decode
vocab_size = 100
embedding_size = 50
hidden_size = 128
# 构建RNNCell
cell = LSTMCell(hidden_size)
# 构建输入
enc_inputs = tf.placeholder(tf.int32, shape=[None, None])
enc_inputs_length = tf.placeholder(tf.int32, shape=[None])
# 构建embedding层,并将输入转化为embedding表示
embedding = tf.get_variable('embedding', [vocab_size, embedding_size])
enc_inputs_embedded = tf.nn.embedding_lookup(embedding, enc_inputs)
# 设置训练helper
dec_sequence_length = tf.placeholder(tf.int32, shape=[None])
teacher_forcing_helper = TrainingHelper(enc_inputs_embedded,
dec_sequence_length)
# BasicDecoder的initial_state使用RNN cell的zero_state
initial_state = cell.zero_state(batch_size=tf.shape(enc_inputs)[0], dtype=tf.float32)
# 定义decoder
decoder = BasicDecoder(
cell=cell,
helper=teacher_forcing_helper,
initial_state=initial_state
)
# 调用dynamic_decode运行decoder
outputs, _ = dynamic_decode(
decoder=decoder,
maximum_iterations=tf.reduce_max(dec_sequence_length),
impute_finished=True,
swap_memory=True
)
logits = outputs.rnn_output
b. 基于seq2seq的文本生成
import tensorflow as tf
from tensorflow.contrib.seq2seq import BasicDecoder, TrainingHelper, dynamic_decode
from tensorflow.contrib.seq2seq import BahdanauAttention, BahdanauMonotonicAttention
vocab_size = 100
embedding_size = 50
hidden_size = 128
# 构建encoder输入
enc_inputs = tf.placeholder(tf.int32, shape=[None, None])
enc_inputs_length = tf.placeholder(tf.int32, shape=[None])
# 构建embedding层,并将输入转化为embedding表示
embedding = tf.get_variable('embedding', [vocab_size, embedding_size])
enc_inputs_embedded = tf.nn.embedding_lookup(embedding, enc_inputs)
# 构建decoder输入
decoder_inputs = tf.placeholder(tf.int32, shape=[None, None])
decoder_inputs_length = tf.placeholder(tf.int32, shape=[None])
# 构建helper
helper = TrainingHelper(
inputs=tf.nn.embedding_lookup(embedding, decoder_inputs),
sequence_length=decoder_inputs_length)
# 构建attention机制
attention_mechanism = BahdanauMonotonicAttention(
num_units=hidden_size,
memory=enc_inputs_embedded,
memory_sequence_length=enc_inputs_length)
# 构建decoder cell
decoder_cell = LSTMCell(hidden_size)
# 构建decoder,并将attention机制加入基础decoder
decoder = BasicDecoder(
cell=decoder_cell,
helper=helper,
initial_state=decoder_cell.zero_state(batch_size=tf.shape(encoder_inputs)[0], dtype=tf.float32),
output_layer=tf.layers.Dense(vocab_size)
)
# 调用dynamic_decode运行decoder
outputs, _ = dynamic_decode(
decoder=decoder,
maximum_iterations=tf.reduce_max(decoder_inputs_length),
impute_finished=True,
swap_memory=True,
attention_mechanism=attention_mechanism
)
predicted_ids = outputs.sample_id