materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
231 stars 57 forks source link

[Bug]: Finetuned model is worse than pretrained model #264

Closed bfocassio closed 1 month ago

bfocassio commented 1 month ago

Email (Optional)

bfocassio@gmail.com

Version

matgl@1.1.1

Which OS(es) are you using?

What happened?

I'm trying to fine-tune a model using the data from 10.1021/acs.jpca.9b08723

Attached is the code I came up with based on the examples. It seems it doesn't make a difference if I provide the element_refs, data_mean,data_std for my dataset or keep the ones from the pre-trained model; if I use a very small or large learning rate; or even if I do 1 or 200 epochs, the fine-tuned model seems to be worse than the pre-trained in all cases.

I think I'm missing something very basic about this. Any ideas?

My best

Code snippet

from __future__ import annotations

import os
import shutil
import warnings

import numpy as np
import pytorch_lightning as pl
from functools import partial
from dgl.data.utils import split_dataset
#from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.ext.ase import PESCalculator
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule

from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor

from sklearn.linear_model import LinearRegression
from sklearn import metrics

import torch

from prettytable import PrettyTable

# To suppress warnings for clearer output
warnings.simplefilter("ignore")

def parse_data(filename,structure_type='pymatgen'):

    data = read(filename,index=':')
    if structure_type=='ase':
        structures = [atoms for atoms in data]
    else:
        structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in data]
    energies = [atoms.get_total_energy() for atoms in data]
    forces = [atoms.get_forces().tolist() for atoms in data]
    stress = [atoms.get_stress(voigt=False).tolist() for atoms in data]

    return structures,energies,forces,stress

def compute_element_refs_dict(filename,elements):

    structures, energies, forces, stress = parse_data(filename,structure_type='ase')
    element_encoder = np.zeros((len(structures),len(elements)))

    for io,atoms in enumerate(structures):
        for jo,el in enumerate(elements):
            if el in atoms.get_chemical_symbols():
                element_encoder[io,jo] = len((np.array(atoms.get_chemical_symbols()) == el).nonzero()[0])

    lin_reg = LinearRegression(fit_intercept=False)
    lin_reg.fit(element_encoder,energies)

    element_refs_lin_reg = lin_reg.coef_

    element_ref_dict = dict(zip(elements,element_refs_lin_reg))

    return element_ref_dict

def eval_train(structures,energies,forces):

    structures_ase = [AseAtomsAdaptor.get_atoms(struct) for struct in structures]

    energy_per_atom = []
    forces_flat = []
    ml_energies = []
    ml_energies_ft = []
    ml_forces = []
    ml_forces_ft = []
    for io,atoms in enumerate(structures_ase):

        energy_per_atom.append(energies[io]/len(atoms))
        for fat in forces[io]:
            for f in fat:
                forces_flat.append(f)

        atoms.calc = m3gnet
        ml_energies.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces.append(f)

        atoms.calc = m3gnet_ft
        ml_energies_ft.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces_ft.append(f)

    rmse_energy = metrics.root_mean_squared_error(energy_per_atom,ml_energies)
    rmse_energy_ft = metrics.root_mean_squared_error(energy_per_atom,ml_energies_ft)

    mae_energy = metrics.mean_absolute_error(energy_per_atom,ml_energies)
    mae_energy_ft = metrics.mean_absolute_error(energy_per_atom,ml_energies_ft)

    rmse_forces = metrics.root_mean_squared_error(forces_flat,ml_forces)
    rmse_forces_ft = metrics.root_mean_squared_error(forces_flat,ml_forces_ft)

    mae_forces = metrics.mean_absolute_error(forces_flat,ml_forces)
    mae_forces_ft = metrics.mean_absolute_error(forces_flat,ml_forces_ft)

    headers = ['Model','Energy MAE meV/atom', 'Energy RMSE meV/atom', 'Force MAE eV/Ang', 'Force RMSE eV/Ang']
    m3gnet_metrics = ['M3GNET',mae_energy*1000,rmse_energy*1000,mae_forces,rmse_forces]
    m3gnet_ft_metrics = ['M3GNET-FT',mae_energy_ft*1000,rmse_energy_ft*1000,mae_forces_ft,rmse_forces_ft]

    model_metrics_table = [headers,m3gnet_metrics,m3gnet_ft_metrics]

    tab2 = PrettyTable(model_metrics_table[0])
    tab2.add_rows(model_metrics_table[1:])
    tab2.float_format = "7.4"
    print(tab2)

training_data_path = '/home/bruno.focassio/mlips_surface_benchmark/m3gnet_ft/data/training_data.xyz'
test_data_path = '/home/bruno.focassio/mlips_surface_benchmark/m3gnet_ft/data/test_data.xyz'
name = 'finetune_m3gnet'

jpca_elements = ['Ni','Cu','Si','Ge','Li','Mo']

train_structures,train_energies,train_forces,train_stress = parse_data(training_data_path)
test_structures,test_energies,test_forces,test_stress = parse_data(test_data_path)

train_labels = {
    "energies": train_energies,
    "forces": train_forces,
    "stresses": train_stress,
}
test_labels = {
    "energies": test_energies,
    "forces": test_forces,
    "stresses": test_stress,
}

print(f"{len(train_structures)} training structures")
print(f"{len(test_structures)} test structures")

