lsdefine / attention-is-all-you-need-keras

A Keras+TensorFlow Implementation of the Transformer: Attention Is All You Need
708 stars 188 forks source link

'nan' loss function when using layer normalization #13

Open McKracken opened 5 years ago

McKracken commented 5 years ago

Hi,

I was using only the LayerNormalization from your code in mine. I didn't change anything from the code, apart from overriding the compute_mask function, as my input is an Embedding with mask_zero=True

Code

class LayerNormalization(Layer):

    def __init__(self, eps=1e-6, **kwargs):
        self.eps = eps
        super(LayerNormalization, self).__init__(**kwargs)

    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
                                     initializer=Ones(), trainable=True)
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
                                    initializer=Zeros(), trainable=True)
        super(LayerNormalization, self).build(input_shape)

    def call(self, x):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

    def compute_output_shape(self, input_shape):
        return input_shape

    def compute_mask(self, inputs, input_mask=None):
        return input_mask

but strangely I get all nan for all the measurements I do while training and tuning (loss function and others). I tried using other implementations of the LayerNormalization layer (e.g. https://github.com/CyberZHG/keras-layer-normalization), and everything works without problem. I was wondering whether you have any clue about that.

lsdefine commented 5 years ago

CyberZHG's

variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
std = K.sqrt(variance + self.epsilon)

My

std = K.std(x, axis=-1, keepdims=True)

I think maybe there are input sequences with length 0, and the whole sequence is mask. But you can safely use his LayerNormalization.