facebookresearch / 3detr

Code & Models for 3DETR - an End-to-end transformer model for 3D object detection
Apache License 2.0
618 stars 76 forks source link

Doesn't Work with Pytorch 1.10 #12

Closed stanleyshly closed 2 years ago

stanleyshly commented 2 years ago

With Pytorch 1.9, I get no errors. However, with Pytorch 1.10, I get this error. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 1, 256]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Even though Pytorch 1.10 isn't supported, I was wondering, did any large behavior change happen between pytorch 1.9 and 1.10? It seems odd that this error won't have been raised with 1.9.

imisra commented 2 years ago

Hi @stanleyshly

I haven't tested the code with any PyTorch older than 1.5. I'm not sure why this is happening.

stanleyshly commented 2 years ago

@imisra I located the cause of the error above, it was an inplace error in the dropout layer. I created a pull request that "fixes" the issue when running with PyTorch 1.10. However, it is slower unfortunately by a not insignificant amount(1.5x to 2x time per batch), it appears to be an error with inplace=True, I set to inplace=False, and it works just okay.

mjlbach commented 2 years ago

This happens with the latest pytorch (1.10), not an older version. The exact error (with set_detect_anomaly true) is (minor changes because I broke apart a chained layer call to isolate which layer was causing the issue):

``` c[W python_anomaly_mode.cpp:104] Warning: Error detected in ReluBackward0. Traceback of forward call that caused the error: File "main.py", line 427, in launch_distributed(args) File "main.py", line 415, in launch_distributed main(local_rank=0, args=args) File "main.py", line 400, in main do_train( File "main.py", line 176, in do_train aps = train_one_epoch( File "/root/Repositories/3detr/engine.py", line 89, in train_one_epoch outputs = model(inputs) File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/root/Repositories/3detr/models/model_3detr.py", line 332, in forward box_features = self.decoder( File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/root/Repositories/3detr/models/transformer.py", line 118, in forward output, attn = layer(output, memory, tgt_mask=tgt_mask, File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/root/Repositories/3detr/models/transformer.py", line 403, in forward return self.forward_pre(tgt, memory, tgt_mask, memory_mask, File "/root/Repositories/3detr/models/transformer.py", line 386, in forward_pre tgt4 = self.activation(tgt3) File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 98, in forward return F.relu(input, inplace=self.inplace) File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/nn/functional.py", line 1299, in relu result = torch.relu(input) (function _print_stack) Traceback (most recent call last): File "main.py", line 427, in launch_distributed(args) File "main.py", line 415, in launch_distributed main(local_rank=0, args=args) File "main.py", line 400, in main do_train( File "main.py", line 176, in do_train aps = train_one_epoch( File "/root/Repositories/3detr/engine.py", line 101, in train_one_epoch loss.backward(retain_graph=True) File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/root/.local/mambaforge/envs/3detr/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward Variable._execution_engine.run_backward( RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 8, 256]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck! ```

See some generic discussion about this: https://discuss.pytorch.org/t/solved-pytorch1-5-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/90256/17

Similar-ish issue fixed by hugging face: https://github.com/huggingface/transformers/pull/13613

I think it's due to this issue/PR in pytorch 1.10 https://github.com/pytorch/pytorch/pull/63089 https://github.com/pytorch/pytorch/issues/63027

There are also some trivial changes to pointnet required to get it to compile with the latest pytorch/cuda.

I saw @stanleyshly posted in the pytorch forums and there is some relevant discussion there: https://discuss.pytorch.org/t/inplace-errors-with-dropout-layers-with-pytorch-1-9-but-not-with-pytorch-1-10/137544