Closed laserkelvin closed 3 months ago
Looks good overall! My only suggestion would be to add a couple pytests for the new loss modules. I'll leave it up to you if you want to do that or not. Otherwise feel free to merge when ready.
Added a very superficial parametrized test in 04d359f. Will merge when tests pass!
This PR adds support for MSE and L1 loss functions that are weighted by the number of atoms in each graph in the
matsciml.models.losses
module.In order to enable usage of these functions with tasks that have both scalar (e.g. energy) and vector (e.g. force) targets, I've had to refactor
loss_func
as an argument to all tasks to support a dictionary mapping, whereby each key corresponds to atask_key
, and the passed function the loss for that corresponding target. As an example:This does not break previous specifications: if a loss module is passed by itself (e.g.
loss_func = nn.MSELoss()
), thetask_keys.setter
method copy the function to be used for all targets.Some refactoring was also needed in
_compute_losses
to allow for additional arguments to be passed into the loss function, e.g. in this case the number of atoms per graph. New loss functions inlosses
should try to be consistent in function signatures with native PyTorch ones (e.g.input
andtarget
) for consistent mapping.