ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
415 stars 157 forks source link

[Bug?] Training the MACE model slows down gradually #169

Closed mjhong0708 closed 9 months ago

mjhong0708 commented 9 months ago

First of all, thanks for developers to share such a great model to community.

Description

When I tried to train MACE model on large bulk dataset(containing DFT optimization trajectories from Materials Project), I experienced slowdown of training over epochs. The slowdown occurs regardless on batch size, dataloader shuffling, etc. It does not happen with small molecule datasets (ex. MD17)

Code and data for reproduction

Below is full script to reproduce my problem. To exclude any possible source of error other than the model itself, I wrote minimal energy-only training loop.

import json

import ase.data
import numpy as np
import torch
from e3nn import o3
from torch.nn import functional as F
from tqdm import tqdm

from mace.data.atomic_data import AtomicNumberTable, AtomicData
from mace.data.utils import config_from_atoms_list
from mace.modules.blocks import RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock
from mace.modules.models import MACE
from mace.tools.torch_geometric import DataLoader

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model definition
with open("atomic_energies.json", "r") as f:
    atomic_energies = json.load(f)

species = list(atomic_energies.keys())
atomic_numbers = [ase.data.atomic_numbers[s] for s in species]
atomic_energies = np.array([atomic_energies[s] for s in species])
model = MACE(
    r_max=5.0,
    num_bessel=8,
    num_polynomial_cutoff=5,
    max_ell=2,
    interaction_cls=RealAgnosticResidualInteractionBlock,
    interaction_cls_first=RealAgnosticInteractionBlock,
    num_interactions=2,
    num_elements=len(atomic_numbers),
    hidden_irreps=o3.Irreps("128x0e + 128x1o"),
    MLP_irreps="16x0e",
    atomic_energies=atomic_energies,
    avg_num_neighbors=32,
    atomic_numbers=atomic_numbers,
    correlation=2,
    gate=F.silu,
).to(DEVICE)

# Dataset
train_images = ase.io.read("train_dataset.traj", ":")
for i, atoms in enumerate(train_images):
    energy = atoms.get_potential_energy()
    forces = atoms.get_forces()
    atoms.info["energy"] = energy
    atoms.arrays["forces"] = forces
table = AtomicNumberTable(zs=atomic_numbers)
configs = config_from_atoms_list(train_images)
train_dataset = [
    AtomicData.from_config(config, table, 5.0)
    for config in tqdm(configs, desc="Building dataset")
]
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Training
def train_loop(model, batch, optimizer):
    optimizer.zero_grad()
    output = model(batch.cuda(), compute_force=False)
    loss = F.mse_loss(output["energy"].squeeze(), batch.energy.squeeze())
    loss.backward()
    optimizer.step()
    return loss.item()

optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
for i in range(1000):
    num_batches = len(train_loader)
    loss = 0.0
    for batch in tqdm(train_loader, total=num_batches, desc=f"Epoch {i}"):
        loss += train_loop(model, batch, optimizer) / num_batches

    print(f"Epoch {i} loss: {loss:.6f}")

I also attach the training dataset (train_dataset.traj) and atomic_energies.json.

Screenshots

slow_down

After some hundreds of epochs, the time for epochs increases more than 2x.

Running environment:

What could be the reason for problem? I'm curious if the origin of this problem is e3nn or mace. Thanks in advance!

ilyes319 commented 9 months ago

Hey,

I see that you need to start using the official training script of MACE. I highly recommend you use the official training script, as we use different regularizations essential to MACE performance.

We are not experiencing the same thing with our trainer. On the slowdown, I suspect it might be that you are caching information on your GPU, and it is accumulating. I don't think it is coming from MACE.