fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.22k stars 520 forks source link

Does MSDeformable support torch.compile? #226

Open Lwzzzzzz opened 7 months ago

Lwzzzzzz commented 7 months ago

Initially, I encountered problems using the Whole network capture of torch.cuda.graph and was unable to run smoothly.

So I switched to Partial network capture for modules in the network, but once I applied torch.cuda.make_graphed_callables() to modules containing MSDeformableAttention, an error occurred. I realized that this was not feasible, so I had to manually bypass the module containing MSDeformableAttention.

Recently, pytorch provided torch.compile, which is said to automatically avoid such modules, but I still encountered errors. Is there a way to make MSDeformableAttention support torch.compile?