IntelLabs / matsciml

Open MatSci ML Toolkit is a framework for prototyping and scaling out deep learning models for materials discovery supporting widely used materials science datasets, and built on top of PyTorch Lightning, the Deep Graph Library, and PyTorch Geometric.
MIT License
144 stars 20 forks source link

[Feature request]: Support using intermediate embeddings #99

Closed laserkelvin closed 9 months ago

laserkelvin commented 9 months ago

Feature/behavior summary

Refactor Embeddings data structure, and their use by OutputBlocks to allow the use of intermediate embeddings for modeling. In some models such as MACE, the output of the model is given as some reduction over projections of intermediate layers, i.e. $E_f = E_0 + E_1 + \ldots + E_l$ for $l$ layers.

The current implementation hinders this mode of usage a little, as we need to be able to store intermediate embeddings and then use them correctly after every layer is computed, which would be fine if the intermediate embeddings were the same shape and could be concatenated along a single (new) dimension, and the output blocks just broadcasts. In the case of MACE and other equivariance preserving models, the intermediate layers may have different shapes and can't be concatenated.

Request attributes

Related issues

72 is the main issue, but #83 is a related WIP PR

Solution description

One way would be to refactor Embeddings to allow intermediate embeddings; i.e. for node/system levels, we expect either a Tensor or a list[Tensor]. This would then need the logic of OutputBlock to be modified so that we use different output heads per intermediate embedding, then reduce them for the final output.

For now I don't think this will be backwards breaking.

Additional notes

No response

laserkelvin commented 9 months ago

Closing this for now, as we've agreed that it's not entirely necessary for MACE for now. We can re-open this issue if there are new use cases that demand this model, but having thought about it for a bit, it's not an entirely trivial task.