piEsposito / pytorch-lstm-by-hand

A small and simple tutorial on how to craft a LSTM nn.Module by hand on PyTorch.
125 stars 27 forks source link

Hi, Is there any support for Bi-LSTM #1

Open blueardour opened 3 years ago

blueardour commented 3 years ago

Hi, thanks for the helpful work.

Coudl I ask if any plan for supporting bidirectional lstm with custom stacks?

piEsposito commented 3 years ago

Hi. I was not thinking on it, but if it is helpful I might as well support it.

Also, feel free to PR with the feature if you want.

blueardour commented 3 years ago

Hi, I tried to implement on myself. However, I can not figure out the output format.


import torch
import torch.nn as nn
import pdb

class CustomLSTM(nn.LSTM):
    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
        # proj_size=0 is available from Pytorch 1.8
        super(CustomLSTM, self).__init__(input_size, hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)

    def forward(self, x, init_states=None, exporting_onnx=False):
        if exporting_onnx:
            assert self.num_layers == 1
            bs, seq, _ = x.size() if self.batch_first else (x.size(1), x.size(0), x.size(2))
            sz = self.hidden_size

            if init_states is None:
                h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device))
            hidden_seq_forward = []
            for t in range(seq):
                x_t = x[:, t, :] if self.batch_first else x[t, :, :]
                i_t = x_t @ self.weight_ih_l0[sz*0:sz*1,:].transpose(0, 1) + self.bias_ih_l0[sz*0:sz*1] + \
                      h_t @ self.weight_hh_l0[sz*0:sz*1,:].transpose(0, 1) + self.bias_hh_l0[sz*0:sz*1]
                f_t = x_t @ self.weight_ih_l0[sz*1:sz*2,:].transpose(0, 1) + self.bias_ih_l0[sz*1:sz*2] + \
                      h_t @ self.weight_hh_l0[sz*1:sz*2,:].transpose(0, 1) + self.bias_hh_l0[sz*1:sz*2]
                g_t = x_t @ self.weight_ih_l0[sz*2:sz*3,:].transpose(0, 1) + self.bias_ih_l0[sz*2:sz*3] + \
                      h_t @ self.weight_hh_l0[sz*2:sz*3,:].transpose(0, 1) + self.bias_hh_l0[sz*2:sz*3]
                o_t = x_t @ self.weight_ih_l0[sz*3:sz*4,:].transpose(0, 1) + self.bias_ih_l0[sz*3:sz*4] + \
                      h_t @ self.weight_hh_l0[sz*3:sz*4,:].transpose(0, 1) + self.bias_hh_l0[sz*3:sz*4]
                i_t = torch.sigmoid(i_t)
                f_t = torch.sigmoid(f_t)
                g_t = torch.tanh(g_t)
                o_t = torch.sigmoid(o_t)
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t)
                hidden_seq_forward.append(h_t.unsqueeze(0))

            if init_states is None:
                h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device))
            hidden_seq_reverse = []
            for t in list(reversed(range(seq))):
                x_t = x[:, t, :] if self.batch_first else x[t, :, :]
                i_t = x_t @ self.weight_ih_l0_reverse[sz*0:sz*1,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*0:sz*1] + \
                      h_t @ self.weight_hh_l0_reverse[sz*0:sz*1,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*0:sz*1]
                f_t = x_t @ self.weight_ih_l0_reverse[sz*1:sz*2,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*1:sz*2] + \
                      h_t @ self.weight_hh_l0_reverse[sz*1:sz*2,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*1:sz*2]
                g_t = x_t @ self.weight_ih_l0_reverse[sz*2:sz*3,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*2:sz*3] + \
                      h_t @ self.weight_hh_l0_reverse[sz*2:sz*3,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*2:sz*3]
                o_t = x_t @ self.weight_ih_l0_reverse[sz*3:sz*4,:].transpose(0, 1) + self.bias_ih_l0_reverse[sz*3:sz*4] + \
                      h_t @ self.weight_hh_l0_reverse[sz*3:sz*4,:].transpose(0, 1) + self.bias_hh_l0_reverse[sz*3:sz*4]
                i_t = torch.sigmoid(i_t)
                f_t = torch.sigmoid(f_t)
                g_t = torch.tanh(g_t)
                o_t = torch.sigmoid(o_t)
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t) # [bs * self.hidden_size]
                hidden_seq_reverse.append(h_t.unsqueeze(0))

            # stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
            hidden_seq = torch.cat(hidden_seq, dim=0) # [seq, bs, self.hidden_size]
            if self.batch_first:
                hidden_seq = hidden_seq.transpose(0, 1).contiguous()
            return hidden_seq, (_, _)

        else:
            return super().forward(x)

if __name__ == "__main__":
    model = CustomLSTM(100, 60, bidirectional=True)
    x = torch.rand(512, 10, 100)

    model.eval()
    y1, (hn, cn) = model(x, None, False)
    print(y1.shape)

    y2, (hn, cn) = model(x, None, True)
    print(y2.shape)
    pdb.set_trace()

Could I ask for suggestion around # stack hidden_seq_forward and hidden_seq_reverse to hidden_seq

blueardour commented 3 years ago

if I employ

# stack hidden_seq_forward and hidden_seq_reverse to hidden_seq
            hidden_seq_forward = torch.cat(hidden_seq_forward, dim=0) # [seq, bs, self.hidden_size]
            hidden_seq_reverse = torch.cat(hidden_seq_reverse, dim=0) # [seq, bs, self.hidden_size]
            print(hidden_seq_forward.shape, hidden_seq_reverse.shape)
            hidden_seq = torch.cat([hidden_seq_forward, hidden_seq_reverse], dim=2)
            print(hidden_seq.shape)

seems y1 == y2 in the main gives a lot of False