materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
274 stars 64 forks source link

[Bug]: Issue in MGLDataLoader #349

Closed GurjotDhaliwal closed 2 months ago

GurjotDhaliwal commented 2 months ago

Email (Optional)

No response

Version

1.1.2

Which OS(es) are you using?

What happened?

I am learning MATGL to apply for the problems on my data. While examining the train data object from MGLDataLoader, I am getting a following keyerror. Will Appreciate if you can help me to resolve this issue.

Code snippet

from __future__ import annotations

import os
import shutil
import warnings

import numpy as np
#import pytorch_lightning as pl
import lightning.pytorch as pl
from functools import partial
from dgl.data.utils import split_dataset
from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule
from matgl.config import DEFAULT_ELEMENTS

# To suppress warnings for clearer output
warnings.simplefilter("ignore")

mpr = MPRester(api_key="")
entries = mpr.get_entries_in_chemsys(["Si", "O"])
structures = [e.structure for e in entries]
energies = [e.energy for e in entries]
print(energies)
#print(baz)
forces = [np.zeros((len(s), 3)).tolist() for s in structures]
stresses = [np.zeros((3, 3)).tolist() for s in structures]
labels = {
    "energies": energies,
    "forces": forces,
    "stresses": stresses,
}

print(f"{len(structures)} downloaded from MP.")

element_types = DEFAULT_ELEMENTS
print(f'Default element types: {element_types}')
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(
    threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True
)

train_data, val_data, test_data = split_dataset(
    dataset,
    frac_list=[0.8, 0.1, 0.1],
    shuffle=True,
    random_state=42,
)

my_collate_fn = partial(collate_fn_pes, include_line_graph=True)
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=my_collate_fn,
    batch_size=2,
    num_workers=0,
)

#Examine the train_loader as its giving errors below
for i, batch in enumerate(train_loader):
    print(f"Batch {i}: {batch}")
    if i >= 2:  # Stop after a few batches for brevity
        break

### Log output

```shell
e = torch.tensor([d["energies"] for d in labels])  # type: ignore
KeyError: 'energies'

Code of Conduct

kenko911 commented 2 months ago

Hi @GurjotDhaliwal, I am unable to reproduce this error with my python environment. I suggest that you build another new environment to install MatGL and test it again.