print('Considering following elements:')
print(jpca_elements)

element_ref_dict = compute_element_refs_dict(training_data_path,jpca_elements)

print('Elemental ref energies from dataset:')
print(element_ref_dict)

element_types = get_element_list(train_structures)
converter = Structure2Graph(element_types=element_types, cutoff=5.0)

train_dataset = MGLDataset(
    threebody_cutoff=4.0, structures=train_structures, converter=converter, labels=train_labels, include_line_graph=True
)
test_dataset = MGLDataset(
    threebody_cutoff=4.0, structures=test_structures, converter=converter, labels=test_labels, include_line_graph=True
)

train_data, val_data = split_dataset(
    train_dataset,
    frac_list=[0.9, 0.1],
    shuffle=True,
    random_state=42,
)

my_collate_fn = partial(collate_fn_pes, include_line_graph=True)

train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_dataset,
    collate_fn=my_collate_fn,
    batch_size=8,
    num_workers=0,
)

m3gnet_nnp = matgl.load_model("M3GNet-MP-2021.2.8-PES")
model_pretrained = m3gnet_nnp.model

pretrained_data = np.load('m3gnet_values.npz')

data_mean,data_std,element_refs = pretrained_data['data_mean'],pretrained_data['data_std'],pretrained_data['element_refs']

lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=5e-8, include_line_graph=True,force_weight=1,stress_weight=0.1,element_refs=element_refs,data_mean=data_mean,data_std=data_std)

# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg.
logger = CSVLogger("logs", name="M3GNet_finetuning")
trainer = pl.Trainer(max_epochs=50, accelerator="cuda", logger=logger, inference_mode=False)
trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)

# save trained model
model_save_path = f"./{name}/"
lit_module_finetune.model.save(model_save_path)
# load trained model
trained_model = matgl.load_model(path=model_save_path)

train_metrics_dict = trainer.test(model=lit_module_finetune,dataloaders=train_loader)
val_metrics_dict = trainer.test(model=lit_module_finetune,dataloaders=val_loader)
test_metrics_dict = trainer.test(model=lit_module_finetune,dataloaders=test_loader)

headers = ['Dataset','Energy MAE eV/atom','Energy RMSE eV/atom', 'Force MAE eV/Ang','Force RMSE eV/Ang', 'Stress MAE eV/Ang^2', 'Stress RMSE eV/Ang^2']
keys = ['test_Energy_MAE','test_Energy_RMSE','test_Force_MAE','test_Force_RMSE','test_Stress_MAE','test_Stress_RMSE']

train_metrics = ['Train',*[train_metrics_dict[0][k] for k in keys]]
val_metrics = ['Val',*[val_metrics_dict[0][k] for k in keys]]
test_metrics = ['Test',*[test_metrics_dict[0][k] for k in keys]]

metrics_table = [headers,train_metrics,val_metrics,test_metrics]

tab = PrettyTable(metrics_table[0])
tab.add_rows(metrics_table[1:])
tab.float_format = "7.4"
print(tab)

pot_ft = matgl.load_model(model_save_path)
m3gnet_ft = PESCalculator(pot_ft)

m3gnet = PESCalculator(matgl.load_model("M3GNet-MP-2021.2.8-PES"))

print('TRAINING SET:')
eval_train(train_structures,train_energies,train_forces)
print('TEST SET:')
eval_train(test_structures,test_energies,test_forces)

Log output

No response

Code of Conduct

shyuep commented 1 month ago

It would be helpful if you could let us know what is meant by "worse". Higher train MAEs? Higher test MAEs? How much higher? Thanks.

bfocassio commented 1 month ago

Let me put a few cases I've tested here:

For 50 epochs, with a 5e-8 LR using the same data_mean,data_std and element_refs from the pre-trained model we have:

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |
| M3GNET-FT |       82.4868       |       133.7823       |      0.1561      |       0.2982      |
+-----------+---------------------+----------------------+------------------+-------------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |
| M3GNET-FT |       85.3629       |       136.9792       |      0.1648      |       0.3077      |
+-----------+---------------------+----------------------+------------------+-------------------+

The same scenario but running 200 epochs with 1e-4 LR we have:

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |
| M3GNET-FT |      1095.6828      |      1313.5432       |      0.2475      |       0.4338      |
+-----------+---------------------+----------------------+------------------+-------------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |
| M3GNET-FT |      1107.0479      |      1321.9523       |      0.2554      |       0.4418      |
+-----------+---------------------+----------------------+------------------+-------------------+
shyuep commented 1 month ago

I would say the first set of stats does not seem substantially different between FT and non-FT. @kenko911 pls investigate and see if there is any issue here.

bfocassio commented 1 month ago

Another case I've tested is:

For 200 epochs, with 1e-6 LR using the same data_mean,data_std and element_refs from the pre-trained model

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |
| M3GNET-FT |       736.2891      |       820.9986       |      0.2070      |       0.3547      |
+-----------+---------------------+----------------------+------------------+-------------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |
| M3GNET-FT |       746.1202      |       828.0520       |      0.2130      |       0.3597      |
+-----------+---------------------+----------------------+------------------+-------------------+
kenko911 commented 1 month ago

Hi @bfocassio and @shyuep, I will investigate this.

kenko911 commented 1 month ago

