salesforce / awd-lstm-lm

LSTM and QRNN Language Model Toolkit for PyTorch
BSD 3-Clause "New" or "Revised" License
1.96k stars 487 forks source link

The model behaves normally during training, but during prediction, the weightdrop mechanism cannot be stopped and an error is reported #122

Closed zhourui-xihu closed 4 years ago

zhourui-xihu commented 4 years ago

hi,

I see the code you share at GitHub's address below(https://github.com/salesforce/awd-lstm-lm),and use your weight dropout technical code to train my LSTM model.

def init(self,input_dim,hidden1_dim,hidden2_dim, num_layers = 1,target = 're' ,biFlag = True):

    super(re,self).__init__()

    self.input_dim=input_dim

    self.hidden1_dim=hidden1_dim
    self.hidden2_dim=hidden2_dim

    self.output1_dim=self.hidden1_dim*2
    self.output2_dim=self.hidden2_dim*2

    self.num_layers= num_layers

    self.target = target
    self.biFlag=biFlag

    self.wdrnn1 = WeightDrop(nn.LSTM(input_size=self.input_dim, hidden_size = self.hidden1_dim,num_layers=self.num_layers,batch_first=True, dropout=0.3, bidirectional=biFlag), ['weight_hh_l0'], dropout=0.5)

    self.wdrnn2 = WeightDrop(nn.LSTM(input_size=self.output1_dim,hidden_size = self.hidden2_dim,num_layers=self.num_layers,batch_first=True,  dropout=0.3,bidirectional=biFlag), ['weight_hh_l0'], dropout=0.5)
    target == 're'
    outsize = 1
    self.ReLU=nn.ReLU()

    self.linearTimeDistributed = nn.Linear(self.output2_dim, outsize)

def init_hidden(self,batch_size,hidden): # 

    c0 = torch.zeros(self.num_layers*2,batch_size, hidden,dtype=torch.float64).to(device)
    h0 = torch.zeros(self.num_layers*2,batch_size, hidden,dtype=torch.float64).to(device)

    return (h0,c0)

def forward(self,inputsignal):

    rnn1out,_ = self.wdrnn1(inputsignal)

    rnn2out,_ = self.wdrnn2(rnn1out)

The model behaves normally during training, but during prediction, the weightdrop mechanism cannot be stopped and an error is reported: Applying weight drop of 0.5 to weight_hh_l0 Applying weight drop of 0.5 to weight_hh_l0 Traceback (most recent call last):

File "D:\0730waid11\prediction.py", line 101, in load_network.load_state_dict(torch.load(modelname, map_location=lambda storage, loc: storage))

File "C:\Users\admin.conda\envs\tf\lib\site-packages\torch\nn\modules\module.py", line 845, in load_state_dict self.class.name, "\n\t".join(error_msgs)))

RuntimeError: Error(s) in loading state_dict for reverb: Unexpected key(s) in state_dict: "rnn1.weight_hh_l0", "rnn2.weight_hh_l0", "wdrnn1.module.weight_hh_l0", "wdrnn2.module.weight_hh_l0".

How should I deal with it? Thank you very much.