choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
9 stars 4 forks source link

Use diagonal batching in TorchDataset #27

Closed ArnNag closed 8 months ago

ArnNag commented 9 months ago

Description

Currently, TorchDataset reads from an npz file that contains arrays of shape (n_conformers, n_atoms, *) for atom features. The atom feature array is padded to a constant n_atoms. This PR will revise HDF5Dataset to instead store an atom feature array of shape (n_atoms, *) and an array that tracks the number of atoms per conformer in an array of shape (n_conformers, ). TorchDataset.__getitem__ will still return an array of atom features for one conformer of shape (n_atoms, *), but it will do so by computing index ranges into the (n_atoms, *)atom feature array rather than direct indexing. Conformer features will still be stored in arrays of shape (n_conformers, *) as before.

Todos

Status

ArnNag commented 9 months ago

The failing SchNet tests are because diagonal batching is not implemented within SchNet. test_dataset.py needs a test for a dataset with an arbitrary number of conformers in each record (e.g. SPICE). The existing tests were passing with incorrect implementations due to QM9 having one conformer per record.