Hi @bfocassio, I think the main cause of poor performance for FT is the inconsistency of element_types used in the MGLDataset and the pretrained M3GNet models. The element_types defined here are ('Ni', 'Cu', 'Si', 'Ge', 'Li', 'Mo'), whereas the pretrained M3GNet model includes 89 elements. Please refer to the latest example on the training and fine-tuning of the M3GNet potential. I suggest setting element_types to DEFAULT_ELEMENTS and fine-tuning the model.

bfocassio commented 1 month ago

Thank you for the reply. I fixed that by replacing

element_types = get_element_list(train_structures)

for

element_types = DEFAULT_ELEMENTS

where I get the DEFAULT_ELEMENTS from matgl.config as in the updated examples.

However, I tried running 50 epochs with 1e-6 LR, and this is the result:

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |
| M3GNET-FT |       216.8817      |       294.9494       |      0.1714      |       0.3017      |
+-----------+---------------------+----------------------+------------------+-------------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |
| M3GNET-FT |       221.6903      |       302.6300       |      0.1789      |       0.3103      |
+-----------+---------------------+----------------------+------------------+-------------------+
kenko911 commented 1 month ago

@bfocassio I think 50 epochs is still way too far to reach the convergence since element_refs is not the same. Furthermore, I would also like to note that the M3GNet pretrained dataset is rather noisy and the number of single elemental systems is very small in pretrained M3GNet training set. I would not be surprised if the performance of M3GNet-FT is worse than M3GNet trained from scratch for this specific dataset.

bfocassio commented 1 month ago

Ok, I see ...

I'll try with something close to 200 epochs to see what happens, or even more than that.

I have also tried trying a M3GNET from scratch. Here is the code used (which is very similar to the above code), and the final result is not very exciting as well

Code:

from __future__ import annotations

import os
import shutil
import warnings

import numpy as np
import pytorch_lightning as pl
from functools import partial
from dgl.data.utils import split_dataset
#from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.ext.ase import PESCalculator
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule

from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor

from sklearn.linear_model import LinearRegression
from sklearn import metrics

import torch

from prettytable import PrettyTable

# To suppress warnings for clearer output
warnings.simplefilter("ignore")

def parse_data(filename,structure_type='pymatgen'):

    data = read(filename,index=':')
    if structure_type=='ase':
        structures = [atoms for atoms in data]
    else:
        structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in data]
    energies = [atoms.get_total_energy() for atoms in data]
    forces = [atoms.get_forces().tolist() for atoms in data]
    stress = [atoms.get_stress(voigt=False).tolist() for atoms in data]

    return structures,energies,forces,stress

def compute_element_refs_dict(filename,elements):

    structures, energies, forces, stress = parse_data(filename,structure_type='ase')
    element_encoder = np.zeros((len(structures),len(elements)))

    for io,atoms in enumerate(structures):
        for jo,el in enumerate(elements):
            if el in atoms.get_chemical_symbols():
                element_encoder[io,jo] = len((np.array(atoms.get_chemical_symbols()) == el).nonzero()[0])

    lin_reg = LinearRegression(fit_intercept=False)
    lin_reg.fit(element_encoder,energies)

    element_refs_lin_reg = lin_reg.coef_

    element_ref_dict = dict(zip(elements,element_refs_lin_reg))

    return element_ref_dict

def eval_train(structures,energies,forces):

    structures_ase = [AseAtomsAdaptor.get_atoms(struct) for struct in structures]

    energy_per_atom = []
    forces_flat = []
    ml_energies = []
    ml_energies_ft = []
    ml_forces = []
    ml_forces_ft = []
    for io,atoms in enumerate(structures_ase):

        energy_per_atom.append(energies[io]/len(atoms))
        for fat in forces[io]:
            for f in fat:
                forces_flat.append(f)

        atoms.calc = m3gnet
        ml_energies.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces.append(f)

        atoms.calc = m3gnet_ft
        ml_energies_ft.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces_ft.append(f)

    rmse_energy = metrics.root_mean_squared_error(energy_per_atom,ml_energies)
    rmse_energy_ft = metrics.root_mean_squared_error(energy_per_atom,ml_energies_ft)

    mae_energy = metrics.mean_absolute_error(energy_per_atom,ml_energies)
    mae_energy_ft = metrics.mean_absolute_error(energy_per_atom,ml_energies_ft)

    rmse_forces = metrics.root_mean_squared_error(forces_flat,ml_forces)
    rmse_forces_ft = metrics.root_mean_squared_error(forces_flat,ml_forces_ft)

    mae_forces = metrics.mean_absolute_error(forces_flat,ml_forces)
    mae_forces_ft = metrics.mean_absolute_error(forces_flat,ml_forces_ft)

    headers = ['Model','Energy MAE meV/atom', 'Energy RMSE meV/atom', 'Force MAE eV/Ang', 'Force RMSE eV/Ang']
    m3gnet_metrics = ['M3GNET',mae_energy*1000,rmse_energy*1000,mae_forces,rmse_forces]
    m3gnet_ft_metrics = ['M3GNET-FT',mae_energy_ft*1000,rmse_energy_ft*1000,mae_forces_ft,rmse_forces_ft]

    model_metrics_table = [headers,m3gnet_metrics,m3gnet_ft_metrics]

    tab2 = PrettyTable(model_metrics_table[0])
    tab2.add_rows(model_metrics_table[1:])
    tab2.float_format = "7.4"
    print(tab2)

