Open kevglynn opened 6 years ago
@kevglynn Were you able to make any headway? I'm running into same problem.
I can have a look at enabling this. In the meantime, know that pickling works!
For reference for anyone who stumbles upon this issue and wants a direct link to the current workaround:
https://pytorch.org/docs/stable/torch.html?highlight=torch%20save#torch.save
@kevglynn @jaskiratr @maciejkula
For a temporary workaround you can access the nn.Module functions by calling _.net on your model and then the desired function.
E.g.
torch.save(model._net.state_dict(), PATH)
model = ImplicitSequenceModel(n_iter=3, representation='ltsm', loss='bpr')
model._initialize(dataset)
model._net.load_state_dict(torch.load(PATH))
The trick with _initialize is to pass in an Interactions object to initialize the model otherwise the PyTorch methods are unavailable. The num_items should match the num_items you trained on. Not ideal but I managed to get my model serving via REST API on my GCP instance this way.
Saving a model as follows:
torch.save(model.state_dict(), 'path/to/model')
yields the following:AttributeError: 'ImplicitSequenceModel' object has no attribute 'state_dict'
This means I am currently unable to follow best practice as outlined by PyTorch: https://pytorch.org/docs/stable/notes/serialization.html#recommend-saving-models
thanks!