graykode / nlp-tutorial

Natural Language Processing Tutorial for Deep Learning Researchers
https://www.reddit.com/r/MachineLearning/comments/amfinl/project_nlptutoral_repository_who_is_studying/
MIT License
14.31k stars 3.95k forks source link

3-3-bilstm-torch comment error #47

Open Tonybb9089 opened 4 years ago

Tonybb9089 commented 4 years ago

class BiLSTM(nn.Module): def init(self): super(BiLSTM, self).init()

    self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)
    self.W = nn.Parameter(torch.randn([n_hidden * 2, n_class]).type(dtype))
    self.b = nn.Parameter(torch.randn([n_class]).type(dtype))

def forward(self, X):
    input = X.transpose(0, 1)  # input : [n_step, batch_size, n_class]

    hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden))   # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
    cell_state = Variable(torch.zeros(1*2, len(X), n_hidden))     # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]

    outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
    **outputs = outputs[-1]  # [batch_size, n_hidden]**
    model = torch.mm(outputs, self.W) + self.b  # model : [batch_size, n_class]
    return model

error: "outputs = outputs[-1] # [batch_size, n_hidden]" the shape should be [batch_size,2*n_hidden]

wmathor commented 4 years ago

hey bro, i found this error too i think you are right