I am learning MATGL to apply for the problems on my data. While examining the train data object from MGLDataLoader, I am getting a following keyerror. Will Appreciate if you can help me to resolve this issue.
Code snippet
from __future__ import annotations
import os
import shutil
import warnings
import numpy as np
#import pytorch_lightning as pl
import lightning.pytorch as pl
from functools import partial
from dgl.data.utils import split_dataset
from mp_api.client import MPRester
from pytorch_lightning.loggers import CSVLogger
import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes
from matgl.models import M3GNet
from matgl.utils.training import PotentialLightningModule
from matgl.config import DEFAULT_ELEMENTS
# To suppress warnings for clearer output
warnings.simplefilter("ignore")
mpr = MPRester(api_key="")
entries = mpr.get_entries_in_chemsys(["Si", "O"])
structures = [e.structure for e in entries]
energies = [e.energy for e in entries]
print(energies)
#print(baz)
forces = [np.zeros((len(s), 3)).tolist() for s in structures]
stresses = [np.zeros((3, 3)).tolist() for s in structures]
labels = {
"energies": energies,
"forces": forces,
"stresses": stresses,
}
print(f"{len(structures)} downloaded from MP.")
element_types = DEFAULT_ELEMENTS
print(f'Default element types: {element_types}')
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(
threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True
)
train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
my_collate_fn = partial(collate_fn_pes, include_line_graph=True)
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
collate_fn=my_collate_fn,
batch_size=2,
num_workers=0,
)
#Examine the train_loader as its giving errors below
for i, batch in enumerate(train_loader):
print(f"Batch {i}: {batch}")
if i >= 2: # Stop after a few batches for brevity
break
### Log output
```shell
e = torch.tensor([d["energies"] for d in labels]) # type: ignore
KeyError: 'energies'
Code of Conduct
[X] I agree to follow this project's Code of Conduct
Hi @GurjotDhaliwal, I am unable to reproduce this error with my python environment. I suggest that you build another new environment to install MatGL and test it again.
Email (Optional)
No response
Version
1.1.2
Which OS(es) are you using?
What happened?
I am learning MATGL to apply for the problems on my data. While examining the train data object from MGLDataLoader, I am getting a following keyerror. Will Appreciate if you can help me to resolve this issue.
Code snippet
Code of Conduct