songyingxin / Bert-TextClassification

Implemention some Baseline Model upon Bert for Text Classification
Other
677 stars 148 forks source link

BertLSTM的forward函数中对hidden操作写的似乎有些问题 #16

Closed bluesea0 closed 4 years ago

bluesea0 commented 4 years ago
#BertLSTM.py    line 19:
self.rnn = nn.LSTM(config.hidden_size, rnn_hidden_size, num_layers,bidirectional=bidirectional, batch_first=True, dropout=dropout)
.....

#BertLSTM.py    line 31-36:
_, (hidden, cell) = self.rnn(encoded_layers)
# outputs: [batch_size, seq_len, rnn_hidden_size * 2]
hidden = self.dropout(
torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))  # 连接最后一层的双向输出

logits = self.classifier(hidden)

这里由于之前设置了batch_first=True,hidden的shape=[batch_size,num_layers*nums_directions,rnn_hidden_size],所以连接最后一层双向输出应该是(hidden[:,-2,:],hidden[:,-1,:])吧?或者将hidden=hidden.permute([1,0,2])

bluesea0 commented 4 years ago

明白了,设置batch_first=True只影响到output的shape,对hidden和cell的shape不会有影响。