training_data_path = '/home/bruno.focassio/mlips_surface_benchmark/m3gnet_ft/data/training_data.xyz'
test_data_path = '/home/bruno.focassio/mlips_surface_benchmark/m3gnet_ft/data/test_data.xyz'
name = 'm3gnet_jpca'

jpca_elements = ['Ni','Cu','Si','Ge','Li','Mo']

train_structures,train_energies,train_forces,train_stress = parse_data(training_data_path)
test_structures,test_energies,test_forces,test_stress = parse_data(test_data_path)

train_labels = {
    "energies": train_energies,
    "forces": train_forces,
    "stresses": train_stress,
}
test_labels = {
    "energies": test_energies,
    "forces": test_forces,
    "stresses": test_stress,
}

print(f"{len(train_structures)} training structures")
print(f"{len(test_structures)} test structures")

print('Considering following elements:')
print(jpca_elements)

element_ref_dict = compute_element_refs_dict(training_data_path,jpca_elements)

print('Elemental ref energies from dataset:')
print(element_ref_dict)

data_mean = np.mean(train_energies)
data_std = np.std(train_energies)

element_types = get_element_list(train_structures)
converter = Structure2Graph(element_types=element_types, cutoff=5.0)

train_dataset = MGLDataset(
    threebody_cutoff=4.0, structures=train_structures, converter=converter, labels=train_labels, include_line_graph=True
)
test_dataset = MGLDataset(
    threebody_cutoff=4.0, structures=test_structures, converter=converter, labels=test_labels, include_line_graph=True
)

train_data, val_data = split_dataset(
    train_dataset,
    frac_list=[0.9, 0.1],
    shuffle=True,
    random_state=42,
)

my_collate_fn = partial(collate_fn_pes, include_line_graph=True)

train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_dataset,
    collate_fn=my_collate_fn,
    batch_size=8,
    num_workers=0,
)

model = M3GNet(
    element_types=element_types,
    is_intensive=False,
)

element_refs = [element_ref_dict[el] if el in jpca_elements else 0 for el in element_types]

lit_module = PotentialLightningModule(model=model, lr=1e-6, include_line_graph=True,force_weight=1,stress_weight=0.1,element_refs=element_refs,data_mean=data_mean,data_std=data_std)

# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg.
logger = CSVLogger("logs", name="M3GNet")
trainer = pl.Trainer(max_epochs=200, accelerator="cuda", logger=logger, inference_mode=False)
trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)

# save trained model
model_save_path = f"./{name}/"
lit_module.model.save(model_save_path)
# load trained model
trained_model = matgl.load_model(path=model_save_path)

train_metrics_dict = trainer.test(model=lit_module,dataloaders=train_loader)
val_metrics_dict = trainer.test(model=lit_module,dataloaders=val_loader)
test_metrics_dict = trainer.test(model=lit_module,dataloaders=test_loader)

headers = ['Dataset','Energy MAE eV/atom','Energy RMSE eV/atom', 'Force MAE eV/Ang','Force RMSE eV/Ang', 'Stress MAE eV/Ang^2', 'Stress RMSE eV/Ang^2']
keys = ['test_Energy_MAE','test_Energy_RMSE','test_Force_MAE','test_Force_RMSE','test_Stress_MAE','test_Stress_RMSE']

train_metrics = ['Train',*[train_metrics_dict[0][k] for k in keys]]
val_metrics = ['Val',*[val_metrics_dict[0][k] for k in keys]]
test_metrics = ['Test',*[test_metrics_dict[0][k] for k in keys]]

metrics_table = [headers,train_metrics,val_metrics,test_metrics]

tab = PrettyTable(metrics_table[0])
tab.add_rows(metrics_table[1:])
tab.float_format = "7.4"
print(tab)

pot_ft = matgl.load_model(model_save_path)
m3gnet_ft = PESCalculator(pot_ft)

m3gnet = PESCalculator(matgl.load_model("M3GNet-MP-2021.2.8-PES"))

print('TRAINING SET:')
eval_train(train_structures,train_energies,train_forces)
print('TEST SET:')
eval_train(test_structures,test_energies,test_forces)

The result is

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |
| M3GNET-FT |      2478.6676      |      9925.5555       |      0.4928      |       3.1393      |
+-----------+---------------------+----------------------+------------------+-------------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+-----------+---------------------+----------------------+------------------+-------------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |
| M3GNET-FT |      2332.3704      |      7027.6740       |      0.5813      |       7.5631      |
+-----------+---------------------+----------------------+------------------+-------------------+
JiQi535 commented 1 month ago

@bfocassio I believe the major problem is the unit conversion of stress. The data from JCP paper is directly from VASP, where stress is having an unit of KBar, and it is actually negative stress. See this discussion: https://github.com/materialsproject/pymatgen/issues/1388

To train M3GNet, we need to pre-process the negative stress in KBar to stress in GPa, as described in the original M3GNet paper. We just need to multiple -0.1 for all stress values.

@kenko911 It seems that we don't have clear documentation of stress unit and definition in matGL for now? Also, the tutorial example of training M3GNet potential with lightning seems to directly use stress data from MP (VASP, DFT), where preprocessing should be added.

