Closed C-K-Loan closed 1 year ago
Hi @C-K-Loan , you are right. The database was not formatted correctly. I added a pull request to fix this.
For now, this code should fix your databases:
import os
import shutil
import tempfile
from tqdm import tqdm
import numpy as np
from ase.db import connect
from schnetpack.datasets import ISO17
data_path = "iso17.db"
# fix databases
tmpdir = tempfile.mkdtemp("iso17")
for fold in ISO17.existing_folds:
dbpath = os.path.join(data_path, "iso17", fold + ".db")
tmp_dbpath = os.path.join(tmpdir, "tmp.db")
with connect(dbpath) as conn:
with connect(tmp_dbpath) as tmp_conn:
tmp_conn.metadata = {
"_property_unit_dict": {ISO17.energy: "eV", ISO17.forces: "eV/Ang"},
"_distance_unit": "Ang",
"atomrefs": {},
}
# add energy to data dict in db
for idx in tqdm(range(len(conn)), f"parsing database file {dbpath}"):
atmsrw = conn.get(idx + 1)
data = atmsrw.data
data[ISO17.forces] = np.array(data[ISO17.forces])
data[ISO17.energy] = np.array([atmsrw.total_energy])
tmp_conn.write(atmsrw.toatoms(), data=data)
os.remove(dbpath)
os.rename(tmp_dbpath, dbpath)
shutil.rmtree(tmpdir)
# check database
from schnetpack.transform import ASENeighborList
iso17data = ISO17(
'./iso17.db',
fold = 'reference_eq',
batch_size=10,
num_train=11000,
num_val=10000,
transforms=[ASENeighborList(cutoff=5.)]
)
iso17data.prepare_data()
iso17data.setup()
print('Number of reference calculations:', len(iso17data.dataset))
print('Number of train data:', len(iso17data.train_dataset))
print('Number of validation data:', len(iso17data.val_dataset))
print('Number of test data:', len(iso17data.test_dataset))
print('Available properties:')
for p in iso17data.dataset.available_properties:
print('-', p)
Let me know if this solves the issue!
This fixes the issue thank you @Stefaanhess
Looks like there is a bug for loading iso17data to torch tensors https://colab.research.google.com/drive/1LkySYJWJoFScx1w4DKD9SJ5XOXtsmYU1?usp=sharing
throws