ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
412 stars 155 forks source link

during-training valid loss is wrong #444

Closed bernstei closed 3 weeks ago

bernstei commented 4 weeks ago

Looks to me like the validation loss in the log during fitting is actually the sum over all heads so far:

https://github.com/ACEsuit/mace/blob/e4ac49818c9dc581c5167155b7b659bca0c9064e/mace/tools/train.py#L217-L233

Is the solution just that the quantity that should be passed to is valid_loss_head?

bernstei commented 3 weeks ago

Pretty sure that this is indeed a bug, but also validation batches in the multi-head-interface branch are not deterministic.

LarsSchaaf commented 3 weeks ago

What would you say the best behaviour should be? The user can supply a list of weights for each head? Defaulting to 0: mp and 1:Default ?

Given that the main impact this has is which checkpoint (and model) to save:

  1. If youre finetuning you want to have the situation above
  2. If you're training a model on multiple reference data you want to be able to select which method yo care most about (hence the ability to supply weights?)
gabor1 commented 3 weeks ago

I think the validation losses and rmses should be printed separately for each head

bernstei commented 3 weeks ago

I think the validation losses and rmses should be printed separately for each head

They already are

bernstei commented 3 weeks ago

What would you say the best behaviour should be? The user can supply a list of weights for each head? Defaulting to 0: mp and 1:Default ?

Given that the main impact this has is which checkpoint (and model) to save:

  1. If youre finetuning you want to have the situation above
  2. If you're training a model on multiple reference data you want to be able to select which method yo care most about (hence the ability to supply weights?)

This has nothing to do with weights. That's a separate issue. The code claims to print a loss for each head, but actually prints the cumulative loss that it's computing as it calculates the total loss by looping over heads. That's all. I'll do a PR for this issue, now that it seems pretty clear (from the slack) that the validation loss not being deterministic is a separate bug.

LarsSchaaf commented 3 weeks ago

The point is that checkpoints only get saved if the loss decreases. Now that we have multiple heads and therefore multiple validation losses how do we decide when to save a checkpoint? My suggestion was having a main_loss that is a combination of the losses. Which combination depends on the usecase - hence the user should be able to supply a weighting over head_loss s. If the main_loss decreases a new checkpoint is saved.

bernstei commented 3 weeks ago

The point is that checkpoints only get saved if the loss decreases. Now that we have multiple heads and therefore multiple validation losses how do we decide when to save a checkpoint? My suggestion was having a main_loss that is a combination of the losses. Which combination depends on the usecase - hence the user should be able to supply a weighting over head_loss s. If the main_loss decreases a new checkpoint is saved.

A fine suggestion, but independent of this issue. I agree that the way the "total" loss, which is used to save checkpoints, is calculated could use further thought. And I don't even mind making that part of the PR I created for this issue. But the issue was really only about how valid_err_log needs to get the head-specific loss, rather than the (currently naively computed) partial sum that's constructed to calculate a total loss.

[edited] @LarsSchaaf I think you should perhaps open a new issue, an enhancement request, to make the logic for saving checkpoints based on loss less naive

ilyes319 commented 3 weeks ago

The user can already provide weights for different heads just to be clear.

bernstei commented 3 weeks ago

The user can already provide weights for different heads just to be clear.

Are those used when printing the validation loss? Or only when computing the gradient of the training loss w.r.t. shared parameters?

ilyes319 commented 3 weeks ago

They would be used for printing also currently.

bernstei commented 3 weeks ago

This issue was supposed to be closed by #449, but seems to still be open. Do we want to continue this discussion here, or close it and open a new one having to do with weights?