Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

improve shape accuracy in transform output by providing better tooling #1179

Open t-vi opened 6 days ago

t-vi commented 6 days ago

Some transforms, notably FSDP and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.). We might consider either making it easy for the transform to update the shapes on the fly or at the end of the transform by providing a mechanism similar / based on to the interpret_trace_to_trace introduced in #1164 (which would then become stricter by default).

cc: @crcrpar

cc @carmocca @borda

tfogal commented 6 days ago

Not updating the shapes is odd; curious why this is an 'enhancement' vs. a real bug. marking triage review to discuss!

t-vi commented 6 days ago

I don't object to having it labeled bug, but I don't think typically users will hit it today: For things needing gradients, the augmented forward pass will fix it. I'm not sure that Tensor Parallel sees heavy use in inference just yet. Among the things I found while doing #1164, it seems one of the milder issues.

crcrpar commented 22 hours ago

Some transforms, notably and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.).

I thought the logic implemented with visitor_transform, in https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L118, more specifically, https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L145 updates shapes automatically. but it isn't?

t-vi commented 18 hours ago

Yeah, @crcrpar, so this is why I mentioned better tooling, maybe using the visitor transform pattern more is the solution. Looking at the code, I would probably have the same expectation as you, but if I take out this

https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/core/trace_interpreter.py#L120-L127

there seem to be inconsistencies. Maybe it is some other part (and I certainly anything I wrote will have the problem) that does not update shapes fully.