vmtmxmf5 / Pytorch-

pytorch로 머신러닝~딥러닝 구현
3 stars 0 forks source link

RNN - Pytorch 핵심 (seq2seq 변형에 관한 해설도 조금) #7

Open vmtmxmf5 opened 3 years ago

vmtmxmf5 commented 3 years ago

중요

# nn.RNN(input_dim, units, batch_first=True)
rnn = nn.RNN(len(char_dic), 15, batch_first=True)

# X가 포함하고 있는 정보 : (N, X의 timestep, input_dim)
# rnn(X)[0].shape : (N, X의 timestep, units) # 출력 텐서
# rnn(X)[1].shape : (N, last timestep, units) # last hidden state
hidden_states, last_hidden_state = rnn(X)

# W_x : 뉴런개수 * input_dim
# W_h : 뉴런개수 * 뉴런개수
# b_x : 뉴런개수 
# b_h : 뉴런개수 (위와 값이 다름)
list(rnn.named_parameters())
vmtmxmf5 commented 3 years ago

input_dim, units = len(char_dic), 15 lstm = nn.LSTM(input_dim, units, batch_first=True)

hidden_states, last_states = lstm(X) last_hidden_state, last_cell_state = last_states

vmtmxmf5 commented 3 years ago

여기서 input_dim == emb_dim 임

참고로 seq2seq에서

rnn cell에 nn.GRU(emb_dim + hid_dim, hid_dim) 해주면

첫번째 hid_dim은 encoder의 rnn cell의 units 개수고, 두번째 hid_dim은 decoder의 rnn cell의 units 개수임

scalar축을 늘리는 이유는 모든 셀에 context vector의 정보를 넣기 위해서임 source sequence 정보를 cell마다 계속 추가해 줌으로써, decoder에 입력된 토큰이 어디서 온 것인지 정보를 살려준다