bfocassio commented 1 month ago

@JiQi535 Thanks very much for that info. I'll make this change and post the results here as soon as possible. Cheers

kenko911 commented 1 month ago

@JiQi535 good catch, I thought this was just energy and force training..... I agree that we should add some documentations regarding this

bfocassio commented 1 month ago

Running the finetunning with the GPa stress for 100 epochs with 1e-6 LR I get something like:

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |     0.6639     |      1.6248     |
| M3GNET-FT |       359.9006      |       518.9627       |      0.1870      |       0.3161      |     1.1269     |      2.3768     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |     0.6858     |      1.5642     |
| M3GNET-FT |       367.4000      |       528.0001       |      0.1934      |       0.3242      |     1.1205     |      2.3177     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+

For 200 epochs I get something like this:

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |     0.6639     |      1.6248     |
| M3GNET-FT |       733.8190      |       819.4782       |      0.2068      |       0.3540      |     1.9984     |      5.2956     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |     0.6858     |      1.5642     |
| M3GNET-FT |       743.6299      |       826.5217       |      0.2128      |       0.3591      |     1.9387     |      5.1322     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+

Another question regarding units I have is: what are the units on the model training and why are they different from when I evaluate the model through the PESCalculator? Is it possible that something is getting lost/mixed up between model training, serialization, and reloading it as a calculator?

For example, for 100 epochs with 1e-6 LR I get from using the trainer.test() function something like this (assuming the units on the header are not correct):

+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
| Dataset | Energy MAE eV/atom | Energy RMSE eV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
|  Train  |       1.5227       |        2.1154       |      0.3193      |       0.5887      |     0.4423     |      0.8860     |
|   Val   |       1.5977       |        2.1901       |      0.3048      |       0.5585      |     0.3940     |      0.7608     |
|   Test  |       1.5303       |        1.5440       |      0.3285      |       0.4618      |     0.4374     |      0.7453     |
+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+

For 200 epochs with 1e-6 LR, the metrics from trainer.test() are:

+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
| Dataset | Energy MAE eV/atom | Energy RMSE eV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
|  Train  |       0.7551       |        1.0415       |      0.2327      |       0.4210      |     0.4495     |      0.9065     |
|   Val   |       0.7259       |        0.9837       |      0.2204      |       0.3920      |     0.4492     |      0.8826     |
|   Test  |       0.7522       |        0.7583       |      0.2485      |       0.3387      |     0.4495     |      0.7527     |
+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
kenko911 commented 1 month ago

@bfocassio I strongly recommend that you do the training step by step by starting from a single elemental system from scratch. I don't understand why you use that low learning rate. In Pytorch documentation, the default learning rate for Adam is set to 10e-3. Also, it is not meaningful to have a data_mean here since the total energy is an extensive property. The default units in PESCalculator are the same with the model training. The possible reason you get different results is that you loaded the wrong model.

bfocassio commented 1 month ago

For training a model from scratch, for instance, for a single element, should I use element_types = ['Ni'] (for Ni dataset) or should I keep DEFAULT_ELEMENTS?, the first case makes more sense here, however, there is out of range error:

Traceback (most recent call last):
  File "/home/bruno.focassio/Documents/lnnano/mlips/benchmark_paper/benchmark/fine_tunning_m3gnet/element_jpca/Ni/run_train.py", line 210, in <module>
    trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/matgl/utils/training.py", line 59, in validation_step
    results, batch_size = self.step(batch)  # type: ignore
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/matgl/utils/training.py", line 405, in step
    e, f, s, _ = self(g=g, lat=lat, state_attr=state_attr, l_g=l_g)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/matgl/utils/training.py", line 380, in forward
    e, f, s, h = self.model(g=g, lat=lat, l_g=l_g, state_attr=state_attr)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/matgl/apps/pes.py", line 107, in forward
    total_energies = self.model(g=g, state_attr=state_attr, l_g=l_g)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/matgl/models/_m3gnet.py", line 254, in forward
    node_feat, edge_feat, state_feat = self.embedding(node_types, g.edata["rbf"], state_attr)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/matgl/layers/_embedding.py", line 89, in forward
    node_feat = self.layer_node_embedding(node_attr)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/bruno.focassio/miniconda3/envs/chgnet24/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
bfocassio commented 1 month ago

Hi there, thanks for the help so far.

I'm still having trouble with the units of the models. For instance, trainer.test is still resulting in very different metrics than evaluating the model from PESCalculator . In the following example, I'm loading the PESCalculator directly from the PotentialLightningModule as PESCalculator(lit_module_finetune.model)

from __future__ import annotations

import os
import shutil
import warnings

import numpy as np
import pytorch_lightning as pl
from functools import partial
from dgl.data.utils import split_dataset
from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.ext.ase import PESCalculator, Atoms2Graph
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule
from matgl.config import DEFAULT_ELEMENTS

from pymatgen.io.ase import AseAtomsAdaptor
from sklearn import metrics

from prettytable import PrettyTable

# To suppress warnings for clearer output
warnings.simplefilter("ignore")

