Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Add a transform to add nvtx range on the optimized trace #600

Closed kshitij12345 closed 3 months ago

kshitij12345 commented 3 months ago

This PR adds a post optimization transform to wrap compute symbols in NVTX range. This makes it easy to profile the trace with Nsight Systems and to easily map trace operations to GPU execution timeline.

For Future, we should allow user to:

  1. specify selective operations from trace to profile.
  2. specify regions of the trace to profile.

Usage

import os
import torch
import torch.distributed as tdist
import thunder
import thunder.distributed
from thunder.tests.litgpt_model import GPT, Config
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

if __name__ == "__main__":
    tdist.init_process_group(backend="nccl")
    LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", LOCAL_RANK)
    torch.set_default_device(device)

    config = Config.from_name("open_llama_3b")
    config.n_layer = 2
    with device:
        model = GPT(config)

    nvtx_profile_transform = NvtxProfileTransform()
    model = thunder.distributed.fsdp(thunder.jit(model, post_optimization_transforms=[nvtx_profile_transform]))

    input_ids = torch.randint(1, 30010, (128, 256), dtype=torch.long, device=device)

    if LOCAL_RANK == 0:
        torch.cuda.cudart().cudaProfilerStart()
    logits = model(input_ids)
    logits.sum().backward()

    if LOCAL_RANK == 0:
        torch.cuda.cudart().cudaProfilerStop()

Example of the transformed trace (for brevity, generated from a different script than above):

# Constructed by NVTX Profile Transform (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, t_fc1_bias, t_fc1_weight, t_fc2_bias, t_fc2_weight):
  # x: "cuda:0 f32[4, 4096, 4096]"
  # t_fc1_bias: "cuda:0 f32[4096]"
  # t_fc1_weight: "cuda:0 f32[4096, 4096]"
  # t_fc2_bias: "cuda:0 f32[4096]"
  # t_fc2_weight: "cuda:0 f32[4096, 4096]"
  nvtx_range_push('t0 = torch.nn.functional.linear(x, t_fc1_weight, t_fc1_bias)  # t0: "cuda:0 f32[4, 4096, 4096]"')
  t0 = torch.nn.functional.linear(x, t_fc1_weight, t_fc1_bias)  # t0: "cuda:0 f32[4, 4096, 4096]"
    # t0 = ltorch.linear(x, t_fc1_weight, t_fc1_bias)  # t0: "cuda:0 f32[4, 4096, 4096]"
      # t0 = prims.linear(x, t_fc1_weight, t_fc1_bias)  # t0: "cuda:0 f32[4, 4096, 4096]"
  nvtx_range_pop()
  nvtx_range_push('[t1, t2] = nvFusion0(t0)')
  [t1, t2] = nvFusion0(t0)
    # t1 = prims.gt(t0, 0.0)  # t1: "cuda:0 b8[4, 4096, 4096]"
    # t2 = prims.where(t1, t0, 0.0)  # t2: "cuda:0 f32[4, 4096, 4096]"
  nvtx_range_pop()
  del t0
  nvtx_range_push('t3 = torch.nn.functional.linear(t2, t_fc2_weight, t_fc2_bias)  # t3: "cuda:0 f32[4, 4096, 4096]"')
  t3 = torch.nn.functional.linear(t2, t_fc2_weight, t_fc2_bias)  # t3: "cuda:0 f32[4, 4096, 4096]"
    # t3 = ltorch.linear(t2, t_fc2_weight, t_fc2_bias)  # t3: "cuda:0 f32[4, 4096, 4096]"
      # t3 = prims.linear(t2, t_fc2_weight, t_fc2_bias)  # t3: "cuda:0 f32[4, 4096, 4096]"
  nvtx_range_pop()
  return {'output': t3, 'flat_args': [x, t_fc1_bias, t_fc1_weight, t_fc2_bias, t_fc2_weight], 'flat_output': (t3,)}, ((t1, t2, t_fc2_weight, x), ())

Example in nsight GUI image

Alternative: One alternative to this is to use nvtx package with automatic annotation but this leads to a very dense (and big profile) report (with more information than required to absorb) Eg.

image


crcrpar commented 3 months ago
    model = thunder.distributed.fsdp(model)
    # use the transform
    model = thunder.jit(model, executors=["torch"], post_optimization_transforms=[nvtx_profile_transform])

I think there would be a few other ways to setting up this model.

One would be

jitted = thunder.jit(model, ..., post_optimization_transforms=[nvtx_profile_transform])
fsdp = thunder.distributed.fsdp(jitted)

The other would be to use thunder.core.transforms.add_transform

jitted = thunder.jit(model, ..., post_optimization_transforms=[nvtx_profile_transform])
fsdp = thunder.distributed.fsdp(jitted)
fsdp = add_transform(fsdp, nvtx_profile_transform)

oops it doesn't seem to support post_optimization_transforms.

anyway could you check these or the first one works?

kshitij12345 commented 3 months ago

The first one works correctly, (will update the example in PR description with this)

    nvtx_profile_transform = NvtxProfileTransform()
    model = thunder.distributed.fsdp(thunder.jit(model, post_optimization_transforms=[nvtx_profile_transform]))

As for second, there is add_post_optimization_transform instead which works.

    nvtx_profile_transform = NvtxProfileTransform()
    from thunder.core.transforms import add_post_optimization_transform
    model = thunder.distributed.fsdp(thunder.jit(model, executors=["torch"]))
    model = add_post_optimization_transform(model, nvtx_profile_transform)
kshitij12345 commented 3 months ago

@t-vi agreed, have added a simple test. Thanks!