QA系统Match-LSTM代码研读

QA系统Match-LSTM代码研读

背景

在QA模型中,Match-LSTM是较早提出的,使用Prt-Net边界模型。本文是对阅读其实现代码的总结。主要思路是对照着论文和代码,对论文中模型的关键结构,查看代码中的具体实现。参考代码是MurtyShikhar实现的

模型简介

模型的输入是(Passage, Question),模型的输出是(start_idx, end_idx)。对于输入,Passage是QA任务中的正文,输入给模型时已经转化为经过Padding的id-list;Question是QA任务中的问题,输入给模型时已经转化为经过Padding的id-list。对于输出,start_idx是答案在正文的起始位置,end_idx是答案在正文的结束位置。

用于QA的Match-LSTM模型主要由三层构成:

  1. LSTM预处理层。
    分别将Passage和Question通过LSTM进行处理,使每个位置的表示都带有一些上下文信息。
  2. Match-LSTM层。
    Match-LSTM最早用于文本蕴含,输入一个前提,一个猜测,判断前提是否能蕴含猜测。在用于QA任务时,Question被当做前提,Passage被当做猜测。依次处理Passage的每个位置,计算Passage每个位置对Question的Attention,进而求出对Question的Attend Vector。该Attend Vector与第一层的输出拼接起来,输入给一个LSTM进行处理,这整个流程被称作Match-LSTM。
    其中Attention选择BahdanauAttention,Attention的输入(Query)由上一时刻Match-LSTM的输出及Passage在当前位置的表示拼接,Attention的key是Question每个位置的表示,Attention的value也是Question每个位置的表示。根据Attention的alignment对Attention Value加权求和计算出Attend Vector。
    所以,Match-LSTM本质上由一个LSTM单元和一个Attention单元组成。LSTM单元的输出作为Match-LSTM层的输出,LSTM单元的状态和下一个位置的输入拼接起来作为Attention单元的输入(Query),Attention单元的输出(Attend Vector)与当前位置的输入拼接起来作为LSTM单元的输入。也可以理解为在LSTM的基础上增加Attention,改变LSTM的输入,在LSTM的原始输入上增加当前位置对于Question的Attention。
  3. Pointer-Net层。
    Pointer-Net层在代码实现上,与Match-LSTM十分接近。只在涉及输入、输出的地方有几处不同。从原理上看,Pointer-Net层也是一个序列化迭代的Attention过程,首先用zero_state作为query对Match-LSTM层的所有输出计算attention,作为回答第一个符号的logit。然后以AttentionWrapper的输出作为下一时刻的query,对Match-LSTM层的所有输出计算attention,如此迭代进行。对于边界模型,秩序计算start_index和end_index,这个迭代过程秩序进行两次。

接下来的几部分对照论文及代码中模型关键结构实现。

模型

模型图构建的入口在qa_model.py文件中class QASystem类的def setup_system(self)方法内。这一节主要就是对该方法的细节展开解读。

LSTM预处理层

所有逻辑都包含在qa_model.py文件中,入口位于class QASystem类的def setup_system(self)方法内,具体逻辑位于class Encoderdef encode(self, inputs, masks, encoder_state_input = None)方法内。

def setup_system(self)方法内,通过以下语句调用class Encoderdef encode(self, inputs, masks, encoder_state_input = None)方法。

encoder = self.encoder
decoder = self.decoder
encoded_question, encoded_passage, q_rep, p_rep = encoder.encode([self.question, self.passage],
                 [self.question_lengths, self.passage_lengths], encoder_state_input = None)

再看一下encode方法的实现。

def encode(self, inputs, masks, encoder_state_input = None):
    """
    :param inputs: vector representations of question and passage (a tuple) 
    :param masks: masking sequences for both question and passage (a tuple)
    :param encoder_state_input: (Optional) pass this as initial hidden state to tf.nn.dynamic_rnn to build conditional representations
    :return: an encoded representation of the question and passage.
    """
    
    question, passage = inputs
    masks_question, masks_passage = masks

    # read passage conditioned upon the question
    with tf.variable_scope("encoded_question"):
        lstm_cell_question = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
        encoded_question, (q_rep, _) = tf.nn.dynamic_rnn(lstm_cell_question, question, masks_question, dtype=tf.float32) # (-1, 
