Open nikitaved opened 3 months ago
It's an inherent feature of CUDA Graphs to be restricted to static code. Newer CUDA Toolkit versions have dynamic control flow feature, but it's unavailable to use in PyTorch for now https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
What additional safety does torch.cuda.make_graphed_callables
provide?
No additional safety, I got it wrong. What I mean is that we can indeed make it a transform to potentially decide which parts are safe to capture.
triage review:
@tfogal , the executor is in good shape, just not complete. This means, there is no "advanced" logic on handling data-dep operations and fusion regions between graph breaks (with dynamic shapes, sometimes we can put a tensor into a fusion, sometimes we have to opt-out). This bit was intentionally left in this state to collect issues and shape our understanding of how to handle them in our use cases... To summarize, the original issue is still present, but we can fix it, at least partially, by modifying the fusion logic in the current executor.
🐛 Bug
With https://github.com/Lightning-AI/lightning-thunder/pull/430 merged we enable the usage of
thunder.jit(..., use_cudagraphs=True)
. This makes sure that the forward callable is wrapped intothunder.cudagraphs.CUDAGraphExecutor
. This executor, however, assumes a static structure of the code. We might considertorch.cuda.make_graphed_callables
as a safer option.