atomistic-machine-learning / schnetpack

SchNetPack - Deep Neural Networks for Atomistic Systems
Other
751 stars 210 forks source link

Accessing properties in `.db` files, tutorials 1, 2. #633

Closed CalmScout closed 1 month ago

CalmScout commented 1 month ago

Dear maintainers of SchNetPack,

I am following the tutorials from your repo and adapting it to my data. From tutorial 1 we have a custom_data object of the class schnetpack.data.datamodule.AtomsDataModule that was generated for a set of SMILES and corresponding bioactivities labels like:

for smi, bio in zip(smiles, bioactivities):
    try:
        atoms_list.append(smi2ase(smi))
        properties = {'bioactivity': bio}
        property_list.append(properties)
    except Exception as e:
        # write to log
        with open("papyrus1k-exceptions.log", "a") as file:
            file.write(f"Exception {type(e)} {e} for smi: {smi}\nbio: {bio}\ntarget: {target}\n")

new_dataset = ASEAtomsData.create(
    str(path_to_db / f"{target}.db"), 
    distance_unit='Ang',
    property_unit_dict={'bioactivity':'bool'}
)

new_dataset.add_systems(property_list, atoms_list)

So, new_dataset supposedly has a "bioactivity" property. After saving it to the .db file and loading with the command:

custom_data = spk.data.AtomsDataModule(
    str(path_db), 
    batch_size=10,
    distance_unit='Ang',
    split_file=None, # supresses the warning message regarding the split sizes mismatch
    num_train=1500, # total number of samples is 1881
    num_val=381,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.CastTo32()
    ],
    load_properties=["bioactivity"],
    # property_units={'bioactivity':'bool'},
    num_workers=1,
    pin_memory=True, # set to false, when not using a GPU
)
custom_data.prepare_data()
custom_data.setup()

I am trying to get an access to the "bioactivity" property. How should I do that? I want to use this label as U0 is used in the tutorial 2 in the "Setting up the model" section.

I understand that the question probably should be addressed to the authors of ase package, but I appreciate it if you could provide any hint on how I can adapt your package (and tutorials in particular) to my data.

Best regards, Anton.

jnsLs commented 1 month ago

Dear Anton,

you can iterate through any subset of your data using the dataloader.

For the test set this could look like:

dataloader = custom_data.test_dataloader()

for sample_idx, sample in enumerate(tqdm(dataloader)):
    print(sample)

The fact that you defined bool as property unit might be a problem. I would recommend using an empty string "" instead. You should use those empty units when you define the dataset and when you load the dataset.

Best, Jonas

CalmScout commented 1 month ago

Dear Jonas,

I am figuring out how to use dataloaders in schnetpack. Meanwhile, while just copy paste the code you suggested above I received the following exception:

(pyg) (mambaforge) popova@mi001pc014:~/Projects/citre-benchmarking/scripts$ python tmp.py 
- bioactivity
  0%|          | 0/760 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/popova/Projects/citre-benchmarking/scripts/tmp.py", line 42, in <module>
    for sample_idx, sample in enumerate(tqdm(dataloader)):
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/tqdm/notebook.py", line 250, in __iter__
    for obj in it:
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
    data.reraise()
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/_utils.py", line 705, in reraise
    raise exception
AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/schnetpack/data/atoms.py", line 269, in __getitem__
    props = self._get_properties(
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/popova/.pyenv/versions/mambaforge/envs/pyg/lib/python3.12/site-packages/schnetpack/data/atoms.py", line 347, in _get_properties
    torch.tensor(row.data[pname].copy()) * self.conversions[pname]
                 ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'copy'

Changed the property unit from bool to an empty string as was suggested, thank you.

Best regards, Anton.

jnsLs commented 1 month ago

Hi Anton,

this probably means that your target (bioactivity) does not have the suitable dimensions. You could try to expand the properties before writing it to the dataset. It should be a list of np.arrays with dimension (1,)

Best, Jonas