Q, H)

    with tf.variable_scope("encoded_passage"):
        lstm_cell_passage  = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
        encoded_passage, (p_rep, _) =  tf.nn.dynamic_rnn(lstm_cell_passage, passage, masks_passage, dtype=tf.float32) # (-1, P, 
H)
    # outputs beyond sequence lengths are masked with 0s
    return encoded_question, encoded_passage , q_rep, p_rep

从代码可以看出,对Passage和Question的预处理就是分别经过两个单向LSTM层(不共享参数),LSTM每个位置的输出作为预处理后的表示。

Match-LSTM层

Match-LSTM的逻辑主要在qa_model.py和attention_wrapper.py两个文件中。虽然tensorflow的contrib库中现在也有attention_wrapper这个模块,但是两者在具体实现上不太相同。入口位于qa_model.py文件class Decoder类中decode方法内。

首先,看一下最外层的入口,与LSTM预处理层一样,位于class QASystem类的def setup_system(self)方法内。

if self.config.use_match:
    self.logger.info("\n========Using Match LSTM=========\n")
    logits= decoder.decode([encoded_question, encoded_passage], q_rep, [self.question_lengths, self.passage_lengths], self.
labels)

接下来,进入class Decoder类中decode方法。函数逻辑非常清晰,先通过Match-LSTM层,再通过Ptr-Net层。

def decode(self, encoded_rep, q_rep, masks, labels):
    output_attender = self.run_match_lstm(encoded_rep, masks)
    logits = self.run_answer_ptr(output_attender, masks, labels)

    return logits

然后进入run_match_lstm方法。

def run_match_lstm(self, encoded_rep, masks):
    encoded_question, encoded_passage = encoded_rep
    masks_question, masks_passage = masks

    match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
    query_depth = encoded_question.get_shape()[-1]


    # output attention is false because we want to output the cell output and not the attention values
    with tf.variable_scope("match_lstm_attender"):
        attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
        cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
        lstm_attender  = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)

        # we don't mask the passage because masking the memories will be handled by the pointerNet
        reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)

        output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
        output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")

        output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)


    output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
    return output_attender

该方法的输入encoded_rep是一个tuple,包含Passage和Question的表示;masks也是一个tuple,包含Passage和Question的长度。

match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)

这条语句定义了Match-LSTM单元中AttentionMechanism的输入函数,作为参数该函数被传递给AttentionWrapper的构造函数,作为attention_input_fnAttentionWrapper本身也是一个RNN,它组合了一个RNN和一个AttentionMechanism,形成一个高级的RNN单元。该函数就是定义了用于Attention机制的Query是如何生成的,由当前时刻的输入拼接上一个时刻的state,形成Attention的Query。

attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)

这条语句定义了一个AttentionMechanism,也就是一个Attention单元,该类包含一个__call__方法,调用该对象可以计算出alignments,调用该类对象的参数如方法定义所示def __call__(self, query, previous_alignments)。联系上面一起来看,这里的query就是上面所说的Attention的Query。
至于BahdanauAttention是如何实现的,暂时不做过详细的介绍,目前该类位于tf.contrib.seq2seq.BahdanauAttention,已经是tensorflow库的一部分。

cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)

这条语句定义一个普通的LSTM单元。

lstm_attender  = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)

这条语句将上面两步定义的AttentionMechanism及LSTM单元组装为一个高级RNN单元。参数还包括了在run_match_lstm方法一开头顶一个的一个函数,该函数用来生成AttentionMechanismquery

reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)

分别正向、反向对Passage的表示应用Match-LSTM,再将输出沿最后一个维度拼接起来作为Match-LSTM层的输出。

我们还可以再近距离看一下LSTM单元和AttentionMechanism是如何配合工作的,这需要深入到AttentionWrappercall方法,这也是所有RNN单元都需要实现的一个方法。

