Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.19k stars 80 forks source link

[RFC] Option to make a trace easier to interpret #507

Open crcrpar opened 5 months ago

crcrpar commented 5 months ago

🚀 Feature

When we run the following snippet with torchrun --nproc-per-node=2 sample.py, we'd get the trace below, as of c89fb1bc5114161aed78e9eb19b8f68df523d192.

import os

import torch
from torch.distributed import distributed_c10d as c10d

import thunder
from thunder.tests.distributed.helper import ToyModel
from thunder.distributed import column_parallel, row_parallel

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device(f"cuda:{local_rank}")
    c10d.init_process_group()

    model = ToyModel(bias=True).to(device)
    if local_rank == 0:
        print("# Convert `net1` and `net2` into column-wise parallel and row-wise parallel, respectively")
        print(model)
    tp_model = thunder.jit(model)
    tp_model = column_parallel(tp_model, ["net1"])
    tp_model = row_parallel(tp_model, ["net2"])

    x = torch.randn((1, ToyModel.N_IN), device=device)
    tp_model(x)

    if local_rank == 0:
        fw_extrace = thunder.last_traces(tp_model)[-1]

        print(fw_extrace)

    c10d.destroy_process_group()

if __name__ == "__main__":
    main()
# Convert `net1` and `net2` into column-wise parallel and row-wise parallel, respectively
# ToyModel(
#   (net1): Linear(in_features=12, out_features=16, bias=True)
#   (net2): Linear(in_features=16, out_features=8, bias=True)
# )

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, t_net1_bias, t_net1_weight, t_net2_bias, t_net2_weight):
  # x: "cuda:0 f32[1, 12]"
  # t_net1_bias: "cuda:0 f32[8]"
  # t_net1_weight: "cuda:0 f32[8, 12]"
  # t_net2_bias: "cuda:0 f32[8]"
  # t_net2_weight: "cuda:0 f32[8, 8]"
  t0 = torch.nn.functional.linear(x, t_net1_weight, t_net1_bias)  # t0: "cuda:0 f32[1, 8]"
    # t0 = ltorch.linear(x, t_net1_weight, t_net1_bias)  # t0: "cuda:0 f32[1, 8]"
      # t0 = prims.linear(x, t_net1_weight, t_net1_bias)  # t0: "cuda:0 f32[1, 8]"
  p1 = torch_all_gather_prim_impl(t0, _torch_distributed_distributed_c10d_ProcessGroup_0, True, 0)  # p1: "FUTURE cuda:0 f32[2, 8]"
  del t0
  t2 = torch_wait_prim_impl(p1)  # t2: "cuda:0 f32[2, 8]"
  del p1
  (t3, t4) = torch.chunk(t2, 2, 0)
    # (t3, t4) = ltorch.chunk(t2, 2, 0)
      # t3 = prims.slice_prim(t2, [0, 0], [1, 8], [1, 1])  # t3: "cuda:0 f32[1, 8]"
      # t4 = prims.slice_prim(t2, [1, 0], [2, 8], [1, 1])  # t4: "cuda:0 f32[1, 8]"
  del t2
  [t13, t5] = nvFusion0(t3, t4)
    # t5 = prims.cat((t3, t4), -1)  # t5: "cuda:0 f32[1, 16]"
    # t6 = prims.mul(0.5, t5)  # t6: "cuda:0 f32[1, 16]"
    # t7 = prims.pow(t5, 3.0)  # t7: "cuda:0 f32[1, 16]"
    # t8 = prims.mul(0.044715, t7)  # t8: "cuda:0 f32[1, 16]"
    # t9 = prims.add(t5, t8)  # t9: "cuda:0 f32[1, 16]"
    # t10 = prims.mul(0.7978845608028654, t9)  # t10: "cuda:0 f32[1, 16]"
    # t11 = prims.tanh(t10)  # t11: "cuda:0 f32[1, 16]"
    # t12 = prims.add(1.0, t11)  # t12: "cuda:0 f32[1, 16]"
    # t13 = prims.mul(t6, t12)  # t13: "cuda:0 f32[1, 16]"
  del t3, t4
  t14 = torch_slice_prim_impl(t13, [0, 0], [1, 8], [1, 1])  # t14: "cuda:0 f32[1, 8]"
  del t13
  t15 = torch.nn.functional.linear(t14, t_net2_weight, None)  # t15: "cuda:0 f32[1, 8]"
    # t15 = ltorch.linear(t14, t_net2_weight, None)  # t15: "cuda:0 f32[1, 8]"
      # t15 = prims.linear(t14, t_net2_weight, None)  # t15: "cuda:0 f32[1, 8]"
  p16 = torch_all_reduce_prim_impl(t15, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True, True)  # p16: "FUTURE cuda:0 f32[1, 8]"
  del t15
  t17 = torch_wait_prim_impl(p16)  # t17: "cuda:0 f32[1, 8]"
  del p16
  t29 = torch.unsqueeze(t_net2_bias, 0)  # t29: "cuda:0 f32[1, 8]"
    # t29 = ltorch.unsqueeze(t_net2_bias, 0)  # t29: "cuda:0 f32[1, 8]"
      # t29 = prims.broadcast_in_dim(t_net2_bias, [1, 8], [1])  # t29: "cuda:0 f32[1, 8]"
  t18 = Tensor.expand(t29, (1, 8))  # t18: "cuda:0 f32[1, 8]"
    # t18 = ltorch.expand(t29, (1, 8))  # t18: "cuda:0 f32[1, 8]"
      # t18 = prims.broadcast_in_dim(t29, (1, 8), (0, 1))  # t18: "cuda:0 f32[1, 8]"
  del t29
  [t19] = nvFusion1(t17, t18)
    # t19 = prims.add(t17, t18)  # t19: "cuda:0 f32[1, 8]"
  del t17, t18
  return {'output': t19, 'flat_args': [x, t_net1_bias, t_net1_weight, t_net2_bias, t_net2_weight], 'flat_output': (t19,)}, ((t14, t5, t_net2_weight, x), (0.5, 3.0, 0.044715, 0.7978845608028654))

