dmlc / mxnet-notebooks

Notebooks for MXNet
Apache License 2.0
616 stars 325 forks source link

dimension mismatch in LSTM tutorial when num_embed != num_hidden #2

Closed bikestra closed 8 years ago

bikestra commented 8 years ago

LSTM tutorial (https://github.com/dmlc/mxnet-notebooks/blob/master/python/rnn/lstm.ipynb) seems to have a bug which does not support num_embed != num_hidden. next_state in lstm() should be defined like below.

            next_state = lstm(num_embed if i == 0 else num_hidden, indata=hidden,
                              mask=maskvec[seqidx],
                              prev_state=last_states[i],
                              param=param_cells[i],
                              seqidx=seqidx, layeridx=i, dropout=dropout)

Below is an expanded context:

embed = mx.sym.Embedding(data=data, input_dim=input_size,
                             weight=embed_weight, output_dim=num_embed, name='embed')
    wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
    maskvec = mx.sym.SliceChannel(data=mask, num_outputs=seq_len, squeeze_axis=1)

    # Now we can unroll the network
    hidden_all = []
    for seqidx in range(seq_len):
        hidden = wordvec[seqidx] # input to LSTM cell, comes from embedding

        # stack LSTM
        for i in range(num_lstm_layer):
            next_state = lstm(num_hidden, indata=hidden,
                              mask=maskvec[seqidx],
                              prev_state=last_states[i],
                              param=param_cells[i],
                              seqidx=seqidx, layeridx=i, dropout=dropout)
            hidden = next_state.h
            last_states[i] = next_state
        # decoder
        hidden_all.append(hidden) # last output of stack LSTM units
bikestra commented 8 years ago

My bad, sorry.