lamoureux-lab / TorchProteinLibrary

PyTorch library of layers acting on protein representations
https://lamoureux-lab.github.io/TorchProteinLibrary/
MIT License
116 stars 23 forks source link

RMSD batch_size issue #31

Closed akabiraka closed 4 years ago

akabiraka commented 4 years ago

An example is given to use RMSD here.

If I change the batch_size > 1, it raises Exception: Coords2CenterFunction: forward Nan

Example code:

rmsd = RMSD.Coords2RMSD().cuda()
src = torch.randn(3, 4*3, dtype=torch.double, device = 'cuda')
ref = torch.randn(3, 4*3, dtype=torch.double, device = 'cuda')
num_atoms = torch.tensor([4], dtype=torch.int, device='cuda')
L = rmsd(src, ref, num_atoms)

Output exception:

File "models/rmsd_loss.py", line 52, in <module>
    L = rmsd(src, ref, num_atoms)
  File "path/to/python3_venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "path/to/New Volume/python3_venv/lib/python3.6/site-packages/TorchProteinLibrary-0.1-py3.6-linux-x86_64.egg/TorchProteinLibrary/RMSD/Coords2RMSD/Coords2RMSD.py", line 73, in forward
    input_center = self.c2c(input, num_atoms)
  File "path/to/New Volume/python3_venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "path/to/New Volume/python3_venv/lib/python3.6/site-packages/TorchProteinLibrary-0.1-py3.6-linux-x86_64.egg/TorchProteinLibrary/FullAtomModel/CoordsTransform/CoordsTransform.py", line 213, in forward
    return Coords2CenterFunction.apply(input_coords, num_atoms)
  File "path/to/New Volume/python3_venv/lib/python3.6/site-packages/TorchProteinLibrary-0.1-py3.6-linux-x86_64.egg/TorchProteinLibrary/FullAtomModel/CoordsTransform/CoordsTransform.py", line 180, in forward
    raise(Exception('Coords2CenterFunction: forward Nan'))
Exception: Coords2CenterFunction: forward Nan

Can anyone tell me what is going on? Any help is highly appreciated.

lupoglaz commented 4 years ago

Change "num_atoms = torch.tensor([4], dtype=torch.int, device='cuda')" to "num_atoms = torch.tensor([4,4,4], dtype=torch.int, device='cuda')"

akabiraka commented 4 years ago

Thanks a lot. I misunderstood. That was per batch atoms.