lamoureux-lab / TorchProteinLibrary

PyTorch library of layers acting on protein representations
https://lamoureux-lab.github.io/TorchProteinLibrary/
MIT License
116 stars 23 forks source link

Transposing without explicitly setting contiguous causes a bug in gradient calculation. #21

Closed egeozsoy closed 5 years ago

egeozsoy commented 5 years ago

Transposing without explicitly setting contiguous causes the backwards to misbehave. Tested many times. Nothing else was changed between the two tests.

Example: Without contiguous: angles = x.transpose(1, 2) L = torch.zeros(1, dtype=torch.int, device="cuda").fill_(angles.shape[2]) coords = a2b(angles, L)

Loss after every 10 iteration: 10:, 29297735.766414236 20:, 29420032.493838556 30:, 29592217.76077364 40:, 29819240.597087096 50:, 30089326.541265268

With contiguous: angles = x.transpose(1, 2). contiguous() L = torch.zeros(1, dtype=torch.int, device="cuda").fill_(angles.shape[2]) coords = a2b(angles, L)

Loss after every 10 iteration: 10:, 67705071.70907234 20:, 65147285.364302136 30:, 62598781.50298739 40:, 60059446.29427992 50:, 57532192.84414939

lupoglaz commented 5 years ago

Thank you for noticing this issue. Could you please tell me:

  1. your pytorch version
  2. the repository branch you use (dev or release)
  3. How x is initialized: ?? angles = x.transpose(1, 2)

THank you in advance, I'll try my best to resolve this issue asap.

egeozsoy commented 5 years ago

Pytorch version: 1.0

Repo version: Because this repository doesn't support pytorch 1.0, we forked it and made some tweaks https://github.com/johahi/TorchProteinLibrary. We are using this version, but we wanted to issue it as a bug report here because we think it is not related to our small tweaks to make it work with pytorch 1.0. If you can't reproduce the bug, it might be.

x = torch.zeros(1, int(num_backbone_atoms / 3), 2, dtype=torch.float, device='cuda').requiresgrad() So x was initialized according to the example given in the TPL documentation. The only different was that, the shape is different than the example(which was (1, 2, int(num_backbone_atoms / 3)) We changed the shape intentionally so a transpose call would be necessary. This was done to demonstrate the bug which only happens if it is transposed.

lupoglaz commented 5 years ago

Commit a802e2cccee9fecccace902459f29f1d7771f066 Now checking types, device and contiguous state of tensors.

import torch
import TorchProteinLibrary
from TorchProteinLibrary import ReducedModel
a2b = ReducedModel.Angles2Backbone()
length = 40
angles = torch.randn(1, length, 3, dtype=torch.float, device='cuda')
num_aa = torch.tensor([length], dtype=torch.int, device='cuda')
x = angles.transpose(1, 2)
out = a2b(x, num_aa)

Traceback (most recent call last): File "", line 1, in File "/home/lupoglaz/anaconda3/envs/tf/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/home/lupoglaz/Projects/MILA/TorchProteinLibrary/TorchProteinLibrary/ReducedModel/Angles2Backbone/Angles2Backbone.py", line 103, in forward return Angles2BackboneGPUFunction.apply(input.to(dtype=torch.float32), angles_length) File "/home/lupoglaz/Projects/MILA/TorchProteinLibrary/TorchProteinLibrary/ReducedModel/Angles2Backbone/Angles2Backbone.py", line 23, in forward _ReducedModel.Angles2BackboneGPU_forward( input, output_coords_gpu, angles_length, ctx.A) RuntimeError: input_angles.is_contiguous() ASSERT FAILED at Layers/ReducedModel/Angles2Backbone/angles2backbone_interface.cpp:12, please report a bug to PyTorch. input_angles must be contiguous (Angles2BackboneGPU_forward at Layers/ReducedModel/Angles2Backbone/angles2backbone_interface.cpp:12) frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fbb38184dc5 in /home/lupoglaz/anaconda3/envs/tf/lib/python3.6/site-packages/torch/lib/libc10.so) frame #1: Angles2BackboneGPU_forward(at::Tensor, at::Tensor, at::Tensor, at::Tensor) + 0x7d8 (0x7fbb216685e8 in /home/lupoglaz/anaconda3/envs/tf/lib/python3.6/site-packages/TorchProteinLibrary-0.1-py3.6-linux-x86_64.egg/_ReducedModel.cpython-36m-x86_64-linux-gnu.so) frame #2: + 0x2e384 (0x7fbb21671384 in /home/lupoglaz/anaconda3/envs/tf/lib/python3.6/site-packages/TorchProteinLibrary-0.1-py3.6-linux-x86_64.egg/_ReducedModel.cpython-36m-x86_64-linux-gnu.so) frame #3: + 0x32eba (0x7fbb21675eba in /home/lupoglaz/anaconda3/envs/tf/lib/python3.6/site-packages/TorchProteinLibrary-0.1-py3.6-linux-x86_64.egg/_ReducedModel.cpython-36m-x86_64-linux-gnu.so)

frame #10: THPFunction_apply(_object*, _object*) + 0x691 (0x7fbb67410081 in /home/lupoglaz/anaconda3/envs/tf/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #36: __libc_start_main + 0xe7 (0x7fbb76758b97 in /lib/x86_64-linux-gnu/libc.so.6)
out = a2b(x.contiguous(), num_aa)
print(out.size())

torch.Size([1, 360])