I am trying to train a m3gnet model from scratch for my own dataset generation project. However, it appears that the model training is unstable during training.
Below is the code used to train the model:
from __future__ import annotations
import os
import glob
import shutil
import warnings
import numpy as np
import pytorch_lightning as pl
from functools import partial
from dgl.data.utils import split_dataset
from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger
import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule
warnings.simplefilter("ignore")
from pymatgen.core.structure import Structure
import json
with open("testing_data.json") as f:
d = json.load(f)
structures = []
energies = []
forces = []
stresses = []
for cur_dict in d:
structures.append(Structure.from_dict(cur_dict["structure"]))
energies.append(cur_dict["frame_properties"]["e_0_energy"])
forces.append(cur_dict["frame_properties"]["forces"])
stresses.append(cur_dict["frame_properties"]["stresses"])
# convert stresses to GPa to be consistent with original m3gnet.
for i in range(0, len(stresses)):
correct_unit_stress_list = np.array(stresses[i]) * -0.1
stresses[i] = correct_unit_stress_list.tolist()
labels = {
"energies": energies,
"forces": forces,
"stresses": stresses,
}
print(len(structures))
# read in elemental reference energies
with open('elemental_reference_energies.json') as f:
element_refs = np.array(list(json.load(f).values()))
element_types = get_element_list(structures)
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(
threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True
)
train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.9, 0.05, 0.05],
shuffle=True,
random_state=42,
)
my_collate_fn = partial(collate_fn_pes, include_line_graph=True, include_stress = True)
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
collate_fn=my_collate_fn,
batch_size=16,
num_workers=0,
)
model = M3GNet(
element_types=element_types,
is_intensive=False,
)
lit_module = PotentialLightningModule(
model=model,
include_line_graph=True,
stress_weight = 0.1,
element_refs = element_refs,
)
# make sure every epoch is saved
checkpoint_callback = pl.callbacks.ModelCheckpoint(
every_n_epochs=1,
save_top_k=-1,
)
logger = CSVLogger("logs", name="M3GNet_training")
trainer = pl.Trainer(
max_epochs=140,
accelerator="cuda",
num_nodes = 1, ##################
devices = 4,
strategy="ddp",
logger=logger,
inference_mode=False,
callbacks=[checkpoint_callback],
)
# find most recent checkpoint file to restart from!
checkpoint_files = glob.glob("./logs/**/*.ckpt", recursive = True)
most_recent_ckpt_file = None
for filename in checkpoint_files:
cur_epoch_num = int(filename.split("/")[-1].split("-")[0].split("=")[-1])
if most_recent_ckpt_file == None:
most_recent_ckpt_file = filename
else:
best_epoch_num = int(most_recent_ckpt_file.split("/")[-1].split("-")[0].split("=")[-1])
if cur_epoch_num > best_epoch_num:
most_recent_ckpt_file = filename
trainer.fit(
model=lit_module,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
ckpt_path = most_recent_ckpt_file,
)
# test the model, remember to set inference_mode=False in trainer (see above)
trainer.test(dataloaders=test_loader)
# save trained model
model_export_path = "./trained_model/"
lit_module.model.save(model_export_path)
# load trained model
model = matgl.load_model(path=model_export_path)
Note that, while my code implies checkpointing, no checkpointing seems to have occured in the training documented above.
I am trying to train a m3gnet model from scratch for my own dataset generation project. However, it appears that the model training is unstable during training.![image](https://github.com/materialsvirtuallab/matgl/assets/82329282/f4733a77-e34a-4d47-be4e-def925dd49eb)
Below is the code used to train the model:
Note that, while my code implies checkpointing, no checkpointing seems to have occured in the training documented above.
Other relevant info:
Any advice would be much appreciated!