def eval_train(model,structures,energies,forces):

    energy_per_atom = []
    ml_energies = []
    forces_flat = []
    ml_forces = []
    for io,atoms in enumerate(structures):

        energy_per_atom.append(energies[io]/len(atoms))
        for fat in forces[io]:
            for f in fat:
                forces_flat.append(f)

        atoms.calc = model
        ml_energies.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces.append(f)

    rmse_energy = metrics.root_mean_squared_error(energy_per_atom,ml_energies)   
    mae_energy = metrics.mean_absolute_error(energy_per_atom,ml_energies)

    rmse_forces = metrics.root_mean_squared_error(forces_flat,ml_forces)
    mae_forces = metrics.mean_absolute_error(forces_flat,ml_forces)

    headers = ['Model','Energy MAE meV/atom', 'Energy RMSE meV/atom', 'Force MAE eV/Ang', 'Force RMSE eV/Ang']
    m3gnet_metrics = ['M3GNET',mae_energy*1000,rmse_energy*1000,mae_forces,rmse_forces]

    model_metrics_table = [headers,m3gnet_metrics]

    tab2 = PrettyTable(model_metrics_table[0])
    tab2.add_rows(model_metrics_table[1:])
    tab2.float_format = "7.4"
    print(tab2)

# Obtain your API key here: https://next-gen.materialsproject.org/api
mpr = MPRester(api_key="Tun3lglrsqRhJeIxqiX95NPuL9oQtIi8")
# Obtain your API key here: https://next-gen.materialsproject.org/api
entries = mpr.get_entries_in_chemsys(["Si","O"])
structures = [AseAtomsAdaptor.get_atoms(e.structure) for e in entries]
energies = [e.energy for e in entries]
forces = [np.zeros((len(s), 3)).tolist() for s in structures]
stresses = [np.zeros((3, 3)).tolist() for s in structures]
labels = {
    "energies": energies,
    "forces": forces,
    "stresses": stresses,
}

print(f"{len(structures)} downloaded from MP.")

element_types = DEFAULT_ELEMENTS
converter = Atoms2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(
    threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True
)
train_data, val_data, test_data = split_dataset(
    dataset,
    frac_list=[0.8, 0.1, 0.1],
    shuffle=True,
    random_state=42,
)
my_collate_fn = partial(collate_fn_pes, include_line_graph=True)
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=my_collate_fn,
    batch_size=2,
    num_workers=0,
)

# download a pre-trained M3GNet
m3gnet_nnp = matgl.load_model("M3GNet-MP-2021.2.8-PES")
model_pretrained = m3gnet_nnp.model
element_refs = m3gnet_nnp.element_refs.property_offset
data_std = m3gnet_nnp.data_std
lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-3, include_line_graph=True,data_std=data_std,element_refs=element_refs)

# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg.
logger = CSVLogger("logs", name="M3GNet_finetuning")
trainer = pl.Trainer(max_epochs=1, accelerator="cuda", logger=logger, inference_mode=False)
trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)

train_metrics_dict = trainer.test(lit_module_finetune,train_loader)
val_metrics_dict = trainer.test(lit_module_finetune,val_loader)
test_metrics_dict = trainer.test(lit_module_finetune,test_loader)

headers = ['Dataset','Energy MAE eV/atom','Energy RMSE eV/atom', 'Force MAE eV/Ang','Force RMSE eV/Ang', 'Stress MAE GPa', 'Stress RMSE GPa']
keys = ['test_Energy_MAE','test_Energy_RMSE','test_Force_MAE','test_Force_RMSE','test_Stress_MAE','test_Stress_RMSE']

train_metrics = ['Train',*[train_metrics_dict[0][k] for k in keys]]
val_metrics = ['Val',*[val_metrics_dict[0][k] for k in keys]]
test_metrics = ['Test',*[test_metrics_dict[0][k] for k in keys]]

metrics_table = [headers,train_metrics,val_metrics,test_metrics]

tab = PrettyTable(metrics_table[0])
tab.add_rows(metrics_table[1:])
tab.float_format = "7.4"
print(tab)

m3gnet_ft = PESCalculator(lit_module_finetune.model)

eval_train(m3gnet_ft,structures,energies,forces)

Resulting in for trainer.test:

+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
| Dataset | Energy MAE eV/atom | Energy RMSE eV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+
|  Train  |       0.0917       |        0.1079       |      0.1811      |       0.2879      |     0.0000     |      0.0000     |
|   Val   |       0.1049       |        0.1187       |      0.1879      |       0.2899      |     0.0000     |      0.0000     |
|   Test  |       0.0791       |        0.0924       |      0.1776      |       0.2750      |     0.0000     |      0.0000     |
+---------+--------------------+---------------------+------------------+-------------------+----------------+-----------------+

AND for PESCalculator

+--------+---------------------+----------------------+------------------+-------------------+
| Model  | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang |
+--------+---------------------+----------------------+------------------+-------------------+
| M3GNET |       442.0444      |       826.1631       |      0.1653      |       0.5268      |
+--------+---------------------+----------------------+------------------+-------------------+
bfocassio commented 1 month ago

Sorry to keep posting one right after the other, but @kenko911 and @JiQi535 thank you very much for the help. I ran a couple more tests with fixed units for stress data, can you please give it a look?

The finetuning script is the following:

from __future__ import annotations

import os
import shutil
import warnings

import numpy as np
import pytorch_lightning as pl
from functools import partial
from dgl.data.utils import split_dataset
#from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.ext.ase import PESCalculator, Atoms2Graph
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule
from matgl.config import DEFAULT_ELEMENTS

