kyunghyuncho / NMT

1 stars 1 forks source link

Bug in both Blocks and GroundHog when computing the output probability #19

Open kyunghyuncho opened 9 years ago

kyunghyuncho commented 9 years ago

@bartvm @rizar @sebastien-j @ejls

@orhanf and I just realized that there was a bug in the implementation of the attention-based model already from GroundHog. Somehow, the same bug is in Blocks as well for training, but it's not in sampling (beamsearch), which is how we noticed.

The bug is that when computing the output probability (or readout) the code uses the context computed by the attention mechanism from the "next" timestep using the "previous" symbol as the input (which makes it really weird.) For instance, in GH:

        readout = self.repr_readout(contexts[0])
        for level in range(self.num_levels):
            if mode != Decoder.EVALUATION:
                read_from = init_states[level]
            else:
                read_from = hidden_layers[level]

(https://github.com/kyunghyuncho/GroundHog/blob/master/experiments/nmt/encdec.py#L1150-L1155)

contexts[0] is the context for the "next" hidden state hidden_layers but during sampling (mode != Decoder.EVALUATION) init_states is used instead. The same thing happens during training in GF. See

        # In hidden_layers we do no have the initial state, but we need it.
        # Instead of it we have the last one, which we do not need.
        # So what we do is discard the last one and prepend the initial one.
        if mode == Decoder.EVALUATION:
            for level in range(self.num_levels):
                hidden_layers[level].out = TT.concatenate([
                    TT.shape_padleft(init_states[level].out),
                        hidden_layers[level].out])[:-1]

(https://github.com/kyunghyuncho/GroundHog/blob/master/experiments/nmt/encdec.py#L1135-L1142)

hidden_layers is shifted to accommodate the initial hidden state (which is not returned by scan), but contexts is not shifted accordingly.

I believe the reason the models we trained with GH worked well regardless is that because both training and sampling followed the same procedure. The downside is that this requires the model to compute the hidden state "twice" each time.

The same thing happens in Blocks during training. See

            if return_initial_states:
                # Undo Subtensor
                for i in range(len(states_given)):
                    assert isinstance(result[i].owner.op,
                                      tensor.subtensor.Subtensor)
                    result[i] = result[i].owner.inputs[0]

(https://github.com/bartvm/blocks/blob/master/blocks/bricks/recurrent.py#L211-L216)

This will shift the hidden states, but won't shift the computed context as far as I can tell because it's not given in states_given. This is confirmed indirectly by the matching cost between GH and Blocks-based implementation.

However, for beam search, Blocks implementation is correct. See

    def _compile_logprobs_computer(self):
        # This filtering should return identical variables
        # (in terms of computations) variables, and we do not care
        # which to use.
        probs = VariableFilter(
            applications=[self.generator.readout.emitter.probs],
            roles=[OUTPUT])(self.inner_cg)[0]
        logprobs = -tensor.log(probs)
        self.logprobs_computer = function(
            self.contexts + self.input_states, logprobs,
            on_unused_input='ignore')

(https://github.com/bartvm/blocks/blob/master/blocks/search.py#L119-L129)

This function uses the "current context" (self.contexts).

So, we have now two options:

(1) cripple the Blocks implementation (beam search) to make it do exactly what GH did: this is a safer bet, since we know this works. However, the model becomes quite weird, and this should only be a temporary solution

(2) fix the Blocks implementation (training) to do the right thing: I believe this is the way we should go, but in this case, we have to forget about reusing the models trained with GH.

I prefer (2), but what do you think?

kyunghyuncho commented 9 years ago

Oh, one more reason for going (2): we know that the correct way to do it works well, as the models for image/video caption generation worked well.

kyunghyuncho commented 9 years ago

Apparently, the same thing applies to Blocks' beamsearch as well. Currently the beamsearch works. We'll keep it as it is for WMT'15, but this is something to ponder upon.

sebastien-j commented 9 years ago

I don't understand. When I was checking the cost, I implemented the forward pass in numpy and it seemed fine. If it isn't too much of a mess, could you point out how the following code would be changed?

import numpy

from numpy import dot, tanh, exp, max

def sigmoid(x):
    return 1./(1+exp(-x))

def np_cost(x, labels):
    # batchsize x num_categories (or num_ategories)
    # labels : 1d-array of length batchsize (or scalar for single example)
    # Returns 1d-array of costs
    if len(x.shape) == 1:
        x = x[None, :]
    mx = max(x, axis=1)
    scores = exp(x-mx[:, None])
    return  -x[range(len(x)), labels] + mx + numpy.log(numpy.sum(scores, axis=1))

def compute_annotations(x):
    # batchsize x timesteps
    if len(x.shape) == 1:
        x = x[None, :]
    mx = max(x, axis=1)
    scores = exp(x-mx[:, None])
    scores /= numpy.sum(scores, axis=1)[:,None]
    return scores
###

fname = '/data/lisatmp3/jeasebas/nmt/bahdanau/fixed/search_30k_fixed1075_model2.npz'

params = dict(numpy.load(fname))

emb = params['W_0_enc_approx_embdr'] + params['b_0_enc_approx_embdr']

W = params['W_0_enc_input_embdr_0']
Wz = params['W_0_enc_update_embdr_0']
Wr = params['W_0_enc_reset_embdr_0']

enc_b = params['b_0_enc_input_embdr_0']

U = params['W_enc_transition_0']
Uz = params['G_enc_transition_0']
Ur = params['R_enc_transition_0']

###
back_W = params['W_0_back_enc_input_embdr_0']
back_Wz = params['W_0_back_enc_update_embdr_0']
back_Wr = params['W_0_back_enc_reset_embdr_0']

back_enc_b = params['b_0_back_enc_input_embdr_0']

back_U = params['W_back_enc_transition_0']
back_Uz = params['G_back_enc_transition_0']
back_Ur = params['R_back_enc_transition_0']
###

d_I = params['W_0_dec_initializer_0']
d_Ib = params['b_0_dec_initializer_0']

d_emb = params['W_0_dec_approx_embdr'] + params['b_0_dec_approx_embdr']

d_W = params['W_0_dec_input_embdr_0']
d_Wz = params['W_0_dec_update_embdr_0']
d_Wr = params['W_0_dec_reset_embdr_0']
d_b = params['b_0_dec_input_embdr_0']

d_U = params['W_dec_transition_0']
d_Uz = params['G_dec_transition_0']
d_Ur = params['R_dec_transition_0']

d_C = params['W_0_dec_dec_inputter_0']
d_Cz = params['W_0_dec_dec_updater_0']
d_Cr = params['W_0_dec_dec_reseter_0']

d_Oh = params['W_0_dec_hid_readout_0']
d_Oy = params['W_0_dec_prev_readout_0']
d_Oc = params['W_0_dec_repr_readout']
d_Ob = params['b_0_dec_hid_readout_0']

d_W1 = params['W1_dec_deep_softmax']
d_W2 = params['W2_dec_deep_softmax']

d_softmax_b = params['b_dec_deep_softmax']

d_A = params['A_dec_transition_0'] # 2000 x 1000
d_B = params['B_dec_transition_0'] # 1000 x 1000
d_D = params['D_dec_transition_0'].reshape((len(params['D_dec_transition_0']))) # 1000 x 1 -> 1000

d_W12 = dot(d_W1, d_W2)

####

h = numpy.zeros((len(english), d_U.shape[1])) # eg 64 x 1000
back_h = numpy.zeros((len(english), d_U.shape[1])) # eg 64 x 1000

h_array = []
back_h_array = []

for i in xrange(english.shape[1]):

    z = sigmoid(dot(emb[english[:,i]], Wz) + dot(h, Uz)) # bs x h_size
    r = sigmoid(dot(emb[english[:,i]], Wr) + dot(h, Ur))

    g = tanh(dot(emb[english[:,i]], W) + dot(r*h, U) + enc_b)
    h = english_mask[:,i][:, None] * ((1 - z) * h + z * g) + (1.- english_mask[:,i][:, None]) * h
    h_array.append(h)

    back_z = sigmoid(dot(emb[english[:,-i-1]], back_Wz) + dot(back_h, back_Uz)) # bs x h_size
    back_r = sigmoid(dot(emb[english[:,-i-1]], back_Wr) + dot(back_h, back_Ur))

    back_g = tanh(dot(emb[english[:,-i-1]], back_W) + dot(back_r*back_h, back_U) + back_enc_b)
    back_h = english_mask[:,-i-1][:, None] * ((1 - back_z) * back_h + back_z * back_g) + (1.- english_mask[:,-i-1][:, None]) * back_h
    back_h_array.append(back_h)

h_array = numpy.asarray(h_array) # timesteps x bs x 1000
back_h_array = numpy.asarray(back_h_array) # timesteps x bs x 1000

back_h_array = back_h_array[::-1,:,:]

both_h_array = numpy.concatenate((h_array, back_h_array), axis=2) # timesteps x bs x 2000

search_enc = dot(both_h_array, d_A) # timesteps x bs x 2000, 2000 x 1000 -> timesteps x bs x 1000

d_cost_array = []
total_cost = 0.

d_h = tanh(dot(back_h_array[0], d_I) + d_Ib)
initial_d_h = d_h[:]

d_h_array = []
d_h_array.append(d_h)

prev_word_emb = numpy.zeros_like(d_emb[french[:,-1]])

wa_array = []

for i in xrange(french.shape[1]):

    # Get the context

    search_dec = dot(d_h, d_B) # bs x 1000, 1000 x 1000 -> bs x 1000

    search_tanh = tanh(search_enc + search_dec) # timesteps x bs x 1000

    search_energies = dot(search_tanh, d_D).T # bs x timesteps
    search_energies[english_mask<0.5] = -numpy.inf

    annotations = compute_annotations(search_energies) # bs x timesteps
    h = numpy.sum(annotations[:,:,None] * numpy.transpose(both_h_array, [1,0,2]), axis=1) # bs x 2000

    wa_array.append(h)

    d_sp = dot(prev_word_emb,d_Oy) + dot(d_h ,d_Oh) + dot(h, d_Oc) + d_Ob # 64 x 1000 -> 64 x 500 x 2
    d_sp = d_sp.reshape(d_sp.shape[0], d_sp.shape[1]/2, 2)
    d_s = numpy.max(d_sp, axis = -1)
    energies = dot(d_s, d_W12) + d_softmax_b[None, :] # bs x V
    d_cost = np_cost(energies, french[:,i]) * french_mask[:,i]
    d_cost_array.append(d_cost)
    total_cost += numpy.sum(d_cost)

    prev_word_emb = d_emb[french[:,i]]

    d_z = sigmoid(dot(prev_word_emb,d_Wz) + dot(d_h, d_Uz) + dot(h, d_Cz))
    d_r = sigmoid(dot(prev_word_emb,d_Wr) + dot(d_h, d_Ur) + dot(h, d_Cr))

    d_g = tanh(dot(prev_word_emb,d_W) + dot(d_r*d_h, d_U) + dot(h, d_C) + d_b)
    d_h = french_mask[:,i][:, None] * ((1 - d_z) * d_h + d_z * d_g) + (1.- french_mask[:,i][:, None]) * d_h
    # `i` seems strange but matches current gh implementation
    # It doesn't matter anyway as we use the correct position for the cost

    if i != (french.shape[1] - 1):
        d_h_array.append(d_h)

d_h_array = numpy.asarray(d_h_array)
wa_array = numpy.asarray(wa_array)
d_cost_array = numpy.asarray(d_cost_array)
rizar commented 9 years ago

This will shift the hidden states, but won't shift the computed context as far as I can tell because it's not given in states_given

In fact the "contexts" (using Groundhog terminology) or the "glimpses" (in Blocks terminology) are states of the AttentionRecurrent, that is the return_initial_states=True will apply to them as well. However in this line we discard it, because initial glimpses make no sense. So to generate the word y_1 the initial state s_0 and the first computed glimpse g_1 are used.

So far I also do not understand what is wrong.

kyunghyuncho commented 9 years ago

@rizar @sebastien-j The problem is that y_t is computed by g_t+1 and h_t, but h_t is computed by h_{t-1} and g_t. See the appendix A.2.2 of http://arxiv.org/pdf/1409.0473v5.pdf. y_t should be computed by g_t and h_t, and h_t by h_{t-1} and g_t.

@sebastien-j your numpy implementation does the same thing as well.

rizar commented 9 years ago

Hold on, I think you are now proposing this very detrimental change that I tried to introduce in https://github.com/bartvm/blocks/pull/522

Let's look close at what you propose using the nice notation @sebastien-j introduced:

g_t <--- h_{t-1}, g_{t-1}
h_t <--- h_{t-1}, g_t, y_{t-1}
y_t <--- g_t

The problem is that now g_t does not depend on y_{t-1}. That brings a big performance drop, I tried.

But it seems like this is exactly what is written in A.2.2.

What is done in Blocks and Groundhog instead in your notation looks like this:

g_t <--- s_{t-1}, g_{t-1},
y_t <---  s_{t-1}, g_t
s_t <--- s_{t-1}, g_t, y_t

This way g_t does depend on y_{t-1} through s_{t-1}.

kyunghyuncho commented 9 years ago

First, we should eventually update A.2.2 to match the implementation in Blocks and GH.

Second, in this case, what is the alignment of y_t? Is it g_t or g_{t-1}?

rizar commented 9 years ago

First, we should eventually update A.2.2 to match the implementation in Blocks and GH.

Definitely. I can do it as soon as we finish the discussion.

The alignment for y_t is g_t in Blocks. You might want to read the new docs in https://github.com/bartvm/blocks/pull/563, it explains in detail how things work now.