salaniz / pytorch-gve-lrcn

PyTorch implementations for "Generating Visual Explanations" (GVE) and "Long-term Recurrent Convolutional Networks" (LRCN)
MIT License
92 stars 22 forks source link

Training GVE causes error in state_dict() #3

Closed dfdazac closed 6 years ago

dfdazac commented 6 years ago

I followed the steps to train a GVE model on the CUB dataset. After the first epoch, it starts evaluating and then proceeds to save the model. The following error is thrown:

Traceback (most recent call last):
  File "main.py", line 114, in <module>
    torch.save(model.state_dict(), checkpoint_path)
  File "/home/daniel/projects/pytorch-vision-language/models/lrcn.py", line 80, in state_dict
    state_dict = super().state_dict()
  File "/home/daniel/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 616, in state_dict
    module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
TypeError: state_dict() got an unexpected keyword argument 'keep_vars'

According to lrcn.py state_dict() calls super().state_dict(), which corresponds to the state_dict() method in the nn.Module class. The last lines in that method are:

for name, module in self._modules.items():
    if module is not None:
        module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
return destination

These lines call the state_dict method of every module in the model, which includes a SentenceClassifier. The state_dict method for this class is

def state_dict(self, full_dict=False):
    return super().state_dict()

Since its signature doesn't match that of the method being called from nn.Module.state_dict(), the error is thrown.

salaniz commented 6 years ago

Thanks for the detailed error report. I will fix it momentarily.