lsdefine / attention-is-all-you-need-keras

A Keras+TensorFlow Implementation of the Transformer: Attention Is All You Need
708 stars 188 forks source link

seq2seq confused with shape #26

Closed thomasyue closed 5 years ago

thomasyue commented 5 years ago

want to play around with the transformer, but I'm confused with shapes.

print(train[0]) [ 2 4 1 283 51 283 986 6 284 8 226 227 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] train.shape is (1000, 57)


Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 57)           0                                            
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 57, 300)      865200      input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 57, 300)      90000       embedding_2[0][0]                
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 57, 300)      90000       embedding_2[0][0]                
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 57)           0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, None, None)   0           dense_1[0][0]                    
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, None, None)   0           dense_2[0][0]                    
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 57)           0           lambda_3[0][0]                   
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 57)           0                                            
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, None, None)   0           lambda_4[0][0]                   
                                                                 lambda_5[0][0]                   
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 57)           0           lambda_7[0][0]                   
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 56)           0           input_2[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, None)   0           lambda_8[0][0]                   
                                                                 lambda_9[0][0]                   
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 56, 300)      865200      lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_12 (Lambda)              (None, 56, 56)       0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_13 (Lambda)              (None, None, None)   0           lambda_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None)   0           add_1[0][0]                      
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 57, 300)      90000       embedding_2[0][0]                
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 56, 300)      90000       embedding_3[0][0]                
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 56, 300)      90000       embedding_3[0][0]                
__________________________________________________________________________________________________
lambda_14 (Lambda)              (None, 56, 56)       0           lambda_12[0][0]                  
                                                                 lambda_13[0][0]                  
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, None, None)   0           activation_1[0][0]               
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, None, None)   0           dense_3[0][0]                    
__________________________________________________________________________________________________
lambda_16 (Lambda)              (None, None, None)   0           dense_5[0][0]                    
__________________________________________________________________________________________________
lambda_17 (Lambda)              (None, None, None)   0           dense_6[0][0]                    
__________________________________________________________________________________________________
lambda_19 (Lambda)              (None, 56, 56)       0           lambda_14[0][0]                  
__________________________________________________________________________________________________
lambda_10 (Lambda)              (None, None, None)   0           dropout_1[0][0]                  
                                                                 lambda_6[0][0]                   
__________________________________________________________________________________________________
lambda_20 (Lambda)              (None, None, None)   0           lambda_16[0][0]                  
                                                                 lambda_17[0][0]                  
__________________________________________________________________________________________________
lambda_21 (Lambda)              (None, 56, 56)       0           lambda_19[0][0]                  
__________________________________________________________________________________________________
lambda_11 (Lambda)              (None, None, 300)    0           lambda_10[0][0]                  
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, None)   0           lambda_20[0][0]                  
                                                                 lambda_21[0][0]                  
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, None, 300)    90300       lambda_11[0][0]                  
__________________________________________________________________________________________________
activation_2 (Activation)       (None, None, None)   0           add_4[0][0]                      
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 56, 300)      90000       embedding_3[0][0]                
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, None, 300)    0           time_distributed_1[0][0]         
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, None, None)   0           activation_2[0][0]               
__________________________________________________________________________________________________
lambda_18 (Lambda)              (None, None, None)   0           dense_7[0][0]                    
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, 300)    0           embedding_2[0][0]                
                                                                 dropout_6[0][0]                  
__________________________________________________________________________________________________
lambda_22 (Lambda)              (None, None, None)   0           dropout_3[0][0]                  
                                                                 lambda_18[0][0]                  
__________________________________________________________________________________________________
layer_normalization_2 (LayerNor (None, None, 300)    600         add_2[0][0]                      
__________________________________________________________________________________________________
lambda_23 (Lambda)              (None, None, 300)    0           lambda_22[0][0]                  
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 512)    154112      layer_normalization_2[0][0]      
__________________________________________________________________________________________________
time_distributed_2 (TimeDistrib (None, None, 300)    90300       lambda_23[0][0]                  
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, None, 300)    153900      conv1d_1[0][0]                   
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, None, 300)    0           time_distributed_2[0][0]         
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, None, 300)    0           conv1d_2[0][0]                   
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, 300)    0           embedding_3[0][0]                
                                                                 dropout_7[0][0]                  
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, 300)    0           dropout_2[0][0]                  
                                                                 layer_normalization_2[0][0]      
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, None, 300)    600         add_5[0][0]                      
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, None, 300)    600         add_3[0][0]                      
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, None, 300)    90000       layer_normalization_4[0][0]      
__________________________________________________________________________________________________
dense_10 (Dense)                (None, None, 300)    90000       layer_normalization_1[0][0]      
__________________________________________________________________________________________________
lambda_15 (Lambda)              (None, 56, 57)       0           lambda_1[0][0]                   
                                                                 input_1[0][0]                    
