materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
243 stars 58 forks source link

Issue with unstable m3gnet training #285

Closed matthewkuner closed 1 week ago

matthewkuner commented 3 weeks ago

I am trying to train a m3gnet model from scratch for my own dataset generation project. However, it appears that the model training is unstable during training. image

Below is the code used to train the model:

from __future__ import annotations

import os
import glob
import shutil
import warnings

import numpy as np
import pytorch_lightning as pl
from functools import partial
from dgl.data.utils import split_dataset
from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule
warnings.simplefilter("ignore")

from pymatgen.core.structure import Structure
import json

with open("testing_data.json") as f:
    d = json.load(f)

structures = []
energies = []
forces = []
stresses = []

for cur_dict in d:
    structures.append(Structure.from_dict(cur_dict["structure"]))
    energies.append(cur_dict["frame_properties"]["e_0_energy"])
    forces.append(cur_dict["frame_properties"]["forces"])
    stresses.append(cur_dict["frame_properties"]["stresses"])

# convert stresses to GPa to be consistent with original m3gnet.
for i in range(0, len(stresses)):
    correct_unit_stress_list = np.array(stresses[i]) * -0.1
    stresses[i] = correct_unit_stress_list.tolist()

labels = {
    "energies": energies,
    "forces": forces,
    "stresses": stresses,
}

print(len(structures))

# read in elemental reference energies
with open('elemental_reference_energies.json') as f:
    element_refs = np.array(list(json.load(f).values()))

element_types = get_element_list(structures)
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(
    threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True
)
train_data, val_data, test_data = split_dataset(
    dataset,
    frac_list=[0.9, 0.05, 0.05],
    shuffle=True,
    random_state=42,
)
my_collate_fn = partial(collate_fn_pes, include_line_graph=True, include_stress = True)
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=my_collate_fn,
    batch_size=16,
    num_workers=0,
)
model = M3GNet(
    element_types=element_types,
    is_intensive=False,
)
lit_module = PotentialLightningModule(
    model=model, 
    include_line_graph=True, 
    stress_weight = 0.1,
    element_refs = element_refs,
)

# make sure every epoch is saved
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    every_n_epochs=1,
    save_top_k=-1,
)

logger = CSVLogger("logs", name="M3GNet_training")
trainer = pl.Trainer(
    max_epochs=140, 
    accelerator="cuda", 
    num_nodes = 1, ##################
    devices = 4,
    strategy="ddp",
    logger=logger, 
    inference_mode=False,
    callbacks=[checkpoint_callback],
)

# find most recent checkpoint file to restart from!
checkpoint_files = glob.glob("./logs/**/*.ckpt", recursive = True)
most_recent_ckpt_file = None
for filename in checkpoint_files:
    cur_epoch_num = int(filename.split("/")[-1].split("-")[0].split("=")[-1])
    if most_recent_ckpt_file == None:
        most_recent_ckpt_file = filename
    else:
        best_epoch_num = int(most_recent_ckpt_file.split("/")[-1].split("-")[0].split("=")[-1])
        if cur_epoch_num > best_epoch_num:
            most_recent_ckpt_file = filename

trainer.fit(
    model=lit_module, 
    train_dataloaders=train_loader, 
    val_dataloaders=val_loader,
    ckpt_path = most_recent_ckpt_file,
)

# test the model, remember to set inference_mode=False in trainer (see above)
trainer.test(dataloaders=test_loader)

# save trained model
model_export_path = "./trained_model/"
lit_module.model.save(model_export_path)

# load trained model
model = matgl.load_model(path=model_export_path)

Note that, while my code implies checkpointing, no checkpointing seems to have occured in the training documented above.

Other relevant info:

-Linux
dgl                       2.2a240410+cu121          pypi_0    pypi
matgl                     1.1.1                    pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-ml-py3             7.352.0                  pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.3.101                 pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
pytorch-lightning         2.2.1                    pypi_0    pypi
torch                     2.1.1                    pypi_0    pypi
torch-ema                 0.3                      pypi_0    pypi
torchdata                 0.7.1                    pypi_0    pypi
torchmetrics              1.2.1                    pypi_0    pypi

Any advice would be much appreciated!