atomistic-machine-learning / schnetpack

SchNetPack - Deep Neural Networks for Atomistic Systems
Other
774 stars 214 forks source link

Warning for TorchEnvironmentProvider #384

Closed omidshy closed 2 years ago

omidshy commented 2 years ago

I receive this warning:

UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.

for this line:

https://github.com/atomistic-machine-learning/schnetpack/blob/cf962385e32e7335ce3b3f34cd776a4b38c3d985/src/schnetpack/environment.py#L217

The following would probably fix it:

num_repeats = torch.where(pbc.bool(), num_repeats, torch.zeros_like(num_repeats))
stefaanhessmann commented 2 years ago

Hey @omidshy , thanks for your message and sorry for the late reply. You are right about the warning and the fix works out well. I will update the master branch!