Seanny123 / da-rnn

Dual-Stage Attention-Based Recurrent Neural Net for Time Series Prediction
331 stars 120 forks source link

Error while executing main_predict.py #2

Closed notonlyvandalzzz closed 5 years ago

notonlyvandalzzz commented 5 years ago

Got this error trying just to run code from github: Traceback (most recent call last): File "main_predict.py", line 76, in <module> final_y_pred = predict(enc, dec, data, **da_rnn_kwargs) File "main_predict.py", line 51, in predict y_pred[y_slc] = decoder(input_encoded, y_history).cpu().data.numpy() File "/home/halcyon/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__ result = self.forward(*input, **kwargs) File "/home/halcyon/darnn/modules.py", line 106, in forward y_tilde = self.fc(torch.cat((context, y_history[:, t]), dim=1)) # (batch_size, out_size) File "/home/halcyon/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__ result = self.forward(*input, **kwargs) File "/home/halcyon/anaconda3/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 55, in forward return F.linear(input, self.weight, self.bias) File "/home/halcyon/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 1024, in linear return torch.addmm(bias, input, weight.t()) RuntimeError: size mismatch, m1: [128 x 40624], m2: [65 x 1] at /opt/conda/conda-bld/pytorch_1533672544752/work/aten/src/THC/generic/THCTensorMathBlas.cu:249 Running main.py gives no error, so looks like somehow tensor sizes got mismatch at prediction stage The only modification to code i've made is just added this string in the beginning: torch.set_default_tensor_type('torch.cuda.FloatTensor') w/o that string i've got 'type mismatch' from torch

notonlyvandalzzz commented 5 years ago

Resolved Removed unsqueeze(1) from eqn. 15 in decoder

gao27024037 commented 4 years ago

Resolved Removed unsqueeze(1) from eqn. 15 in decoder

Hello, I meet the same with you, but I cannot find unsqueeze(1) in eqn.15, can you tell how you solved it?

thanks

gao27024037 commented 4 years ago

Resolved Removed unsqueeze(1) from eqn. 15 in decoder

Hello, I meet the same with you, but I cannot find unsqueeze(1) in eqn.15, can you tell how you solved it?