IntelLabs / matsciml

Open MatSci ML Toolkit is a framework for prototyping and scaling out deep learning models for materials discovery supporting widely used materials science datasets, and built on top of PyTorch Lightning, the Deep Graph Library, and PyTorch Geometric.
MIT License
144 stars 20 forks source link

[Feature request]: Improve interfacing between encoders and output heads #265

Open laserkelvin opened 2 months ago

laserkelvin commented 2 months ago

Feature/behavior summary

Currently, there are two things that make configuring output heads a bit of a wildcard:

  1. Difficult to match up the shapes between the embedding dimensionality and the output head input layer
  2. Irregular use of lazy layers for the MLP blocks: this was partly to resolve (1), but has resulted in some engineering complexity (since lazy output heads are created just in time), making it hard to maintain and sometimes difficult to start distributed training.

Request attributes

Related issues

No response

Solution description

I don't really have a perfect solution, but my suggestions are:

  1. Remove the option to use lazy modules; this will break examples and probably some tests, but means that there is only one way to initialize models, making it easier for maintenance and less ambiguity in set up.
  2. In the abstract models (e.g. AbstractPyGModel), add an abstract property like encoder_output_dim or something to that effect, that will make it easier for output heads to be created: it'll essentially just use this to calculate the input dimensions for the output head, and the only configuration the output heads will need is the hidden_dim and possibly output_dim.

For the second item, it could be something as simple as:

class AbstractEncoder:

     @abstractproperty
    def encoder_output_dim(self) -> int:
         raise NotImplementedError

And in the concrete case, it might return the dimension of the final layer, or for more complicated (e.g. concatenated tensors), provide the arithmetic to calculate the expected output. We can then refactor OutputHead to rely on this property for the input dimension.

Additional notes

No response