In this trace, some BoundSymbols are derived from tensor parallel pre/post-processing. In my humble opinion the trace would look easier to interpret if the corresponding bsyms have a label or a name indicating they are pre/post-processing.

One example would be:

  p16 = torch_all_reduce_prim_impl(t15, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True, True)  # p16: "FUTURE cuda:0 f32[1, 8]"
  del t15
  t17 = torch_wait_prim_impl(p16)  # t17: "cuda:0 f32[1, 8]"
  del p16
  t29 = torch.unsqueeze(t_net2_bias, 0)  # t29: "cuda:0 f32[1, 8]"
    # t29 = ltorch.unsqueeze(t_net2_bias, 0)  # t29: "cuda:0 f32[1, 8]"
      # t29 = prims.broadcast_in_dim(t_net2_bias, [1, 8], [1])  # t29: "cuda:0 f32[1, 8]"
  t18 = Tensor.expand(t29, (1, 8))  # t18: "cuda:0 f32[1, 8]"
    # t18 = ltorch.expand(t29, (1, 8))  # t18: "cuda:0 f32[1, 8]"
      # t18 = prims.broadcast_in_dim(t29, (1, 8), (0, 1))  # t18: "cuda:0 f32[1, 8]"
  del t29
  [t19] = nvFusion1(t17, t18)
    # t19 = prims.add(t17, t18)  # t19: "cuda:0 f32[1, 8]"

If these lines are easy to guess they are post-processing of row-wise parallel linear, I'd say the trace is more perspicuous. An alternative is that we call a method named row_linear_postprocessing in a visitor_transform and a bsym would just call it.

Motivation

I've had experienced an appetite that I'd like to have a BoundSymbol call a sequence of primitive ops without decomposition / executor-specific optimizations, mainly for the sake of readability/clarity of it in a trace.

Pitch

Alternatives

Additional context

t-vi commented 5 months ago

Hi Masaki! Thank you for this!

Marking the origin of instructions is indeed a topic we are keen to improve on. We recently merged #40 to add source locations to the initial trace. The next step in this is that @IvanYashchuk will look into propagating these, but your proposal (as I understand it) of instructions such as the ones inserted by distirbuted as originating from there would be a great complement to making it easier to track the source of instructions from user code in the interpretation phase.

crcrpar commented 5 months ago

so right after column_parallel applied, a trace looks like this:

