choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
11 stars 4 forks source link

Unsstable multitask learning behavior #271

Open wiederm opened 1 week ago

wiederm commented 1 week ago

Currently, we use fixed weights assigned to each loss component to balance contributions to the total loss. This leads to inefficient training behavior and trial and error search for different networks on different datasets. I suggest that we use techniques from multitask learning to balance the different learning tasks more efficiently.

I will start by adding dynamic loss weighting approaches:

chrisiacovella commented 7 hours ago

While by the overall value will ensure forces and energy are the same scale, and the gradient info is still preserved, we lose all information about the actual objective of the loss which is minimizing the actual value (what was loss will never change)

If we are to do any normalization, I think we should be normalizing individual values by the the total tensor. So currently for energy we have the mean squared error:

Loss =(1/N) * sum ( X_i -Y_i)^2

where X is the value we calculate (e.g., energy) and Y is the "true" value, N is the total number of atoms.

The loss of course doesn't need to be MSE, it could be any function. E.g., we could make it the normalized Euclidian distance:

Loss = sum (sqrt(( X_i -Y_i)^2 ))/ ( |X| + |Y|)

where |X| and |Y| is the magnitude of the tensor of all per_atom_energies for the calculated and known energies.

We could also probably just normalize MSE by the magnitude of the vectors as well.

I did some quick tests on what the loss would look like each case. I think one problem we run into is the spread of the values. This might be especially true with force (a few larger forces could really dominate).

So let us just assume a simple tensor of "known" values, where each entry increases by an order of magnitude (to simulate a case where we have very different scales). Our "calculated" values will be a random perturbation of these values, with some set max fractional amount of deviation from the known value (thus maximum magnitude will scale with the value itself). This was repeated 100 times to get some idea of variability.

I calculated normalized euclidian distance (norm_ed), a "normalized" version of MSE (just scaling by sum of vector magnitude, norm_MSE) and normal MSE. The code I used to generate this is below.

max frac dev    |  norm_ed      | norm_MSE          |  MSE
0.1             | 0.04 +/- 0.03 | 11.4 +/- 24.7     | 210020.3 +/- 4147488.6
0.2             | 0.08 +/- 0.07 | 34.6 +/- 46.9     | 717427.8 +/- 978088.0
0.5             | 0.24 +/- 0.22 | 206.8 +/- 331.1   | 3833511.2 +/- 590990.5
1.0             | 0.34 +/- 0.28 | 845.8 +/- 1122.2  | 21725386.0  +/-  31761810.0 

If nothing else, to try to get energy and force loss on the same scale, we could normalize by the magnitude of the "known" values. This would then make the weighting factors we define only focused on how much we want each to contribute to the loss (i.e., their importance), rather trying to do both their importance and get their scale reasonably the same.

import torch

max_scale = 0.1
x = torch.tensor([1.0, 10.0, 100.0, 1000.0, 10000.0], dtype=torch.float32)

norm_ed_list = []
norm_mse_list = []
mse_list = []
for i in range(100):
    perturb = torch.randn_like(x) * (x * max_scale)
    y = x + perturb

    x_norm = torch.norm(x, p=2, dim=0)
    y_norm = torch.norm(y, p=2, dim=0)

    norm_ed = torch.sqrt_(((x - y) ** 2).sum()) / (x_norm + y_norm)
    norm_mse = (1 / x.shape[0]) * ((x - y) ** 2).sum() / (x_norm + y_norm)
    mse = (1 / x.shape[0]) * ((x - y) ** 2).sum()

    norm_ed_list.append(float(norm_ed))
    norm_mse_list.append(float(norm_mse))
    mse_list.append(float(mse))

norm_ed_mean = float(torch.tensor(norm_ed_list).mean())
norm_ed_std = float(torch.tensor(norm_ed_list).std())
norm_mse_mean = float(torch.tensor(norm_mse_list).mean())
norm_mse_std = float(torch.tensor(norm_mse_list).std())
mse_mean = float(torch.tensor(mse_list).mean())
mse_std = float(torch.tensor(mse_list).std())

print("norm_ed_mean: ", norm_ed_mean, " +/- ", norm_ed_std)
print("norm_mse_mean: ", norm_mse_mean, " +/- ", norm_mse_std)
print("mse_mean: ", mse_mean, " +/- ", mse_std)