lucidrains / invariant-point-attention

Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module
MIT License
148 stars 11 forks source link

In-place rotation detach not allowed #9

Closed sidnarayanan closed 1 year ago

sidnarayanan commented 2 years ago

Hi, this is probably highly version-dependent (I have pytorch=1.11.0, pytorch3d=0.7.0 nightly), but I thought I'd report it. Torch doesn't like the in-place detach of the rotation tensor. Full stack trace (from denoise.py):

Traceback (most recent call last):
  File "denoise.py", line 56, in <module>
    denoised_coords = net(
  File "/home/pi-user/miniconda3/envs/piai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pi-user/invariant-point-attention/invariant_point_attention/invariant_point_attention.py", line 336, in forward
    rotations.detach_()
RuntimeError: Can't detach views in-place. Use detach() instead. If you are using DistributedDataParallel (DDP) for training, and gradient_as_bucket_view is set as True, gradients are views of DDP buckets, and hence detach_() cannot be called on these gradients. To fix this error, please refer to the Optimizer.zero_grad() function in torch/optim/optimizer.py as the solution.

Switching to rotations = rotations.detach() seems to behave correctly (tested in denoise.py and my own code). I'm not totally sure if this allocates a separate tensor, or just creates a new node pointing to the same data.

lucidrains commented 1 year ago

@sidnarayanan oh oops, should be fixed!