from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor

from sklearn.linear_model import LinearRegression
from sklearn import metrics

import torch

from prettytable import PrettyTable

# To suppress warnings for clearer output
warnings.simplefilter("ignore")

def parse_data(filename,structure_type='ase'):

    data = read(filename,index=':')
    if structure_type=='ase':
        structures = [atoms for atoms in data]
    else:
        structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in data]
    energies = [atoms.get_total_energy() for atoms in data]
    forces = [atoms.get_forces().tolist() for atoms in data]
    stress = [atoms.get_stress(voigt=False).tolist() for atoms in data]

    return structures,energies,forces,stress

def eval_train(structures,energies,forces,stress):

    structures_ase = structures 

    energy_per_atom = []
    forces_flat = []
    stress_flat = []
    ml_energies = []
    ml_energies_ft = []
    ml_forces = []
    ml_forces_ft = []
    ml_stress = []
    ml_stress_ft = []
    for io,atoms in enumerate(structures_ase):

        energy_per_atom.append(energies[io]/len(atoms))
        for fat in forces[io]:
            for f in fat:
                forces_flat.append(f)
        for sat in stress[io]:
            for s in sat:
                stress_flat.append(s)

        atoms.calc = m3gnet
        ml_energies.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces.append(f)
        stress_atoms = atoms.get_stress(voigt=False).ravel()
        for s in stress_atoms:
            ml_stress.append(s)

        atoms.calc = m3gnet_ft
        ml_energies_ft.append(atoms.get_total_energy()/len(atoms))
        forces_atoms = atoms.get_forces().ravel()
        for f in forces_atoms:
            ml_forces_ft.append(f)
        stress_atoms = atoms.get_stress(voigt=False).ravel()
        for s in stress_atoms:
            ml_stress_ft.append(s)

    rmse_energy = metrics.root_mean_squared_error(energy_per_atom,ml_energies)
    rmse_energy_ft = metrics.root_mean_squared_error(energy_per_atom,ml_energies_ft)

    mae_energy = metrics.mean_absolute_error(energy_per_atom,ml_energies)
    mae_energy_ft = metrics.mean_absolute_error(energy_per_atom,ml_energies_ft)

    rmse_forces = metrics.root_mean_squared_error(forces_flat,ml_forces)
    rmse_forces_ft = metrics.root_mean_squared_error(forces_flat,ml_forces_ft)

    mae_forces = metrics.mean_absolute_error(forces_flat,ml_forces)
    mae_forces_ft = metrics.mean_absolute_error(forces_flat,ml_forces_ft)

    rmse_stress = metrics.root_mean_squared_error(stress_flat,ml_stress)
    rmse_stress_ft = metrics.root_mean_squared_error(stress_flat,ml_stress_ft)

    mae_stress = metrics.mean_absolute_error(stress_flat,ml_stress)
    mae_stress_ft = metrics.mean_absolute_error(stress_flat,ml_stress_ft)

    headers = ['Model','Energy MAE meV/atom', 'Energy RMSE meV/atom', 'Force MAE eV/Ang', 'Force RMSE eV/Ang', 'Stress MAE GPa', 'Stress RMSE GPa']
    m3gnet_metrics = ['M3GNET',mae_energy*1000,rmse_energy*1000,mae_forces,rmse_forces,mae_stress,rmse_stress]
    m3gnet_ft_metrics = ['M3GNET-FT',mae_energy_ft*1000,rmse_energy_ft*1000,mae_forces_ft,rmse_forces_ft,mae_stress_ft,rmse_stress_ft]

    model_metrics_table = [headers,m3gnet_metrics,m3gnet_ft_metrics]

    tab2 = PrettyTable(model_metrics_table[0])
    tab2.add_rows(model_metrics_table[1:])
    tab2.float_format = "7.4"
    print(tab2)

training_data_path = 'data/Ni/training_gpa.xyz'
test_data_path = 'data/Ni/test_gpa.xyz'
name = 'finetune_m3gnet'

train_structures,train_energies,train_forces,train_stress = parse_data(training_data_path)
test_structures,test_energies,test_forces,test_stress = parse_data(test_data_path)

train_labels = {
    "energies": train_energies,
    "forces": train_forces,
    "stresses": train_stress,
}
test_labels = {
    "energies": test_energies,
    "forces": test_forces,
    "stresses": test_stress,
}

print(f"{len(train_structures)} training structures")
print(f"{len(test_structures)} test structures")

element_types = DEFAULT_ELEMENTS
converter = Atoms2Graph(element_types=element_types, cutoff=5.0)

train_dataset = MGLDataset(
    threebody_cutoff=4.0, structures=train_structures, converter=converter, labels=train_labels, include_line_graph=True
)
test_dataset = MGLDataset(
    threebody_cutoff=4.0, structures=test_structures, converter=converter, labels=test_labels, include_line_graph=True
)

train_data, val_data = split_dataset(
    train_dataset,
    frac_list=[0.9, 0.1],
    shuffle=True,
    random_state=42,
)

my_collate_fn = partial(collate_fn_pes, include_line_graph=True)

train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_dataset,
    collate_fn=my_collate_fn,
    batch_size=8,
    num_workers=0,
)