def call(self, inputs, state):
    output_prev_step = state.cell_state.h # get hr_(i-1)
    attention_input = self._attention_input_fn(inputs, output_prev_step) # get input to BahdanauAttention to get alpha_i
    alignments, raw_scores = self._attention_mechanism(
        attention_input, previous_alignments=state.alignments)

    expanded_alignments = array_ops.expand_dims(alignments, 1)

    attention_mechanism_values = self._attention_mechanism.values
    context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
    context = array_ops.squeeze(context, [1])


    cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
    cell_state = state.cell_state
    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

    if self._attention_layer is not None:
      attention = self._attention_layer(
          array_ops.concat([cell_output, context], 1))
    else:
      attention = context

    if self._alignment_history:
      alignment_history = state.alignment_history.write(
          state.time, alignments)
    else:
      alignment_history = ()

    next_state = AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        alignments=alignments,
        alignment_history=alignment_history)

    if self._output_attention:
      return raw_scores, next_state
    else:
      return cell_output, next_state
output_prev_step = state.cell_state.h # get hr_(i-1)
attention_input = self._attention_input_fn(inputs, output_prev_step)

取LSTM单元上一时刻的状态,与AttentionWrapper当前时刻的输入,通过self._attention_input_fn函数生成attention的Query。这里的self._attention_input_fn就是上面AttentionWrapper构造函数的参数attention_input_fn

alignments, raw_scores = self._attention_mechanism(attention_input, previous_alignments=state.alignments)

调用AttentionMechaism对象,计算Attention的alignments。这里的self._attention_mechanism就是AttentionWrapper构造函数的参数attention_mechanism_match_lstm,也就是BahdanauAttention的一个对象。

expanded_alignments = array_ops.expand_dims(alignments, 1)       # [batch_size, 1, ques_size]
attention_mechanism_values = self._attention_mechanism.values   # [batch_size, ques_size, value_dims]
context = math_ops.matmul(expanded_alignments, attention_mechanism_values) # [batch_size, 1, value_dims]
context = array_ops.squeeze(context, [1])   # [batch_size, value_dims]

通过alignments和attention的Values,计算attend vector,就是对values以alignments为权重求和。

cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

通过_cell_input_fn将当前时刻的输入,和attend vector组合起来,成为当前时刻LSTM的输入。然后调用LSTM单元计算当前时刻LSTM单元的输出和状态。

if self._attention_layer is not None:
   attention = self._attention_layer(
                     array_ops.concat([cell_output, context], 1))
else:
   attention = context

是否需要对attend vector再进行一次线性变换,作为attention,在本例中未做变换,直接用attend vector作为attention。

next_state = AttentionWrapperState(
      time=state.time + 1,
      cell_state=next_cell_state,
      attention=attention,
      alignments=alignments,
      alignment_history=alignment_history)

作为RNN的AttentionWrapper的下一时刻状态。

if self._output_attention:
    return raw_scores, next_state
  else:
    return cell_output, next_state

根据构造函数的参数,决定AttentionWrapper的输出是attention score还是LSTM的输出,attention score的意义是求alignments概率之前的那个东西。

Pointer-Net层

以下代码是Pointer-Net层的逻辑,与Match-LSTM层的逻辑非常接近,但是在一些细节上有所区别。相似的部分是,Pointer-Net层的主体也是通过一个AttentionWrapper完成的,也是组装了一个LSTM单元和一个BahdanauAttention单元。与Match-LSTM不同的地方是,LSTM单元及BahdanauAttention单元的输入函数不同,AttentionWrapper的输出内容不同,并且Pointer-Net层使用一个静态rnn。

def run_answer_ptr(self, output_attender, masks, labels):
    batch_size = tf.shape(output_attender)[0]
    masks_question, masks_passage = masks
    labels = tf.unstack(labels, axis=1) 
    #labels = tf.ones([batch_size, 2, 1])


    answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
    query_depth_answer_ptr = output_attender.get_shape()[-1]

    with tf.variable_scope("answer_ptr_attender"):
        attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
        # output attention is true because we want to output the attention values
        cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
        answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
        logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)

        return logits 

接下来具体看一下上面这段代码。

batch_size = tf.shape(output_attender)[0]       # [batch_size, passage_length, 2 * hidden_size]
masks_question, masks_passage = masks
labels = tf.unstack(labels, axis=1)     # labels : [batch_size, 2]

output_attender是上一层,也就是Match-LSTM层的输出,形状为[batch_size, passage_length, 2 * hidden_size]labels的形状为[batch_size, 2]masks_questionmasks_passage分别为问题的长度和文章的长度。

answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
query_depth_answer_ptr = output_attender.get_shape()[-1]

