Closed kkovary closed 1 month ago
Hi Kyle,
Many thanks for your message! saving the model's state_dict
and reinitializing would be the standard solution in GPyTorch. @InfProbSciX put SIGP
together if I recall correctly. In the meantime it would be great to get a a full reproduction of exactly how you're reinitializing the state_dict
!
Ryan
Hey thanks for getting back to me. I put together a custom serializer and was able to get the state to successfully save and load. I'll post an update here soon when I get some free time. Again, great work on this!
Great to hear Kyle! Would be great to see your custom serializer and feel free to open a PR!
Hey @kkovary any word on your custom serializer? Would be great to have it as a contribution!
Thanks for the reminder, I'll submit a PR this weekend!
First off, great work, this is a really cool package!
I've been playing with the graph representation inputs using
graphein
to a model building off ofSIGP
(some examples in your codebase call itGraphGP
) and have been getting some really great performance out of it. However, I'm struggling to understand how to correctly save and then load the model back into memory for inference after training. If I save the state dict then re-init using that state dict, the model performs as if it had been randomly initialized. I also tried pickling the model (not the ideal solution) I get the following exception:I tried setting
train_inputs
toNone
before saving. This took care of the exception, however I'm back to the original issue where the model seems to be randomly initialized.I was wondering if you had any guidance here, or if there was something in the docs that I missed. Thanks!