m3gnet_nnp = matgl.load_model("M3GNet-MP-2021.2.8-PES")
model_pretrained = m3gnet_nnp.model

element_refs = m3gnet_nnp.element_refs.property_offset
data_std = m3gnet_nnp.data_std

lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-3, include_line_graph=True,force_weight=1.0,stress_weight=0.1,element_refs=element_refs,data_std=data_std,decay_steps=100,decay_alpha=0.01)

# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg.
logger = CSVLogger("logs", name="M3GNet_finetuning")
trainer = pl.Trainer(max_epochs=200, accelerator="cuda", logger=logger, inference_mode=False)
trainer.fit(model=lit_module_finetune, train_dataloaders=train_loader, val_dataloaders=val_loader)

# save trained model
model_save_path = f"./{name}/"
lit_module_finetune.model.save(model_save_path)

train_metrics_dict = trainer.test(lit_module_finetune,train_loader)
val_metrics_dict = trainer.test(lit_module_finetune,val_loader)
test_metrics_dict = trainer.test(lit_module_finetune,test_loader)

headers = ['Dataset','Energy MAE eV/atom','Energy RMSE eV/atom', 'Force MAE eV/Ang','Force RMSE eV/Ang', 'Stress MAE GPa', 'Stress RMSE GPa']
keys = ['test_Energy_MAE','test_Energy_RMSE','test_Force_MAE','test_Force_RMSE','test_Stress_MAE','test_Stress_RMSE']

train_metrics = ['Train',*[train_metrics_dict[0][k] for k in keys]]
val_metrics = ['Val',*[val_metrics_dict[0][k] for k in keys]]
test_metrics = ['Test',*[test_metrics_dict[0][k] for k in keys]]

metrics_table = [headers,train_metrics,val_metrics,test_metrics]

tab = PrettyTable(metrics_table[0])
tab.add_rows(metrics_table[1:])
tab.float_format = "7.4"
print(tab)

m3gnet_ft = PESCalculator(lit_module_finetune.model)

m3gnet = PESCalculator(matgl.load_model("M3GNet-MP-2021.2.8-PES"))

print('TRAINING SET:')
eval_train(train_structures,train_energies,train_forces,train_stress)
print('TEST SET:')
eval_train(test_structures,test_energies,test_forces,test_stress)

For just Ni (still FT since I'm getting that error when I set a single element to train from scratch):

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       37.9139       |       38.7512        |      0.0668      |       0.1735      |     1.0705     |      2.4741     |
| M3GNET-FT |       983.5188      |       987.2318       |      0.2663      |       0.4768      |     2.6252     |      4.9839     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       37.7323       |       38.8390        |      0.0852      |       0.2082      |     1.0630     |      2.2040     |
| M3GNET-FT |       996.2045      |      1001.1609       |      0.2724      |       0.4962      |     2.9724     |      5.4261     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+

For the complete set of elements I'm interested in:

TRAINING SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       77.9317       |       130.5925       |      0.1562      |       0.2996      |     0.8172     |      1.7327     |
| M3GNET-FT |      1257.6639      |      1561.9301       |      0.6262      |       1.2701      |    10.4599     |     24.4154     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
TEST SET:
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   Model   | Energy MAE meV/atom | Energy RMSE meV/atom | Force MAE eV/Ang | Force RMSE eV/Ang | Stress MAE GPa | Stress RMSE GPa |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
|   M3GNET  |       80.5021       |       133.3436       |      0.1652      |       0.3093      |     0.8150     |      1.6303     |
| M3GNET-FT |      1252.5992      |      1562.3821       |      0.6642      |       1.3439      |    10.8342     |     26.0205     |
+-----------+---------------------+----------------------+------------------+-------------------+----------------+-----------------+
bfocassio commented 1 month ago

All my problems were solved by deleting ~/.dgl . The same can be achieved by using save_cache = False on MGLDataset and/or specifying a different location for the dgl cached files for each new experiment (using raw_dir and save_diron MGLDataset)

Would be good practice to set save_cache = False by default?

matthewkuner commented 1 week ago

Hi @bfocassio. Do you know why your stresses all were printed out as 0 from the code in this previous comment? https://github.com/materialsvirtuallab/matgl/issues/264#issuecomment-2122978210

bfocassio commented 1 week ago

That one example was a copy of one of the examples notebook, as you can see stress_weight is set to zero by default on PotentialLightningModule. From https://github.com/materialsvirtuallab/matgl/blob/5abfde6b90f73b80265b5364f229ff0d7fded9e8/src/matgl/utils/training.py#L333 if stress_weight is zero, then it is not calculated by the model.

matthewkuner commented 1 week ago

@bfocassio Oh my god... this is very helpful, thanks. I assume using a stress_weight of 1 is standard for those looking to train stresses?

bfocassio commented 1 week ago

Actually, if you check this example https://github.com/materialsvirtuallab/matgl/issues/264#issuecomment-2123133886, force_weight=1 and stress_weight=0.1 as in the original M3GNET

matthewkuner commented 1 week ago

Thanks again @bfocassio . One more question--in your original comment, you mentioned that including/excluding elemental_refs seemed to have no effect. Is that still the case after all of the bug-fixing you've done in this thread?

bfocassio commented 1 week ago

After fixing the bug it does have an effect since the model would have no reference energy stored if you do not provide it. So it is a good idea to include it.