Initially, I encountered problems using the Whole network capture of torch.cuda.graph and was unable to run smoothly.
So I switched to Partial network capture for modules in the network, but once I applied torch.cuda.make_graphed_callables() to modules containing MSDeformableAttention, an error occurred. I realized that this was not feasible, so I had to manually bypass the module containing MSDeformableAttention.
Recently, pytorch provided torch.compile, which is said to automatically avoid such modules, but I still encountered errors. Is there a way to make MSDeformableAttention support torch.compile?
Initially, I encountered problems using the Whole network capture of
torch.cuda.graph
and was unable to run smoothly.So I switched to Partial network capture for modules in the network, but once I applied
torch.cuda.make_graphed_callables()
to modules containingMSDeformableAttention
, an error occurred. I realized that this was not feasible, so I had to manually bypass the module containingMSDeformableAttention
.Recently, pytorch provided
torch.compile
, which is said to automatically avoid such modules, but I still encountered errors. Is there a way to makeMSDeformableAttention
supporttorch.compile
?