Open crcrpar opened 5 months ago
Hi Masaki! Thank you for this!
Marking the origin of instructions is indeed a topic we are keen to improve on. We recently merged #40 to add source locations to the initial trace. The next step in this is that @IvanYashchuk will look into propagating these, but your proposal (as I understand it) of instructions such as the ones inserted by distirbuted as originating from there would be a great complement to making it easier to track the source of instructions from user code in the interpretation phase.
so right after column_parallel
applied, a trace looks like this:
# Constructed by transform into column-wise tensor parallel
import thunder
import thunder.distributed.prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x, t_net1_bias, t_net1_weight, t_net2_bias, t_net2_weight):
# x: "cuda:0 f32[1, 12]"
# t_net1_bias: "cuda:0 f32[8]"
# t_net1_weight: "cuda:0 f32[8, 12]"
# t_net2_bias: "cuda:0 f32[8]"
# t_net2_weight: "cuda:0 f32[8, 16]"
t18 = thunder.distributed.prims.synchronize_tensor_parallel_input(x, _torch_distributed_distributed_c10d_ProcessGroup_0, _TensorParallelLayerType_1) # t18: "cuda:0 f32[1, 12]"
# /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116: return F.linear(input, self.weight, self.bias)
b = ltorch.linear(t18, t_net1_weight, t_net1_bias) # b: "cuda:0 f32[1, 16]"
# b = prims.linear(t18, t_net1_weight, t_net1_bias) # b: "cuda:0 f32[1, 16]"
t19 = thunder.distributed.prims.synchronize_tensor_parallel_output(b, _torch_distributed_distributed_c10d_ProcessGroup_0, _TensorParallelLayerType_1) # t19: "cuda:0 f32[1, 32]"
# /opt/pytorch/lightning-thunder/thunder/tests/distributed/helper.py:28: return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
result = ltorch.mul(0.5, t19) # result: "cuda:0 f32[1, 16]"
# result = prims.mul(0.5, t19) # result: "cuda:0 f32[1, 16]"
t7 = ltorch.pow(t19, 3.0) # t7: "cuda:0 f32[1, 16]"
# t7 = prims.pow(t19, 3.0) # t7: "cuda:0 f32[1, 16]"
t8 = ltorch.mul(0.044715, t7) # t8: "cuda:0 f32[1, 16]"
# t8 = prims.mul(0.044715, t7) # t8: "cuda:0 f32[1, 16]"
t9 = ltorch.add(t19, t8, alpha=None) # t9: "cuda:0 f32[1, 16]"
# t9 = prims.add(t19, t8) # t9: "cuda:0 f32[1, 16]"
t10 = ltorch.mul(0.7978845608028654, t9) # t10: "cuda:0 f32[1, 16]"
# t10 = prims.mul(0.7978845608028654, t9) # t10: "cuda:0 f32[1, 16]"
t11 = ltorch.tanh(t10) # t11: "cuda:0 f32[1, 16]"
# t11 = prims.tanh(t10) # t11: "cuda:0 f32[1, 16]"
t12 = ltorch.add(1.0, t11, alpha=None) # t12: "cuda:0 f32[1, 16]"
# t12 = prims.add(1.0, t11) # t12: "cuda:0 f32[1, 16]"
input = ltorch.mul(result, t12) # input: "cuda:0 f32[1, 16]"
# input = prims.mul(result, t12) # input: "cuda:0 f32[1, 16]"
# /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116: return F.linear(input, self.weight, self.bias)
t17 = ltorch.linear(input, t_net2_weight, t_net2_bias) # t17: "cuda:0 f32[1, 8]"
# t17 = prims.linear(input, t_net2_weight, t_net2_bias) # t17: "cuda:0 f32[1, 8]"
return t17
The bsym producing t19
does not have the source location info as you explain.
We'll need to decide if a source location is the right format for this, but calling trace.set_current_source_location(filename, positions)
would add this (but you need to set it to None, None afterwards) or add a ctx manager for doing so when you bind a new symbol.
I don't think the source location is the right format here. It should do one thing and I think it should be: showing the original location used to construct the initial trace.
There's also a BoundSymbol.header
that allows printing anything above the BoundSymbol in the generated Python code:
https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/symbol.py#L305-L306
When using BoundSymbol.__call__
to construct symbols a special context manager can be used to set the header:
https://github.com/Lightning-AI/lightning-thunder/blob/77d0fbd328d56d971ee6234c5593de979bf6c887/thunder/core/symbol.py#L54-L55
There is no example code for how to preserve header or source location information best yet. Masaki, do you have ideas for this?
"source location" for bsyms added by visit_transform
or whatever doesn't feel ideal.
I want more like a mechanism that calls a sequence of torch ops (if fusion executors fail to optimize) in a bsym with the name of a primitive especially when looking at this block
p16 = torch_all_reduce_prim_impl(t15, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True, True) # p16: "FUTURE cuda:0 f32[1, 8]"
del t15
t17 = torch_wait_prim_impl(p16) # t17: "cuda:0 f32[1, 8]"
del p16
t29 = torch.unsqueeze(t_net2_bias, 0) # t29: "cuda:0 f32[1, 8]"
# t29 = ltorch.unsqueeze(t_net2_bias, 0) # t29: "cuda:0 f32[1, 8]"
# t29 = prims.broadcast_in_dim(t_net2_bias, [1, 8], [1]) # t29: "cuda:0 f32[1, 8]"
t18 = Tensor.expand(t29, (1, 8)) # t18: "cuda:0 f32[1, 8]"
# t18 = ltorch.expand(t29, (1, 8)) # t18: "cuda:0 f32[1, 8]"
# t18 = prims.broadcast_in_dim(t29, (1, 8), (0, 1)) # t18: "cuda:0 f32[1, 8]"
del t29
[t19] = nvFusion1(t17, t18)
# t19 = prims.add(t17, t18) # t19: "cuda:0 f32[1, 8]"
this would be much clearer if t19 = synchronize_tensor_parallel_output(...
) with subsymbols defining what it does
🚀 Feature
When we run the following snippet with
torchrun --nproc-per-node=2 sample.py
, we'd get the trace below, as of c89fb1bc5114161aed78e9eb19b8f68df523d192.In this trace, some
BoundSymbol
s are derived from tensor parallel pre/post-processing. In my humble opinion the trace would look easier to interpret if the corresponding bsyms have a label or a name indicating they are pre/post-processing.One example would be:
If these lines are easy to guess they are post-processing of row-wise parallel linear, I'd say the trace is more perspicuous. An alternative is that we call a method named
row_linear_postprocessing
in a visitor_transform and a bsym would just call it.Motivation
I've had experienced an appetite that I'd like to have a
BoundSymbol
call a sequence of primitive ops without decomposition / executor-specific optimizations, mainly for the sake of readability/clarity of it in a trace.Pitch
Alternatives
Additional context