next_state = lstm(num_embed if i == 0 else num_hidden, indata=hidden,
mask=maskvec[seqidx],
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i, dropout=dropout)
Below is an expanded context:
embed = mx.sym.Embedding(data=data, input_dim=input_size,
weight=embed_weight, output_dim=num_embed, name='embed')
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
maskvec = mx.sym.SliceChannel(data=mask, num_outputs=seq_len, squeeze_axis=1)
# Now we can unroll the network
hidden_all = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx] # input to LSTM cell, comes from embedding
# stack LSTM
for i in range(num_lstm_layer):
next_state = lstm(num_hidden, indata=hidden,
mask=maskvec[seqidx],
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i, dropout=dropout)
hidden = next_state.h
last_states[i] = next_state
# decoder
hidden_all.append(hidden) # last output of stack LSTM units
LSTM tutorial (https://github.com/dmlc/mxnet-notebooks/blob/master/python/rnn/lstm.ipynb) seems to have a bug which does not support num_embed != num_hidden.
next_state
inlstm()
should be defined like below.Below is an expanded context: