kensho-technologies / bubs

Keras Implementation of Flair's Contextualized Embeddings
Apache License 2.0
26 stars 9 forks source link

On forward and backward index input #21

Closed pinesnow72 closed 3 years ago

pinesnow72 commented 3 years ago

I understand that forward and backward index input is used for gathering token embedding from char-lstm sequence result. This is achieved by calculating sentence id (in a batch) and forward/backward index (the forward and backward index input's shape is (batch_size, MAX_TOKEN_SEQUENCE_LEN, 2)) and using tf.gather_nd() in embedding_layer.py. By the way, tf.gather_nd() can be used with setting batch_dims=1 (the number of batch dimensions; default is 0). So, without the sentence id (in a batch), we can gather token embedding only with forward/backward index and the forward and backward index input shape could just be (batch_size, MAX_TOKEN_SEQUENCE_LEN, 1). This would be better especially when we used model.fit() (not generator).

So, I changed the code related to them as follows:

   def _prepare_index_array(self, index_list):
        """Make a 2D array where each row is a padded array of character-level token-end indices."""
        pad_len = self.max_token_sequence_len
        batch_size = len(index_list)
        padding_index = 0
        # padded_sentences = np.full((batch_size, pad_len, 2), padding_index, dtype=np.int32)
        padded_sentences = np.full((batch_size, pad_len, 1), padding_index, dtype=np.int32)
        for i in range(batch_size):
            clipped_len = min(len(index_list[i]), pad_len)
            # padded_sentences[i, :, 0] = i
            if self.prepad:
                # padded_sentences[i, pad_len - clipped_len:, 1] = index_list[i][:clipped_len]
                padded_sentences[i, pad_len - clipped_len:, 0] = index_list[i][:clipped_len]
            else:
                # padded_sentences[i, :clipped_len, 1] = index_list[i][:clipped_len]
                padded_sentences[i, :clipped_len, 0] = index_list[i][:clipped_len]
        return padded_sentences
def batch_indexing(inputs):
    """Index a character-level embedding matrix at token end locations.

    Args:
        inputs: a list of two tensors:
            tensor1: tensor of (batch_size, max_char_seq_len, char_embed_dim*2) of all char-level
                embeddings
            tensor2: tensor of (batch_size, max_token_seq_len, 2) of indices of token ends.
                Something like [[[0, 1], [0, 5]], [[1, 2], [1, 3]], ...]. The last dimension is 2
                because pairs of (sentence_index, token_index)

    Returns:
        A tensor of (batch_size, max_token_seq_len, char_embed_dim*2) of char-level embeddings
            at ends of tokens
    """
    embeddings, indices = inputs
    # this will break on deserialization if we simply import tensorflow
    # we have to use keras.backend.tf instead of tensorflow
    # return tf.gather_nd(embeddings, indices)
    return tf.gather_nd(embeddings, indices, batch_dims=1)
forward_index_input = Input(
    # batch_shape=(None, MAX_TOKEN_SEQUENCE_LEN, 2), name="forward_index_input", dtype="int32"
    batch_shape=(None, MAX_TOKEN_SEQUENCE_LEN, 1), name="forward_index_input", dtype="int32"
)
backward_index_input = Input(
    # batch_shape=(None, MAX_TOKEN_SEQUENCE_LEN, 2), name="backward_index_input", dtype="int32"
    batch_shape=(None, MAX_TOKEN_SEQUENCE_LEN, 1), name="backward_index_input", dtype="int32"
)
ydovzhenko commented 3 years ago

Sounds good to me! If you open a PR, I'll happily review it:-)