antoyang / TubeDETR

[CVPR 2022 Oral] TubeDETR: Spatio-Temporal Video Grounding with Transformers
Apache License 2.0
171 stars 8 forks source link

an inplace operation #22

Open zhl98 opened 1 year ago

zhl98 commented 1 year ago

Hello, I encountered this problem during the training process. Do you know where the problem is? The dimension of torch. cuda. LongTensor is [1,25]

image

antoyang commented 1 year ago

Hi, can you give more context on the issue so that I can help you?

SkylerSuen commented 1 year ago

Hello, I encountered this problem during the training process. Do you know where the problem is? The dimension of torch. cuda. LongTensor is [1,25]

image

Hi, I have the same problem, did you figure out how to fix this?

Infinitywxh commented 11 months ago

I encountered a similar error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 13]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I set torch.autograd.set_detect_anomaly(True), and get the following information:


/scratch/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in EmbeddingBackward0. Traceback of forward call that caused the error:
  File "main_new.py", line 651, in <module>
    main(args)
  File "main_new.py", line 591, in main
    train_stats = train_one_epoch(
  File "/scratch/TubeDETR-main/engine.py", line 67, in train_one_epoch
    memory_cache = model(
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/TubeDETR-main/models/tubedetr.py", line 190, in forward
    memory_cache = self.transformer(
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/TubeDETR-main/models/transformer.py", line 256, in forward
    encoded_text = self.text_encoder(**tokenized)
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 828, in forward
    embedding_output = self.embeddings(
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py", line 126, in forward
    token_type_embeddings = self.token_type_embeddings(token_type_ids)
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "main_new.py", line 651, in <module>
    main(args)
  File "main_new.py", line 591, in main
    train_stats = train_one_epoch(
  File "/scratch/TubeDETR-main/engine.py", line 148, in train_one_epoch
    losses.backward()
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/scratch/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 13]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
furkancoskun commented 5 months ago

I encountered the same problem. Is there anyone who solved it?

furkancoskun commented 5 months ago

I solved the problem with adding broadcast_buffers=False to torch.nn.parallel.DistributedDataParallel

change main.py line 373 as following

        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True, broadcast_buffers=False,
        )