__________________________________________________________________________________________________
lambda_24 (Lambda)              (None, None, None)   0           dense_9[0][0]                    
__________________________________________________________________________________________________
lambda_25 (Lambda)              (None, None, None)   0           dense_10[0][0]                   
__________________________________________________________________________________________________
lambda_27 (Lambda)              (None, 56, 57)       0           lambda_15[0][0]                  
__________________________________________________________________________________________________
lambda_28 (Lambda)              (None, None, None)   0           lambda_24[0][0]                  
                                                                 lambda_25[0][0]                  
__________________________________________________________________________________________________
lambda_29 (Lambda)              (None, 56, 57)       0           lambda_27[0][0]                  
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, None)   0           lambda_28[0][0]                  
                                                                 lambda_29[0][0]                  
__________________________________________________________________________________________________
activation_3 (Activation)       (None, None, None)   0           add_6[0][0]                      
__________________________________________________________________________________________________
dense_11 (Dense)                (None, None, 300)    90000       layer_normalization_1[0][0]      
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, None, None)   0           activation_3[0][0]               
__________________________________________________________________________________________________
lambda_26 (Lambda)              (None, None, None)   0           dense_11[0][0]                   
__________________________________________________________________________________________________
lambda_30 (Lambda)              (None, None, None)   0           dropout_4[0][0]                  
                                                                 lambda_26[0][0]                  
__________________________________________________________________________________________________
lambda_31 (Lambda)              (None, None, 300)    0           lambda_30[0][0]                  
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, None, 300)    90300       lambda_31[0][0]                  
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, None, 300)    0           time_distributed_3[0][0]         
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, 300)    0           layer_normalization_4[0][0]      
                                                                 dropout_8[0][0]                  
