spro / practical-pytorch

Go to https://github.com/pytorch/tutorials - this repo is deprecated and no longer maintained
MIT License
4.52k stars 1.1k forks source link

Best way to save the model in "RNN Classification" #38

Open herleeyandi opened 7 years ago

herleeyandi commented 7 years ago

Hi guys I am a newbie in pytorch. I find that pytorch has simple way to save our model. When practicing tutorial in RNN Classification, I found a problem to save the model. To save the model I do a simple way by execute torch.save(rnn,'char-rnn-classification.pt') , then as in the predict.py files, I load the model by rnn = torch.load('char-rnn-classification.pt'). This mechanism should be save the entire model from network until the weights. However when I execute it, it successfully save the model file but when I predicting the input in testing phase I got this error. Anybody know how to save the model correctly?

python predict.py Satoshi

Traceback (most recent call last):
  File "predict.py", line 32, in <module>
    predict(sys.argv[1])
  File "predict.py", line 17, in predict
    output = evaluate(Variable(lineToTensor(line)))
  File "predict.py", line 12, in evaluate
    output, hidden = rnn(line_tensor[i], hidden)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/media/mspl/ext1/Desktop/Andi/pytorch/practical-pytorch-master/char-rnn-classification/model.py", line 17, in forward
    hidden = self.i2h(combined)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/linear.py", line 54, in forward
    return self._backend.Linear()(input, self.weight, self.bias)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/_functions/linear.py", line 10, in forward
    output.addmm_(0, 1, input, weight.t())
RuntimeError: size mismatch, m1: [1 x 186], m2: [185 x 128] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1237
spro commented 7 years ago

Make sure your dataset is exactly the same, because lineToTensor relies on the number and order of characters in all_characters to create the input tensors. Another solution is to make that function and character list directly attached to the model.