Open vmtmxmf5 opened 3 years ago
# LSTM은 어떤 식으로 output이 구성되어 있을까?
input_dim, units = len(char_dic), 15 lstm = nn.LSTM(input_dim, units, batch_first=True)
hidden_states, last_states = lstm(X) last_hidden_state, last_cell_state = last_states
여기서 input_dim == emb_dim 임
참고로 seq2seq에서
rnn cell에 nn.GRU(emb_dim + hid_dim, hid_dim) 해주면
첫번째 hid_dim은 encoder의 rnn cell의 units 개수고, 두번째 hid_dim은 decoder의 rnn cell의 units 개수임
scalar축을 늘리는 이유는 모든 셀에 context vector의 정보를 넣기 위해서임 source sequence 정보를 cell마다 계속 추가해 줌으로써, decoder에 입력된 토큰이 어디서 온 것인지 정보를 살려준다
중요
epoch 구현 시 유의사항 rnn은 output이 2종류다