answer_ptr_cell_input_fn定义了AttentionWrapperLSTM单元的输入函数。query_depth_answer_ptr从变量名的字面含义看,是Answer-Ptr层的attention单元的query的维度。

with tf.variable_scope("answer_ptr_attender"):
   attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
   # output attention is true because we want to output the attention values
   cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
   answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)

接下来是装配AttentionWrapper这里与Match-LSTM层有区别。在Match-LSTM层的定义中,没有显式地为AttentionWrapper指定cell_input_fn参数,而是使用了默认函数。在Match-LSTM层的定义中,显式指定了attention_input_fn,但是这里没有指定,使用了默认函数。另外一个区别,在Match-LSTM层的定义中,AttentionWrapperoutput_attention参数是False,在这里该参数用默认的True

对比Match-LSTM层与Pointer-Net层cell_input_fn的区别。

默认的cell_input_fn的定义如下,这是Match-LSTM层采用的。逻辑是将attention的输出和当前的输入拼接起来,作为LSTM单元的输入。

if cell_input_fn is None:
   cell_input_fn = ( 
       lambda inputs, attention: array_ops.concat([inputs, attention], -1))

Pointer-Net层使用的cell_input_fn在上面的代码中已经给出,这里对比一下。只用Attention单元的输出,作为LSTM单元的输入。这样,LSTM单元的输入,就与RNN的输入无关了。

answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question

对比Match-LSTM层与Pointer-Net层attention_input_fn的区别。

Match-LSTM层采用的attention_input_fn是非默认的,在上一节中已经给出,这里对比一下。

match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)

Pointer-Net层的attention_input_fn是默认的,定义如下。

if attention_input_fn is None:
   attention_input_fn = ( 
       lambda _, state: state)

可以看出,在Match-LSTM层,attention单元的输入是上一时刻状态与当前输入的拼接。在Pointer-Net层,attention单元的输入仅仅是上一时刻的状态,与当前时刻的输入无关。

综上两处,可以看出区别。在Match-LSTM层,无论Attention单元还是LSTM单元,其输入都要拼接当前时刻输入。而在Pointer-Net层,无论Attention单元还是LSTM单元,其输入都与当前时刻的输入无关。这也解释了我最早看代码时的疑惑,为什么计算logits的函数需要labels作为参数,labels不是只有在计算loss的时候才需要吗?其实虽然这里有labels这个参数,但是没有实际使用其内容,对于预测过程,只需传一个同样形状的tensor就可以。

再对比最后一个区别,Match-LSTM层与Pointer-Net层在output_attention参数上的区别。

if self._output_attention:
    return raw_scores, next_state
else:
    return cell_output, next_state

raw_scoresattention单元的原始输出,即通过softmax计算alignments前的那个输出。cell_outputLSTM单元的输出,也就是状态h。在Match-LSTM层,AttentionWrapper输出的是其内部LSTM单元的输出。在Pointer-Net层,AttentionWrapper输出的是其内部attention单元的raw_scores

logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)

最后是计算logits。因为labels是个长度为2的listlogits也是长度为2的list。但是,这两个list中元素的shape是不一样的,labels中的元素的shape[batch_size, 1],logits中的元素的shape[batch_size, passage_length]
从代码层面来理解,首先是以zero_state为query去计算attention,attention单元的key和value都是Match-LSTM层的输出,attention计算的raw_score就是第一个输出的logitattention计算出的alignmentsvalues计算attend vector,以其为输入计算LSTM单元的输出,作为下一时刻的query去计算attention。这样,就计算出了两个logits

至此,计算出logits,预测部分就已经完成了。logits是一个长度为2的list,其中每个元素是一个shape[batch_size, passage_length]的tensor。

损失函数

有了logits,就可以计算损失函数了。

losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[0], labels=self.labels[:,0])
losses += tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[1], labels=self.labels[:,1])
self.loss = tf.reduce_mean(losses)

这里只需要理解一个函数即可tf.nn.sparse_softmax_cross_entropy_with_logits,该函数logits参数的ranklabels多1,多出的那个axis的维度是num_classeslabels以稀疏形式表示,每个元素都是整数,小于num_classes

由于之前已经知道,Pointer-Net层求出logits是一个list,每个元素的形状是[batch_size, passage_length],而输入的labels的形状是[batch_size, 2]。因此按照上面代码的方式调用可求出损失函数。