__________________________________________________________________________________________________
layer_normalization_5 (LayerNor (None, None, 300)    600         add_7[0][0]                      
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, None, 512)    154112      layer_normalization_5[0][0]      
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, None, 300)    153900      conv1d_3[0][0]                   
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, None, 300)    0           conv1d_4[0][0]                   
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, 300)    0           dropout_5[0][0]                  
                                                                 layer_normalization_5[0][0]      
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, None, 300)    600         add_8[0][0]                      
__________________________________________________________________________________________________
time_distributed_4 (TimeDistrib (None, None, 57)     17100       layer_normalization_3[0][0]      
==================================================================================================
Total params: 3,447,424
Trainable params: 3,447,424
Non-trainable params: 0
__________________________________________________________________________________________________```

I wanna input the train data and output the exact same sentence as input.
how do I do it?
thomasyue commented 5 years ago

code is as follow

max_features = 2885
embed_size = 300
max_len = 57
d_emb = embed_size

import random, os, sys
import numpy as np
from keras.models import *
from keras.layers import *
from keras.callbacks import *
from keras.initializers import *
import tensorflow as tf
from keras.engine.topology import Layer

keras.backend.clear_session()
class LayerNormalization(Layer):
    def __init__(self, eps=1e-6, **kwargs):
        self.eps = eps
        super(LayerNormalization, self).__init__(**kwargs)
    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
                                     initializer=Ones(), trainable=True)
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
                                    initializer=Zeros(), trainable=True)
        super(LayerNormalization, self).build(input_shape)
    def call(self, x):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta
    def compute_output_shape(self, input_shape):
        return input_shape

class ScaledDotProductAttention():
    def __init__(self, attn_dropout=0.1):
        self.dropout = Dropout(attn_dropout)
    def __call__(self, q, k, v, mask):   # mask_k or mask_qk
        temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype='float32'))
        attn = Lambda(lambda x:K.batch_dot(x[0],x[1],axes=[2,2])/temper)([q, k])  # shape=(batch, q, k)
        if mask is not None:
            mmask = Lambda(lambda x:(-1e+9)*(1.-K.cast(x, 'float32')))(mask)
            attn = Add()([attn, mmask])
        attn = Activation('softmax')(attn)
        attn = self.dropout(attn)
        output = Lambda(lambda x:K.batch_dot(x[0], x[1]))([attn, v])
        return output, attn

class MultiHeadAttention():
    # mode 0 - big martixes, faster; mode 1 - more clear implementation
    def __init__(self, n_head, d_model, dropout, mode=0):
        self.mode = mode
        self.n_head = n_head
        self.d_k = self.d_v = d_k = d_v = d_model // n_head
        self.dropout = dropout
        if mode == 0:
            self.qs_layer = Dense(n_head*d_k, use_bias=False)
            self.ks_layer = Dense(n_head*d_k, use_bias=False)
            self.vs_layer = Dense(n_head*d_v, use_bias=False)
        elif mode == 1:
            self.qs_layers = []
            self.ks_layers = []
            self.vs_layers = []
            for _ in range(n_head):
                self.qs_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
                self.ks_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
                self.vs_layers.append(TimeDistributed(Dense(d_v, use_bias=False)))
        self.attention = ScaledDotProductAttention()
        self.w_o = TimeDistributed(Dense(d_model))

    def __call__(self, q, k, v, mask=None):
        d_k, d_v = self.d_k, self.d_v
        n_head = self.n_head

        if self.mode == 0:
            qs = self.qs_layer(q)  # [batch_size, len_q, n_head*d_k]
            ks = self.ks_layer(k)
            vs = self.vs_layer(v)

            def reshape1(x):
                s = tf.shape(x)   # [batch_size, len_q, n_head * d_k]
                x = tf.reshape(x, [s[0], s[1], n_head, s[2]//n_head])
                x = tf.transpose(x, [2, 0, 1, 3])  
                x = tf.reshape(x, [-1, s[1], s[2]//n_head])  # [n_head * batch_size, len_q, d_k]
                return x
            qs = Lambda(reshape1)(qs)
            ks = Lambda(reshape1)(ks)
            vs = Lambda(reshape1)(vs)

            if mask is not None:
                mask = Lambda(lambda x:K.repeat_elements(x, n_head, 0))(mask)
            head, attn = self.attention(qs, ks, vs, mask=mask)  

            def reshape2(x):
                s = tf.shape(x)   # [n_head * batch_size, len_v, d_v]
                x = tf.reshape(x, [n_head, -1, s[1], s[2]]) 
                x = tf.transpose(x, [1, 2, 0, 3])
                x = tf.reshape(x, [-1, s[1], n_head*d_v])  # [batch_size, len_v, n_head * d_v]
                return x
            head = Lambda(reshape2)(head)
        elif self.mode == 1:
            heads = []; attns = []
            for i in range(n_head):
                qs = self.qs_layers[i](q)   
                ks = self.ks_layers[i](k) 
                vs = self.vs_layers[i](v) 
                head, attn = self.attention(qs, ks, vs, mask)
                heads.append(head); attns.append(attn)
            head = Concatenate()(heads) if n_head > 1 else heads[0]
            attn = Concatenate()(attns) if n_head > 1 else attns[0]

        outputs = self.w_o(head)
        outputs = Dropout(self.dropout)(outputs)
        return outputs, attn

class PositionwiseFeedForward():
    def __init__(self, d_hid, d_inner_hid, dropout=0.1):
        self.w_1 = Conv1D(d_inner_hid, 1, activation='relu')
        self.w_2 = Conv1D(d_hid, 1)
        self.layer_norm = LayerNormalization()
        self.dropout = Dropout(dropout)
    def __call__(self, x):
        output = self.w_1(x) 
        output = self.w_2(output)
        output = self.dropout(output)
        output = Add()([output, x])
        return self.layer_norm(output)

class EncoderLayer():
    def __init__(self, d_model, d_inner_hid, n_head, dropout=0.1):
        self.self_att_layer = MultiHeadAttention(n_head, d_model, dropout=dropout)
        self.pos_ffn_layer  = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
        self.norm_layer = LayerNormalization()
    def __call__(self, enc_input, mask=None):
        output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
        output = self.norm_layer(Add()([enc_input, output]))
        output = self.pos_ffn_layer(output)
        return output, slf_attn

class DecoderLayer():
    def __init__(self, d_model, d_inner_hid, n_head, dropout=0.1):
        self.self_att_layer = MultiHeadAttention(n_head, d_model, dropout=dropout)
        self.enc_att_layer  = MultiHeadAttention(n_head, d_model, dropout=dropout)
        self.pos_ffn_layer  = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
        self.norm_layer1 = LayerNormalization()
        self.norm_layer2 = LayerNormalization()
    def __call__(self, dec_input, enc_output, self_mask=None, enc_mask=None, dec_last_state=None):
        if dec_last_state is None: dec_last_state = dec_input
        output, slf_attn = self.self_att_layer(dec_input, dec_last_state, dec_last_state, mask=self_mask)
        x = self.norm_layer1(Add()([dec_input, output]))
        output, enc_attn = self.enc_att_layer(x, enc_output, enc_output, mask=enc_mask)
        x = self.norm_layer2(Add()([x, output]))
        output = self.pos_ffn_layer(x)
        return output, slf_attn, enc_attn

def GetPosEncodingMatrix(max_len, d_emb):
    pos_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)] 
        if pos != 0 else np.zeros(d_emb) 
            for pos in range(max_len)
            ])
    pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2]) # dim 2i
    pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2]) # dim 2i+1
    return pos_enc

def GetPadMask(q, k):
    ones = K.expand_dims(K.ones_like(q, 'float32'), -1)
    mask = K.cast(K.expand_dims(K.not_equal(k, 0), 1), 'float32')
    mask = K.batch_dot(ones, mask, axes=[2,1])
    return mask

def GetSubMask(s):
    len_s = tf.shape(s)[1]
    bs = tf.shape(s)[:1]
    mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
    return mask

class Encoder():
    def __init__(self, d_model, d_inner_hid, n_head, layers=6, dropout=0.1):
        self.layers = [EncoderLayer(d_model, d_inner_hid, n_head, dropout) for _ in range(layers)]
    def __call__(self, src_emb, src_seq, return_att=False, active_layers=999):
        if return_att: atts = []
        mask = Lambda(lambda x:K.cast(K.greater(x, 0), 'float32'))(src_seq)
        x = src_emb     
        for enc_layer in self.layers[:active_layers]:
            x, att = enc_layer(x, mask)
            if return_att: atts.append(att)
        return (x, atts) if return_att else x

class Decoder():
    def __init__(self, d_model, d_inner_hid, n_head, layers=6, dropout=0.1):
        self.layers = [DecoderLayer(d_model, d_inner_hid, n_head, dropout) for _ in range(layers)]

    def __call__(self, tgt_emb, tgt_seq, src_seq, enc_output, return_att=False, active_layers=999):
        x = tgt_emb
        self_pad_mask = Lambda(lambda x:GetPadMask(x, x))(tgt_seq)
        self_sub_mask = Lambda(GetSubMask)(tgt_seq)
        self_mask = Lambda(lambda x:K.minimum(x[0], x[1]))([self_pad_mask, self_sub_mask])
        enc_mask = Lambda(lambda x:GetPadMask(x[0], x[1]))([tgt_seq, src_seq])
        if return_att: self_atts, enc_atts = [], []
        for dec_layer in self.layers[:active_layers]:
            x, self_att, enc_att = dec_layer(x, enc_output, self_mask, enc_mask)
            if return_att: 
                self_atts.append(self_att)
                enc_atts.append(enc_att)
        return (x, self_atts, enc_atts) if return_att else x

class Transformer():
    def __init__(self, len_limit, embedding_matrix, d_model=embed_size, \
              d_inner_hid=512, n_head=4, d_k=64, d_v=64, layers=2, dropout=0.1, \
              share_word_emb=False, **kwargs):
        self.name = 'Transformer'
        self.len_limit = len_limit
        self.src_loc_info = True
        self.d_model = d_model
        self.layers = layers
        self.decode_model = None
        d_emb = d_model

        pos_emb = Embedding(len_limit, d_emb, trainable=False, \
                            weights=[GetPosEncodingMatrix(len_limit, d_emb)])

        self.i_word_emb = Embedding(max_features, d_emb, weights=[embedding_matrix])
        if share_word_emb: 
            self.o_word_emb = i_word_emb
        else: self.o_word_emb = Embedding(max_features, d_emb, weights=[embedding_matrix])

        self.encoder = Encoder(d_model, d_inner_hid, n_head, layers, dropout)
        self.decoder = Decoder(d_model, d_inner_hid, n_head, layers, dropout)
        self.target_layer = TimeDistributed(Dense(avg_size, use_bias=False))

    def get_pos_seq(self, x):
        mask = K.cast(K.not_equal(x, 0), 'int32')
        pos = K.cumsum(K.ones_like(x, 'int32'), 1)
        return pos * mask

    def compile(self, active_layers=999):
        src_seq_input = Input(shape=(avg_size,))
        tgt_seq_input = Input(shape=(avg_size,))

        src_seq = src_seq_input
        tgt_seq  = Lambda(lambda x:x[:,:-1])(tgt_seq_input)

        src_emb = self.i_word_emb(src_seq)
        tgt_emb = self.o_word_emb(tgt_seq)

#         if self.pos_emb: 
#             src_emb = add_layer([src_emb, self.pos_emb(src_seq)])
#             tgt_emb = add_layer([tgt_emb, self.pos_emb(tgt_seq)])

#         src_emb = self.emb_dropout(src_emb)

        src_pos = Lambda(self.get_pos_seq)(src_seq)

        enc_output = self.encoder(src_emb, src_seq, active_layers=active_layers)
        dec_output = self.decoder(tgt_emb, tgt_seq, src_seq, enc_output, active_layers=active_layers)
        final_output = self.target_layer(dec_output)

        adadelta = optimizers.Adadelta(lr=1, rho=0.95, epsilon=None, decay=0.0)
        self.model = Model([src_seq_input, tgt_seq_input], final_output)
        self.model.compile(loss='mse', optimizer=adadelta)