atomistic-machine-learning / schnetpack

SchNetPack - Deep Neural Networks for Atomistic Systems
Other
751 stars 210 forks source link

Problem with the "Preparing your own data" tutorial #609

Closed mariofp77 closed 4 months ago

mariofp77 commented 4 months ago

I am getting the following error at the end of the second part of the first tutorial (https://github.com/atomistic-machine-learning/schnetpack/blob/master/examples/tutorials/tutorial_01_preparing_data.ipynb):

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[13], line 20
      4 custom_data = spk.data.AtomsDataModule(
      5     './new_dataset.db', 
      6     batch_size=10,
   (...)
     17     pin_memory=True, # set to false, when not using a GPU
     18 )
     19 custom_data.prepare_data()
---> 20 custom_data.setup()

File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/datamodule.py:198, in AtomsDataModule.setup(self, stage)
    196 self._val_dataset = self.dataset.subset(self.val_idx)
    197 self._test_dataset = self.dataset.subset(self.test_idx)
--> 198 self._setup_transforms()

File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/datamodule.py:338, in AtomsDataModule._setup_transforms(self)
    336 def _setup_transforms(self):
    337     for t in self.train_transforms:
--> 338         t.datamodule(self)
    339     for t in self.val_transforms:
    340         t.datamodule(self)

File ~/anaconda3/lib/python3.9/site-packages/schnetpack/transform/atomistic.py:126, in RemoveOffsets.datamodule(self, _datamodule)
    123     self.atomref = atrefs[self._property].detach()
    125 if self.remove_mean and not self._mean_initialized:
--> 126     stats = _datamodule.get_stats(
    127        self._property, self.is_extensive, self.remove_atomrefs
    128    )
    129     self.mean = stats[0].detach()

File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/datamodule.py:354, in AtomsDataModule.get_stats(self, property, divide_by_atoms, remove_atomref)
    351 if key in self._stats:
    352     return self._stats[key]
--> 354 stats = calculate_stats(
    355    self.train_dataloader(),
    356    divide_by_atoms={property: divide_by_atoms},
    357    atomref=self.train_dataset.atomrefs if remove_atomref else None,
    358 )[property]
    359 self._stats[key] = stats
    360 return stats

File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/stats.py:44, in calculate_stats(dataloader, divide_by_atoms, atomref)
     41 mean = torch.zeros_like(norm_mask)
     42 M2 = torch.zeros_like(norm_mask)
---> 44 for props in tqdm(dataloader):
     45     sample_values = []
     46     for p in property_names:

File ~/anaconda3/lib/python3.9/site-packages/tqdm/std.py:1182, in tqdm.__iter__(self)
   1179 time = self._time
   1181 try:
-> 1182     for obj in iterable:
   1183         yield obj
   1184         # Update and possibly print the progressbar.
   1185         # Note: does not call self.update(1) for speed optimisation.

File ~/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File ~/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
   1343 else:
   1344     del self._task_info[idx]
-> 1345     return self._process_data(data)

File ~/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
   1369 self._try_put_index()
   1370 if isinstance(data, ExceptionWrapper):
-> 1371     data.reraise()
   1372 return data

File ~/anaconda3/lib/python3.9/site-packages/torch/_utils.py:694, in ExceptionWrapper.reraise(self)
    690 except TypeError:
    691     # If the exception takes multiple arguments, don't try to
    692     # instantiate since we don't know how to
    693     raise RuntimeError(msg) from None
--> 694 raise exception

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/investigator/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/investigator/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/investigator/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/investigator/anaconda3/lib/python3.9/site-packages/schnetpack/data/atoms.py", line 269, in __getitem__
    props = self._get_properties(
  File "/home/investigator/anaconda3/lib/python3.9/site-packages/schnetpack/data/atoms.py", line 339, in _get_properties
    row = conn.get(idx + 1)
  File "/home/investigator/anaconda3/lib/python3.9/site-packages/ase/db/core.py", line 432, in get
    raise KeyError('no match')
KeyError: 'no match'

The notebook has not been altered, the version of SchNetPack used corresponds to the latest one in the repository and the ASE version is 3.22.1.

jnsLs commented 4 months ago

Dear @mariofp77 ,

I cannot reproduce your error. Could you please try again running the notebook after deleting all the remaining files associated with the custom dataset in the tutorials directory?

md17_uracil.npz new_dataset.db split.npz splitting.lock

Best, Jonas

mariofp77 commented 4 months ago

Dear @jnsLs,

Many thanks for your response. I have tried removing the files you mention and then rerunning again the notebook and I get the same error again.

Best, Mario

jnsLs commented 4 months ago

I think I found the cause of the error. custom_data = spk.data.AtomsDataModule(...)` loads the split file which was defined for the qm9 dataset in one of the previous cells. The split file of qm9 contains 110k training indices (the uracil data does not contain as many). Hence for one of the larger indices we get an error.

Could you please try to verify this by running the notebook again and deleting the split file before running the very last cell?

mariofp77 commented 4 months ago

Yes, doing that it works. Many thanks!

jnsLs commented 4 months ago

Thank you, Mario. We will adapt the tutorial soon. To avoid this error. Best, Jonas