IntelLabs / matsciml

Open MatSci ML Toolkit is a framework for prototyping and scaling out deep learning models for materials discovery supporting widely used materials science datasets, and built on top of PyTorch Lightning, the Deep Graph Library, and PyTorch Geometric.
MIT License
144 stars 20 forks source link

Atom weighted loss functions and `loss_func` argument refactor #256

Closed laserkelvin closed 3 months ago

laserkelvin commented 3 months ago

This PR adds support for MSE and L1 loss functions that are weighted by the number of atoms in each graph in the matsciml.models.losses module.

In order to enable usage of these functions with tasks that have both scalar (e.g. energy) and vector (e.g. force) targets, I've had to refactor loss_func as an argument to all tasks to support a dictionary mapping, whereby each key corresponds to a task_key, and the passed function the loss for that corresponding target. As an example:

ForceRegressionTask(
   ...,
   loss_func={"energy": matsciml.models.losses.AtomWeightedMSE, "force": nn.MSELoss},
   ...
)

This does not break previous specifications: if a loss module is passed by itself (e.g. loss_func = nn.MSELoss()), the task_keys.setter method copy the function to be used for all targets.

Some refactoring was also needed in _compute_losses to allow for additional arguments to be passed into the loss function, e.g. in this case the number of atoms per graph. New loss functions in losses should try to be consistent in function signatures with native PyTorch ones (e.g. input and target) for consistent mapping.

laserkelvin commented 3 months ago

Looks good overall! My only suggestion would be to add a couple pytests for the new loss modules. I'll leave it up to you if you want to do that or not. Otherwise feel free to merge when ready.

Added a very superficial parametrized test in 04d359f. Will merge when tests pass!