# Constructed by transform into column-wise tensor parallel
import thunder
import thunder.distributed.prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_net1_bias, t_net1_weight, t_net2_bias, t_net2_weight):
  # x: "cuda:0 f32[1, 12]"
  # t_net1_bias: "cuda:0 f32[8]"
  # t_net1_weight: "cuda:0 f32[8, 12]"
  # t_net2_bias: "cuda:0 f32[8]"
  # t_net2_weight: "cuda:0 f32[8, 16]"
  t18 = thunder.distributed.prims.synchronize_tensor_parallel_input(x, _torch_distributed_distributed_c10d_ProcessGroup_0, _TensorParallelLayerType_1)  # t18: "cuda:0 f32[1, 12]"

  # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116:             return F.linear(input, self.weight, self.bias)
  b = ltorch.linear(t18, t_net1_weight, t_net1_bias)  # b: "cuda:0 f32[1, 16]"
    # b = prims.linear(t18, t_net1_weight, t_net1_bias)  # b: "cuda:0 f32[1, 16]"
  t19 = thunder.distributed.prims.synchronize_tensor_parallel_output(b, _torch_distributed_distributed_c10d_ProcessGroup_0, _TensorParallelLayerType_1)  # t19: "cuda:0 f32[1, 32]"

  # /opt/pytorch/lightning-thunder/thunder/tests/distributed/helper.py:28:          return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
  result = ltorch.mul(0.5, t19)  # result: "cuda:0 f32[1, 16]"
    # result = prims.mul(0.5, t19)  # result: "cuda:0 f32[1, 16]"
  t7 = ltorch.pow(t19, 3.0)  # t7: "cuda:0 f32[1, 16]"
    # t7 = prims.pow(t19, 3.0)  # t7: "cuda:0 f32[1, 16]"
  t8 = ltorch.mul(0.044715, t7)  # t8: "cuda:0 f32[1, 16]"
    # t8 = prims.mul(0.044715, t7)  # t8: "cuda:0 f32[1, 16]"
  t9 = ltorch.add(t19, t8, alpha=None)  # t9: "cuda:0 f32[1, 16]"
    # t9 = prims.add(t19, t8)  # t9: "cuda:0 f32[1, 16]"
  t10 = ltorch.mul(0.7978845608028654, t9)  # t10: "cuda:0 f32[1, 16]"
    # t10 = prims.mul(0.7978845608028654, t9)  # t10: "cuda:0 f32[1, 16]"
  t11 = ltorch.tanh(t10)  # t11: "cuda:0 f32[1, 16]"
    # t11 = prims.tanh(t10)  # t11: "cuda:0 f32[1, 16]"
  t12 = ltorch.add(1.0, t11, alpha=None)  # t12: "cuda:0 f32[1, 16]"
    # t12 = prims.add(1.0, t11)  # t12: "cuda:0 f32[1, 16]"
  input = ltorch.mul(result, t12)  # input: "cuda:0 f32[1, 16]"
    # input = prims.mul(result, t12)  # input: "cuda:0 f32[1, 16]"

  # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116:             return F.linear(input, self.weight, self.bias)
  t17 = ltorch.linear(input, t_net2_weight, t_net2_bias)  # t17: "cuda:0 f32[1, 8]"
    # t17 = prims.linear(input, t_net2_weight, t_net2_bias)  # t17: "cuda:0 f32[1, 8]"
  return t17

The bsym producing t19 does not have the source location info as you explain.

t-vi commented 5 months ago

We'll need to decide if a source location is the right format for this, but calling trace.set_current_source_location(filename, positions) would add this (but you need to set it to None, None afterwards) or add a ctx manager for doing so when you bind a new symbol.

IvanYashchuk commented 5 months ago

I don't think the source location is the right format here. It should do one thing and I think it should be: showing the original location used to construct the initial trace.

There's also a BoundSymbol.header that allows printing anything above the BoundSymbol in the generated Python code: https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/symbol.py#L305-L306

When using BoundSymbol.__call__ to construct symbols a special context manager can be used to set the header: https://github.com/Lightning-AI/lightning-thunder/blob/77d0fbd328d56d971ee6234c5593de979bf6c887/thunder/core/symbol.py#L54-L55

There is no example code for how to preserve header or source location information best yet. Masaki, do you have ideas for this?

crcrpar commented 5 months ago

"source location" for bsyms added by visit_transform or whatever doesn't feel ideal. I want more like a mechanism that calls a sequence of torch ops (if fusion executors fail to optimize) in a bsym with the name of a primitive especially when looking at this block

  p16 = torch_all_reduce_prim_impl(t15, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True, True)  # p16: "FUTURE cuda:0 f32[1, 8]"
  del t15
  t17 = torch_wait_prim_impl(p16)  # t17: "cuda:0 f32[1, 8]"
  del p16
  t29 = torch.unsqueeze(t_net2_bias, 0)  # t29: "cuda:0 f32[1, 8]"
    # t29 = ltorch.unsqueeze(t_net2_bias, 0)  # t29: "cuda:0 f32[1, 8]"
      # t29 = prims.broadcast_in_dim(t_net2_bias, [1, 8], [1])  # t29: "cuda:0 f32[1, 8]"
  t18 = Tensor.expand(t29, (1, 8))  # t18: "cuda:0 f32[1, 8]"
    # t18 = ltorch.expand(t29, (1, 8))  # t18: "cuda:0 f32[1, 8]"
      # t18 = prims.broadcast_in_dim(t29, (1, 8), (0, 1))  # t18: "cuda:0 f32[1, 8]"
  del t29
  [t19] = nvFusion1(t17, t18)
    # t19 = prims.add(t17, t18)  # t19: "cuda:0 f32[1, 8]"

this would be much clearer if t19 = synchronize_tensor_parallel_output(...) with subsymbols defining what it does