Open MatSci ML Toolkit is a framework for prototyping and scaling out deep learning models for materials discovery supporting widely used materials science datasets, and built on top of PyTorch Lightning, the Deep Graph Library, and PyTorch Geometric.
MIT License
144
stars
20
forks
source link
[Feature request]: Load and Use Wrapped Models 'As Is' From External Pretrained Checkpoint #223
MatSciML offers various models (M3GNet, TensorNet, MACE) which are warppers around the upstream implementations, however there is currently no clean way to load up a pretrained checkpoint and use it 'as is' with the default model architecture. The hang ups arise from:
MatSciML creating output heads which are always added to models
MatSciML expecting Embeddings objects returned from the encoder forward pass
Request attributes
[ ] Would this be a refactor of existing code?
[ ] Does this proposal require new package dependencies?
[ ] Would this change break backwards compatibility?
[ ] Does this proposal include a new model?
[ ] Does this proposal include a new dataset?
[ ] Does this proposal include a new task/workflow?
Related issues
No response
Solution description
Two options to work around this:
modify the existing behavior to toggle on/off the creation of output heads, as well as returning the default output from the wrapped model.
Create a new 'task' which removes all of the output head creation and expected forward pass outputs, and runs the wrapped model 'as is'.
Below is an example of how 1. was implemented by subclassing MatSciML tasks and model wrappers. Note that this relies on #222 to load the proper MACE submodule (ScaleShiftMACE). The model checkpoint 2023-12-10-mace-128-L0_epoch-199.model may be used with example.
Feature/behavior summary
MatSciML offers various models (M3GNet, TensorNet, MACE) which are warppers around the upstream implementations, however there is currently no clean way to load up a pretrained checkpoint and use it 'as is' with the default model architecture. The hang ups arise from:
Embeddings
objects returned from the encoder forward passRequest attributes
Related issues
No response
Solution description
Two options to work around this:
Below is an example of how 1. was implemented by subclassing MatSciML tasks and model wrappers. Note that this relies on #222 to load the proper MACE submodule (ScaleShiftMACE). The model checkpoint 2023-12-10-mace-128-L0_epoch-199.model may be used with example.
Additional notes
No response