Open t-vi opened 6 days ago
Not updating the shapes is odd; curious why this is an 'enhancement' vs. a real bug. marking triage review to discuss!
I don't object to having it labeled bug, but I don't think typically users will hit it today: For things needing gradients, the augmented forward pass will fix it. I'm not sure that Tensor Parallel sees heavy use in inference just yet. Among the things I found while doing #1164, it seems one of the milder issues.
Some transforms, notably and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.).
I thought the logic implemented with visitor_transform
, in https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L118, more specifically, https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L145 updates shapes automatically. but it isn't?
Yeah, @crcrpar, so this is why I mentioned better tooling, maybe using the visitor transform pattern more is the solution. Looking at the code, I would probably have the same expectation as you, but if I take out this
there seem to be inconsistencies. Maybe it is some other part (and I certainly anything I wrote will have the problem) that does not update shapes fully.
Some transforms, notably FSDP and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.). We might consider either making it easy for the transform to update the shapes on the fly or at the end of the transform by providing a mechanism similar / based on to the
interpret_trace_to_trace
introduced in #1164 (which would then become stricter by default).cc: @crcrpar
cc @carmocca @borda