ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
412 stars 155 forks source link

Stress / num_atoms in loss? #437

Closed JPDarby closed 4 weeks ago

JPDarby commented 4 weeks ago

Why is stress divided by num_atoms here? Maybe a bug?

https://github.com/ACEsuit/mace/blob/dee204f1f9d587f28fd792fdad1f45039ef71e94/mace/modules/loss.py#L30C1-L39C12

def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor:
    # energy: [n_graphs, ]
    configs_weight = ref.weight.view(-1, 1, 1)  # [n_graphs, ]
    configs_stress_weight = ref.stress_weight.view(-1, 1, 1)  # [n_graphs, ]
    num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1)  # [n_graphs,]
    return torch.mean(
        configs_weight
        * configs_stress_weight
        * torch.square((ref["stress"] - pred["stress"]) / num_atoms)
    )  # []

The Huber loss looks like it doesn't have the division by num_atoms.

class WeightedHuberEnergyForcesStressLoss(torch.nn.Module):
    def __init__(
        self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01
    ) -> None:
        super().__init__()
        self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta)
        self.register_buffer(
            "energy_weight",
            torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
        )
        self.register_buffer(
            "forces_weight",
            torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
        )
        self.register_buffer(
            "stress_weight",
            torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
        )

    def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
        num_atoms = ref.ptr[1:] - ref.ptr[:-1]
        return (
            self.energy_weight
            * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms)
            + self.forces_weight * self.huber_loss(ref["forces"], pred["forces"])
            + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"])
bernstei commented 4 weeks ago

I agree that virial should be divided by natoms, but not stress.

gabor1 commented 4 weeks ago

Yes must be a bug. Explains poor stress behaviour before, and especially recently. Can someone please just put in a PR deleting this scaling?

ilyes319 commented 4 weeks ago

Merged! Thanks James.