IDEA-Research / detrex

detrex is a research platform for DETR-based object detection, segmentation, pose estimation and other visual recognition tasks.
https://detrex.readthedocs.io/en/latest/
Apache License 2.0
2.01k stars 209 forks source link

[Features] Support FP16 training #198

Closed rentainhe closed 1 year ago

rentainhe commented 1 year ago

TODO

Note

For MultiScaleDeformableAttention, we simply convert the input value to torch.float32 and convert the output from torch.float32 to torch.float16, which means we skip fp16 and conduct fp32 computation in MultiScaleDeformableAttention operator.

Usage

start fp16 training with train.amp.enabled:

python tools/train_net.py \
    --config-file projects/dab_detr/configs/dab_detr_r50_50ep.py \
    --num-gpus 8 \
    train.amp.enabled=True
FabianSchuetze commented 1 year ago

Thanks for this commit!

Did you observe instabilities when using the deformable attention layer with fp16? Is there another reason why the deformable attention layer cannot be used with fp16?