fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.25k stars 523 forks source link

Mixed precision #64

Open Lg955 opened 3 years ago

Lg955 commented 3 years ago

I want to use the Mixed precision(from torch.cuda.amp import autocast, GradScaler) when training the model, but get the error :"

File "/dataset/Deformable_DETR_mix/models/deformable_transformer.py", line 221, in forward
    src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
  File "/home/anaconda3/envs/detr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dataset/Deformable_DETR_mix/models/ops/modules/ms_deform_attn.py", line 113, in forward
    value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
  File "/dataset/Deformable_DETR_mix/models/ops/functions/ms_deform_attn_func.py", line 26, in forward
    value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
RuntimeError: "ms_deform_attn_forward_cuda" not implemented for 'Half'

I want to modify it, but it relates to Cuda code, has anyone encountered the same problem ?

gautamsreekumar commented 3 years ago

Yes, I tried to do the same. I found here that you could change AT_DISPATCH_FLOATING_TYPES to AT_DISPATCH_FLOATING_TYPES_AND_HALF. But I am getting a different error now.

File "ms_deform_attn_func.py", line 26, in forward
    value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
RuntimeError: expected scalar type Half but found Float
noahcao commented 3 years ago

I'm in the same problem. Wanted to know if anyone has fixed this to support the pytorch mixed precision training?

Lg955 commented 3 years ago

@noahcao sorry, I DON'T fix it

Lg955 commented 3 years ago

@gautamsreekumar the same error, I have given up it …>_<…

zyong812 commented 2 years ago

An easy way is to disable mixed precision for custom operations , see @custom_fwd & @custom_bwd in https://pytorch.org/docs/stable/notes/amp_examples.html

Wastoon commented 1 year ago

Is there anyone who can solve this hard question? 20230412 help!!!!