dearleiii / PIRM-2018-SISR-Challenge

Super Resolution
https://www.pirm2018.org/PIRM-SR.html
2 stars 0 forks source link

save_state_dict #11

Closed dearleiii closed 6 years ago

dearleiii commented 6 years ago

Epoch 5, 85% train_loss: 4.83 took: 18.79s 18 score:: torch.Size([40, 1]) outputs: torch.Size([40, 1]) 19 score:: torch.Size([40, 1]) outputs: torch.Size([40, 1]) 20 score:: torch.Size([1, 1]) outputs: torch.Size([1, 1]) Epoch 5, 100% train_loss: 3.22 took: 4.16s Traceback (most recent call last): File "scatter_kernel3.py", line 156, in approximator.save_state_dict('APXM_4conv.pt') File "/home/home2/leichen/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 532, in getattr type(self).name, name)) AttributeError: 'DataParallel' object has no attribute 'save_state_dict'

dearleiii commented 6 years ago

github repo page for saving model

Recommended approach for saving a model

There are two main approaches for serializing and restoring a model.

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH) Then later:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:

torch.save(the_model, PATH) Then later:

the_model = torch.load(PATH) However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.