NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
240 stars 44 forks source link

Write a sharded transformer block in nvFuser API. #2199

Open wujingyue opened 2 months ago

wujingyue commented 2 months ago

This is to unblock @cowanmeg and @samnordmann 's distributed matmul experiments.

I'll start with the tensor parallelism proposed by the original Megatron-LM paper.

  1. Only MHA and MLP are sharded.
  2. Activations are sharded in 2D, batch and hidden. However, the batch dimension sharding is just for data parallelism and the dimension is never resharded.
  3. Weights are sharded in 1D, the hidden dimension.
wujingyue commented 2 months ago

Note to myself: I'll first try to get a single-device nvFuser python definition from Thunder, and then we can manually shard it using nvFuser's API.

@Priya2698 pointed me to the nv_enable_linear flag (https://github.com/Lightning-AI/lightning-thunder/blob/90a0f4c0d0a90d1e94684a847f3adfe2230985b4/thunder/tests/test_nvfuser.py#L875) that I'll need to turn on to enable prims.linear via nvFuser. I'll probably need to nv_enable_bookend=False as well.

wujingyue commented 2 months ago

Note to myself: I'll start with the following benchmark

$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s

which exercises one transformer layer in nanoGPT: https://github.com/Lightning-AI/lightning-thunder/blob/cab020881765594fd9552d4deb8cc4e0f64410d2/thunder/tests/nanogpt_model.py#L132-L143

wujingyue commented 2 months ago

cc @Priya2698

a.ndim==2 is the first check that failed. Here's how you can reproduce the problem:

With the following patch

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..4767ab9c 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -2201,6 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
         return False

     enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+    enable_linear = True
     if not enable_linear:
         return False
     # Verify linear inputs and bias (optional) are supported tensors.
@@ -2210,6 +2211,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
         return False

     # nvFuser only supports 2D inputs in v0.2.3.
+    import pdb; pdb.set_trace()
     if not a.ndim == 2:
         return False
     return True
$ NVFUSER_DUMP=python_definition pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
========================================================================================================================================================================================================================================= test session starts =========================================================================================================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 162 items / 161 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB set_trace >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py(2215)_linear_check()
-> if not a.ndim == 2:
(Pdb) p a.ndim
3
(Pdb)

The Python definition printed out is unsurprisingly five fusions, none of which have matmul or linear.

wujingyue commented 2 months ago

Below is a WAR for the above Thunder check but it ran into an nvFuser issue.

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..137da102 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -4,6 +4,7 @@ from numbers import Number
 from typing import Union, List, Any, Optional, Dict, Set, Tuple, Type
 from types import NoneType
 from collections.abc import Callable, Mapping, Hashable, Sequence
+import math
 import os
 import time
 from copy import copy
@@ -796,7 +797,7 @@ instantiated) this heuristic actually leads to worse code.
             enable_bookend: None | bool = get_compile_option("nv_enable_bookend", bookend_help)
             # Set default value.
             if enable_bookend is None:
-                enable_bookend = True
+                enable_bookend = False
             assert isinstance(enable_bookend, bool)

             if enable_bookend:
@@ -2200,7 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
     if nv_version < LooseVersion("0.2.3"):
         return False

-    enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+    enable_linear = True
     if not enable_linear:
         return False
     # Verify linear inputs and bias (optional) are supported tensors.
@@ -2209,8 +2210,11 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
     if bias is not None and not is_supported_tensor(bias):
         return False

-    # nvFuser only supports 2D inputs in v0.2.3.
-    if not a.ndim == 2:
+    if a.ndim < 2:
+        return False
+    if b.ndim != 2:
+        return False
+    if bias.ndim != 1:
         return False
     return True

@@ -2226,7 +2230,10 @@ def linear(
     nva = getnv(a, fd, lc_to_nv_map)
     nvb = getnv(b, fd, lc_to_nv_map)
     nvbias = None if bias is None else getnv(bias, fd, lc_to_nv_map)
-    return fd.ops.linear(nva, nvb, nvbias)
+
+    nva_2d = fd.ops.reshape(nva, (math.prod(a.shape[:-1]), a.shape[-1]))
+    nvc_2d = fd.ops.linear(nva_2d, nvb, nvbias)
+    return fd.ops.reshape(nvc_2d, a.shape[:-1] + (b.shape[-2],))

 register_supported(PrimIDs.LINEAR, linear, _linear_check)
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cast(T4, dtype=DataType.Float)
    T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(128, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(128, dtype=DataType.Int)
    S15 = fd.define_scalar(1, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
    S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T19 = fd.ops.add(T12, S18)
    T20 = fd.ops.rsqrt(T19)
    S21 = fd.define_scalar(16, dtype=DataType.Int)
    S22 = fd.define_scalar(128, dtype=DataType.Int)
    S23 = fd.define_scalar(1600, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
    T26 = fd.ops.sub(T5, T25)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(128, dtype=DataType.Int)
    S29 = fd.define_scalar(1600, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
    T32 = fd.ops.mul(T26, T31)
    S33 = fd.define_scalar(16, dtype=DataType.Int)
    S34 = fd.define_scalar(128, dtype=DataType.Int)
    S35 = fd.define_scalar(1600, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.mul(T32, T38)
    S40 = fd.define_scalar(16, dtype=DataType.Int)
    S41 = fd.define_scalar(128, dtype=DataType.Int)
    S42 = fd.define_scalar(1600, dtype=DataType.Int)
    V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
    T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
    T45 = fd.ops.cast(T44, dtype=DataType.Float)
    T46 = fd.ops.add(T39, T45)
    T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1600, dtype=DataType.Int)
    V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
    T51 = fd.ops.reshape(T47, new_shape=V50)
    T52 = fd.ops.linear(T51, T1, T0)
    fd.add_output(T52)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((4800,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800,), (1,)),
    torch.randn((7680000,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800, 1600), (1600, 1)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 146, in execute
    result = self._execute(
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment
Exception raised from deriveHeuristic at /opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp:3671 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7fbf362d8381 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7fbf365d51b3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x4bde42 (0x7fbf36675e42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x4c5032 (0x7fbf3667d032 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x4d0c42 (0x7fbf36688c42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::SegmentCandidateFinder::SegmentCandidateFinder(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const*, nvfuser::SegmentCandidateFinderOptions) + 0x46f (0x7fbf366897ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x4a8082 (0x7fbf36660082 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x4d1a0e (0x7fbf36689a0e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionKernelRuntime::FusionKernelRuntime(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const&, nvfuser::serde::FusionKernelRuntime const*, std::optional<nvfuser::PrimDataType>, long, long, long, bool) + 0x373 (0x7fbf36799ed3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x5e5b57 (0x7fbf3679db57 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1e7 (0x7fbf3679e8b7 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, bool, bool, std::optional<signed char>) const + 0x3c8 (0x7fbf36981998 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x18ca25 (0x7fbf36344a25 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2009c2 (0x7fbf363b89c2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x288d00 (0x7fbf36440d00 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: <unknown function> + 0x15a10e (0x555dc548510e in /usr/bin/python3)
frame #16: _PyObject_MakeTpCall + 0x25b (0x555dc547ba7b in /usr/bin/python3)
frame #17: <unknown function> + 0x168acb (0x555dc5493acb in /usr/bin/python3)
frame #18: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #19: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #20: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #21: _PyObject_FastCallDictTstate + 0xc4 (0x555dc547ac14 in /usr/bin/python3)
frame #22: _PyObject_Call_Prepend + 0xc1 (0x555dc54908d1 in /usr/bin/python3)
frame #23: <unknown function> + 0x280700 (0x555dc55ab700 in /usr/bin/python3)
frame #24: _PyObject_MakeTpCall + 0x25b (0x555dc547ba7b in /usr/bin/python3)
frame #25: _PyEval_EvalFrameDefault + 0x64e6 (0x555dc5474096 in /usr/bin/python3)
frame #26: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #27: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #28: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #29: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #30: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #31: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #32: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #33: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #34: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #35: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #36: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #37: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #38: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #39: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #40: _PyObject_FastCallDictTstate + 0xc4 (0x555dc547ac14 in /usr/bin/python3)
frame #41: _PyObject_Call_Prepend + 0x5c (0x555dc549086c in /usr/bin/python3)
frame #42: <unknown function> + 0x280700 (0x555dc55ab700 in /usr/bin/python3)
frame #43: PyObject_Call + 0xbb (0x555dc549442b in /usr/bin/python3)
frame #44: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #45: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #46: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #47: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #48: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #49: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #50: _PyEval_EvalFrameDefault + 0x6bd (0x555dc546e26d in /usr/bin/python3)
frame #51: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #52: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #53: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #54: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #55: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #56: PyObject_Call + 0x122 (0x555dc5494492 in /usr/bin/python3)
frame #57: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #58: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #59: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #60: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #61: _PyEval_EvalFrameDefault + 0x614a (0x555dc5473cfa in /usr/bin/python3)
frame #62: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #63: _PyEval_EvalFrameDefault + 0x614a (0x555dc5473cfa in /usr/bin/python3)
wujingyue commented 2 months ago

FYI, NVFUSER_DUMP=segmenter_logging prints the following

**Segmenter** Considering fusion:
T34_g[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   = mma(T32_g[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ],
         T33_g[ bS94{1}, iS136{i0}, iS137{1600} ])

Scheduler _no_op_ ***rejected*** because : output has a concrete dimension
Scheduler _matmul_ ***rejected*** because : MmaOp input has unsupported dependency
Scheduler _reduction_ ***rejected*** because : No reduction op to schedule
Scheduler _transpose_ ***rejected*** because : no support for mma ops.
Scheduler _pointwise_ ***rejected*** because : no support for mma ops.
Scheduler _inner_persistent_ ***rejected*** because : needs a reduction op
Scheduler _outer_persistent_ ***rejected*** because : needs a reduction op
Scheduler _inner_outer_persistent_ ***rejected*** because : needs a reduction op
wujingyue commented 2 months ago

The matmul scheduler failed at https://github.com/NVIDIA/Fuser/blob/7f7126d2d3bddaf22e0bcb77292dffa8c4958288/csrc/scheduler/matmul_utils.cpp#L275

Looks like it assumes both operands to be broadcasted. I'm under the impression that we removed that assumption for https://github.com/NVIDIA/Fuser/issues/1628. What am I missing? @zasdfgbnm

wujingyue commented 2 months ago

FYI, below is the complete fusion after preseg optimizations. The MmaOp is indeed part of the beautiful broadcast+broadcast+mma+add+float2bfloat subgraph, which is good. However, due to other ops in the fusion, this subgraph is not given to the matmul scheduler immediately. Instead, it's decomposed into singletons, and the segmenter has troubles merging them into the expected subgraph.

$ NVFUSER_DUMP=fusion_ir_preseg python repro.py 
Fusion IR after pre-segmenter optimization passes:
Inputs:
  T0_g[ iS0{i0} ], __bfloat
  T1_g[ iS134{i0}, iS135{1600} ], __bfloat
  T2_g[ iS132{1600} ], __bfloat
  T3_g[ iS130{1600} ], __bfloat
  T4_g[ iS107{16}, iS108{128}, iS109{1600} ], __bfloat
Outputs:
  T38_g[ iS105{2048}, iS140{i0} ], __bfloat

%kernel_math {
T5_l[ iS110{16}, iS111{128}, iS112{1600} ]
   = __bfloat2float(T4_g[ iS107{16}, iS108{128}, iS109{1600} ]);
T6_l[ iS116{16}, iS117{128}, rS118{1600} ](Avg),
T7_l[ iS122{16}, iS123{128}, rS124{1600} ](Var),
T8_l[ iS113{16}, iS114{128}, rS115{1600} ](Count)
 = Welford ( T5_l[ iS110{16}, iS111{128}, iS112{1600} ](Avg), 
  allreduce = false )
T12_l[ iS119{16}, iS120{128}, bS30{1} ]
   = broadcast( T6_l[ iS116{16}, iS117{128}, rS118{1600} ] )
T13_l[ iS31{16}, iS32{128}, bS33{1} ]
   = Set( T12_l[ iS119{16}, iS120{128}, bS30{1} ], cache_op=Streaming )
T16_l[ iS40{16}, iS41{128}, bS42{1} ]
   = Set( T13_l[ iS31{16}, iS32{128}, bS33{1} ], cache_op=Streaming )
T17_l[ iS43{16}, iS44{128}, bS45{1 ex 1600} ] = expand( T16_l[ iS40{16}, iS41{128}, bS42{1} ], {16, 128, 1600} )
T18_l[ iS46{16}, iS47{128}, iS121{1600} ]
   = T5_l[ iS110{16}, iS111{128}, iS112{1600} ]
   - T17_l[ iS43{16}, iS44{128}, bS45{1 ex 1600} ];
d17 = (double)(1600);
d19 = double(1) * d17;
d23 = (double)(0);
d25 = d19 - d23;
d27 = (double)(0);
b29 = d25 >= d27;
d31 = (double)(0);
d33 = where(b29, d25, d31);
d39 = reciprocal(d33);
T9_l[ iS125{16}, iS126{128} ]
   = T7_l[ iS122{16}, iS123{128}, rS124{1600} ]
   * d39;
T10_l[ iS127{16}, iS128{128}, bS24{1} ]
   = broadcast( T9_l[ iS125{16}, iS126{128} ] )
T11_l[ iS25{16}, iS26{128}, bS27{1} ]
   = Set( T10_l[ iS127{16}, iS128{128}, bS24{1} ], cache_op=Streaming )
T14_l[ iS34{16}, iS35{128}, bS36{1} ]
   = T11_l[ iS25{16}, iS26{128}, bS27{1} ]
   + double(1.0000000000000001e-05);
T15_l[ iS37{16}, iS38{128}, bS39{1} ]
   = rsqrtf(T14_l[ iS34{16}, iS35{128}, bS36{1} ]);
T19_l[ iS49{16}, iS50{128}, bS51{1} ]
   = Set( T15_l[ iS37{16}, iS38{128}, bS39{1} ], cache_op=Streaming )
T20_l[ iS52{16}, iS53{128}, bS54{1 ex 1600} ] = expand( T19_l[ iS49{16}, iS50{128}, bS51{1} ], {16, 128, 1600} )
T21_l[ iS55{16}, iS56{128}, iS129{1600} ]
   = T18_l[ iS46{16}, iS47{128}, iS121{1600} ]
   * T20_l[ iS52{16}, iS53{128}, bS54{1 ex 1600} ];
T22_l[ bS58{1}, bS59{1}, iS131{1600} ]
   = broadcast( T3_g[ iS130{1600} ] )
T23_l[ bS61{1 ex 16}, bS62{1 ex 128}, iS63{1600} ] = expand( T22_l[ bS58{1}, bS59{1}, iS131{1600} ], {16, 128, 1600} )
T24_l[ bS64{1 ex 16}, bS65{1 ex 128}, iS66{1600} ]
   = __bfloat2float(T23_l[ bS61{1 ex 16}, bS62{1 ex 128}, iS63{1600} ]);
T25_l[ iS67{16}, iS68{128}, iS69{1600} ]
   = T21_l[ iS55{16}, iS56{128}, iS129{1600} ]
   * T24_l[ bS64{1 ex 16}, bS65{1 ex 128}, iS66{1600} ];
T26_l[ bS70{1}, bS71{1}, iS133{1600} ]
   = broadcast( T2_g[ iS132{1600} ] )
T27_l[ bS73{1 ex 16}, bS74{1 ex 128}, iS75{1600} ] = expand( T26_l[ bS70{1}, bS71{1}, iS133{1600} ], {16, 128, 1600} )
T28_l[ bS76{1 ex 16}, bS77{1 ex 128}, iS78{1600} ]
   = __bfloat2float(T27_l[ bS73{1 ex 16}, bS74{1 ex 128}, iS75{1600} ]);
T29_l[ iS79{16}, iS80{128}, iS81{1600} ]
   = T25_l[ iS67{16}, iS68{128}, iS69{1600} ]
   + T28_l[ bS76{1 ex 16}, bS77{1 ex 128}, iS78{1600} ];
T30_l[ iS82{16}, iS83{128}, iS84{1600} ]
   = __float2bfloat(T29_l[ iS79{16}, iS80{128}, iS81{1600} ]);
T31_l[ iS90{( 16 * 128 )}rf, iS87{1600} ] = view( T30_l[ iS82{16}, iS83{128}, iS84{1600} ] )
T32_l[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ]
   = broadcast( T31_l[ iS90{( 16 * 128 )}rf, iS87{1600} ] )
T33_l[ bS94{1}, iS136{i0}, iS137{1600} ]
   = broadcast( T1_g[ iS134{i0}, iS135{1600} ] )
T34_l[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   = mma(T32_l[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ],
         T33_l[ bS94{1}, iS136{i0}, iS137{1600} ])
T35_l[ iS100{i0} ]
   = __bfloat2float(T0_g[ iS0{i0} ]);
T36_l[ bS101{1}, iS102{i0} ]
   = broadcast( T35_l[ iS100{i0} ] )
T37_l[ iS103{2048}, iS139{i0} ]
   = T34_l[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   + T36_l[ bS101{1}, iS102{i0} ];
T38_g[ iS105{2048}, iS140{i0} ]
   = __float2bfloat(T37_l[ iS103{2048}, iS139{i0} ]);
}
Priya2698 commented 2 months ago

This issue looks related to: https://github.com/NVIDIA/Fuser/issues/2127. The failure stemmed from assuming inputs to be created through BroadcastOp.

@wujingyue What do you get after #2221?

While the ATen evaluation for matmul/linear will drop these assumptions once the new IR nodes are merged, at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

wujingyue commented 2 months ago

at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

That's right. I already merged #2221, so you can reproduce this by running the reproducer in https://github.com/NVIDIA/Fuser/issues/2199#issuecomment-2101214930. Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment.

That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you.

wujingyue commented 2 months ago

With some more hacks (which I'll try to find a way to submit), I'm getting some useful nvFusions to hopefully start with. Now the forward pass runs two nvFusions. The first one has one fd.ops.linear, which I suspect is the input linear layer. The second one has three fd.ops.linear, which I suspect is the output linear layer followed by the two-layer MLP.

I'll confirm this and try to include SDPA as well.

$ NVFUSER_DUMP=python_definition NVFUSER_DISABLE=matmul_expr_eval pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
============================================================================================================================ test session starts =============================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 162 items / 161 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cast(T4, dtype=DataType.Float)
    T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(128, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(128, dtype=DataType.Int)
    S15 = fd.define_scalar(1, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
    S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T19 = fd.ops.add(T12, S18)
    T20 = fd.ops.rsqrt(T19)
    S21 = fd.define_scalar(16, dtype=DataType.Int)
    S22 = fd.define_scalar(128, dtype=DataType.Int)
    S23 = fd.define_scalar(1600, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
    T26 = fd.ops.sub(T5, T25)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(128, dtype=DataType.Int)
    S29 = fd.define_scalar(1600, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
    T32 = fd.ops.mul(T26, T31)
    S33 = fd.define_scalar(16, dtype=DataType.Int)
    S34 = fd.define_scalar(128, dtype=DataType.Int)
    S35 = fd.define_scalar(1600, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.mul(T32, T38)
    S40 = fd.define_scalar(16, dtype=DataType.Int)
    S41 = fd.define_scalar(128, dtype=DataType.Int)
    S42 = fd.define_scalar(1600, dtype=DataType.Int)
    V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
    T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
    T45 = fd.ops.cast(T44, dtype=DataType.Float)
    T46 = fd.ops.add(T39, T45)
    T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1600, dtype=DataType.Int)
    V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
    T51 = fd.ops.reshape(T47, new_shape=V50)
    T52 = fd.ops.linear(T51, T1, T0)
    S53 = fd.define_scalar(16, dtype=DataType.Int)
    S54 = fd.define_scalar(128, dtype=DataType.Int)
    S55 = fd.define_scalar(4800, dtype=DataType.Int)
    V56 = fd.define_vector([S53, S54, S55], dtype=DataType.Int)
    T57 = fd.ops.reshape(T52, new_shape=V56)
    T58 = fd.ops.slice(T57, start_indices=[0, 0, 0], end_indices=[16, 128, 1600], strides=[1, 1, 1])
    T59 = fd.ops.slice(T57, start_indices=[0, 0, 1600], end_indices=[16, 128, 3200], strides=[1, 1, 1])
    T60 = fd.ops.slice(T57, start_indices=[0, 0, 3200], end_indices=[16, 128, 4800], strides=[1, 1, 1])
    S61 = fd.define_scalar(16, dtype=DataType.Int)
    S62 = fd.define_scalar(128, dtype=DataType.Int)
    S63 = fd.define_scalar(25, dtype=DataType.Int)
    S64 = fd.define_scalar(64, dtype=DataType.Int)
    V65 = fd.define_vector([S61, S62, S63, S64], dtype=DataType.Int)
    T66 = fd.ops.reshape(T59, new_shape=V65)
    T67 = fd.ops.permute(T66, dims=[0, 2, 1, 3])
    S68 = fd.define_scalar(16, dtype=DataType.Int)
    S69 = fd.define_scalar(128, dtype=DataType.Int)
    S70 = fd.define_scalar(25, dtype=DataType.Int)
    S71 = fd.define_scalar(64, dtype=DataType.Int)
    V72 = fd.define_vector([S68, S69, S70, S71], dtype=DataType.Int)
    T73 = fd.ops.reshape(T58, new_shape=V72)
    T74 = fd.ops.permute(T73, dims=[0, 2, 1, 3])
    S75 = fd.define_scalar(16, dtype=DataType.Int)
    S76 = fd.define_scalar(128, dtype=DataType.Int)
    S77 = fd.define_scalar(25, dtype=DataType.Int)
    S78 = fd.define_scalar(64, dtype=DataType.Int)
    V79 = fd.define_vector([S75, S76, S77, S78], dtype=DataType.Int)
    T80 = fd.ops.reshape(T60, new_shape=V79)
    T81 = fd.ops.permute(T80, dims=[0, 2, 1, 3])
    fd.add_output(T74)
    fd.add_output(T67)
    fd.add_output(T81)

[W509 16:47:44.547956141 matmul_utils.cpp:386] Warning: Scheduling a matmul without heuristic plugin. Specify plugin location like this: NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libmatmulheuristic.so (function operator())

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T10 = fd.ops.permute(T9, dims=[0, 2, 1, 3])
    T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0])
    S12 = fd.define_scalar(16, dtype=DataType.Int)
    S13 = fd.define_scalar(128, dtype=DataType.Int)
    S14 = fd.define_scalar(1600, dtype=DataType.Int)
    V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
    T16 = fd.ops.reshape(T11, new_shape=V15)
    S17 = fd.define_scalar(2048, dtype=DataType.Int)
    S18 = fd.define_scalar(1600, dtype=DataType.Int)
    V19 = fd.define_vector([S17, S18], dtype=DataType.Int)
    T20 = fd.ops.reshape(T16, new_shape=V19)
    T21 = fd.ops.linear(T20, T1, T0)
    S22 = fd.define_scalar(16, dtype=DataType.Int)
    S23 = fd.define_scalar(128, dtype=DataType.Int)
    S24 = fd.define_scalar(1600, dtype=DataType.Int)
    V25 = fd.define_vector([S22, S23, S24], dtype=DataType.Int)
    T26 = fd.ops.reshape(T21, new_shape=V25)
    S27 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S28 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S29 = fd.define_scalar(16, dtype=DataType.Int)
    S30 = fd.define_scalar(128, dtype=DataType.Int)
    S31 = fd.define_scalar(1600, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.uniform(S27, S28, shape=V32, dtype=DataType.BFloat16)
    S34 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T35 = fd.ops.lt(T33, S34)
    T36 = fd.ops.cast(T26, dtype=DataType.Float)
    T37 = fd.ops.cast(T35, dtype=DataType.Float)
    T38 = fd.ops.mul(T36, T37)
    S39 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T40 = fd.ops.mul(T38, S39)
    T41 = fd.ops.cast(T8, dtype=DataType.Float)
    T42 = fd.ops.add(T41, T40)
    T43, T44 = fd.ops.var_mean(T42, dims=[2], correction=0, keepdim=False)
    S45 = fd.define_scalar(16, dtype=DataType.Int)
    S46 = fd.define_scalar(128, dtype=DataType.Int)
    S47 = fd.define_scalar(1, dtype=DataType.Int)
    V48 = fd.define_vector([S45, S46, S47], dtype=DataType.Int)
    T49 = fd.ops.broadcast_in_dim(T43, shape=V48, broadcast_dims=[0, 1])
    S50 = fd.define_scalar(16, dtype=DataType.Int)
    S51 = fd.define_scalar(128, dtype=DataType.Int)
    S52 = fd.define_scalar(1, dtype=DataType.Int)
    V53 = fd.define_vector([S50, S51, S52], dtype=DataType.Int)
    T54 = fd.ops.broadcast_in_dim(T44, shape=V53, broadcast_dims=[0, 1])
    S55 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T56 = fd.ops.add(T49, S55)
    T57 = fd.ops.rsqrt(T56)
    S58 = fd.define_scalar(16, dtype=DataType.Int)
    S59 = fd.define_scalar(128, dtype=DataType.Int)
    S60 = fd.define_scalar(1600, dtype=DataType.Int)
    V61 = fd.define_vector([S58, S59, S60], dtype=DataType.Int)
    T62 = fd.ops.broadcast_in_dim(T54, shape=V61, broadcast_dims=[0, 1, 2])
    T63 = fd.ops.sub(T42, T62)
    S64 = fd.define_scalar(16, dtype=DataType.Int)
    S65 = fd.define_scalar(128, dtype=DataType.Int)
    S66 = fd.define_scalar(1600, dtype=DataType.Int)
    V67 = fd.define_vector([S64, S65, S66], dtype=DataType.Int)
    T68 = fd.ops.broadcast_in_dim(T57, shape=V67, broadcast_dims=[0, 1, 2])
    T69 = fd.ops.mul(T63, T68)
    S70 = fd.define_scalar(16, dtype=DataType.Int)
    S71 = fd.define_scalar(128, dtype=DataType.Int)
    S72 = fd.define_scalar(1600, dtype=DataType.Int)
    V73 = fd.define_vector([S70, S71, S72], dtype=DataType.Int)
    T74 = fd.ops.broadcast_in_dim(T3, shape=V73, broadcast_dims=[2])
    T75 = fd.ops.cast(T74, dtype=DataType.Float)
    T76 = fd.ops.mul(T69, T75)
    S77 = fd.define_scalar(16, dtype=DataType.Int)
    S78 = fd.define_scalar(128, dtype=DataType.Int)
    S79 = fd.define_scalar(1600, dtype=DataType.Int)
    V80 = fd.define_vector([S77, S78, S79], dtype=DataType.Int)
    T81 = fd.ops.broadcast_in_dim(T2, shape=V80, broadcast_dims=[2])
    T82 = fd.ops.cast(T81, dtype=DataType.Float)
    T83 = fd.ops.add(T76, T82)
    T84 = fd.ops.cast(T83, dtype=DataType.BFloat16)
    S85 = fd.define_scalar(2048, dtype=DataType.Int)
    S86 = fd.define_scalar(1600, dtype=DataType.Int)
    V87 = fd.define_vector([S85, S86], dtype=DataType.Int)
    T88 = fd.ops.reshape(T84, new_shape=V87)
    T89 = fd.ops.linear(T88, T5, T4)
    S90 = fd.define_scalar(16, dtype=DataType.Int)
    S91 = fd.define_scalar(128, dtype=DataType.Int)
    S92 = fd.define_scalar(6400, dtype=DataType.Int)
    V93 = fd.define_vector([S90, S91, S92], dtype=DataType.Int)
    T94 = fd.ops.reshape(T89, new_shape=V93)
    T95 = fd.ops.cast(T94, dtype=DataType.Float)
    T96 = fd.ops.mul(T95, T95)
    T97 = fd.ops.mul(T96, T95)
    S98 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T99 = fd.ops.mul(S98, T95)
    S100 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T101 = fd.ops.mul(S100, T97)
    T102 = fd.ops.add(T95, T101)
    S103 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T104 = fd.ops.mul(S103, T102)
    T105 = fd.ops.tanh(T104)
    S106 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T107 = fd.ops.add(S106, T105)
    T108 = fd.ops.mul(T99, T107)
    T109 = fd.ops.cast(T108, dtype=DataType.BFloat16)
    S110 = fd.define_scalar(2048, dtype=DataType.Int)
    S111 = fd.define_scalar(6400, dtype=DataType.Int)
    V112 = fd.define_vector([S110, S111], dtype=DataType.Int)
    T113 = fd.ops.reshape(T109, new_shape=V112)
    T114 = fd.ops.linear(T113, T7, T6)
    S115 = fd.define_scalar(16, dtype=DataType.Int)
    S116 = fd.define_scalar(128, dtype=DataType.Int)
    S117 = fd.define_scalar(1600, dtype=DataType.Int)
    V118 = fd.define_vector([S115, S116, S117], dtype=DataType.Int)
    T119 = fd.ops.reshape(T114, new_shape=V118)
    S120 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S121 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S122 = fd.define_scalar(16, dtype=DataType.Int)
    S123 = fd.define_scalar(128, dtype=DataType.Int)
    S124 = fd.define_scalar(1600, dtype=DataType.Int)
    V125 = fd.define_vector([S122, S123, S124], dtype=DataType.Int)
    T126 = fd.ops.uniform(S120, S121, shape=V125, dtype=DataType.BFloat16)
    S127 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T128 = fd.ops.lt(T126, S127)
    T129 = fd.ops.cast(T119, dtype=DataType.Float)
    T130 = fd.ops.cast(T128, dtype=DataType.Float)
    T131 = fd.ops.mul(T129, T130)
    S132 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T133 = fd.ops.mul(T131, S132)
    T134 = fd.ops.add(T42, T133)
    T135 = fd.ops.cast(T134, dtype=DataType.BFloat16)
    fd.add_output(T135)
Priya2698 commented 2 months ago

at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

That's right. I already merged #2221, so you can reproduce this by running the reproducer in #2199 (comment). Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment.

That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you.

Yes, the new IR nodes will fix this issue since we won't evaluate a decomposed IR. The pattern matching will be redundant and removed once the API is modified to use the new IR nodes.

wujingyue commented 2 months ago

@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs?

Priya2698 commented 2 months ago

@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs?

It looks like the WAR will still run into the segmentation issue due to the reshapes.

If you don't necessarily need that change in thunder to proceed, then adding the new nodes will lift that restriction anyway. I am estimating the new PRs within a couple days earlier next week.

We can go ahead with it if it unblocks you in the interim.

wujingyue commented 2 months ago

Cool -- I closed https://github.com/Lightning-AI/lightning-thunder/pull/391.

wujingyue commented 2 months ago

@cowanmeg https://github.com/Lightning-AI/lightning-thunder/commit/bf84b04576a13c56df586ea574a328ad85133de7 checked in what's in the forward pass of a single-device transformer block modulo SDPA. See the message of that commit for more details. With that, we should be able to work on this in parallel. I'll try to include SPDA and backprop, and you'll try to build a sharded version. How does that sound?

cowanmeg commented 2 months ago

Thanks @wujingyue! This is super helpful, I'll start working on the sharding soon!

cowanmeg commented 2 months ago

I annotated the sharding of the MLP layer of the example: https://gist.github.com/cowanmeg/75b4144a3627df74efcfc12dda01a2a3

Some comments: (1) The two linear layers and GeLU have sharded computation. The dropout, layernorm, and residual add have replicated computed on each device. (BTW I don't think it would be too hard to represent SP). Sharding propagation is relatively straightforward if we annotate only the Linear layer inputs and outputs. I think the current naive one will suffice for at least now. (2) Now that LinearOp and MatmulOp are part of the compute definition, we need should reconsider how we insert resharding expressions and DID leaf parallelization. (cc @Priya2698 @jacobhinkle) (3) Pointwise scheduler needs to be modified to ignore DID axes. This should be straightforward as reordering DID axes in front and ignoring them.

While we discuss our design for (2), I will manually translate these programs and decompose the LinearOp myself. Regardless this is necessary since we need to logically split sharded axes in the compute definition because of our RFactor restriction. For MLP, this isn't too hard and would let us get a small example working.

wujingyue commented 1 month ago

FYI, https://github.com/Lightning-AI/lightning-thunder/commit/af6bfc10d38bba449a5745e9e09d11359d47feca added the forward pass of the whole transformer block (i.e. with SDPA). Caveat: the speed is probably far from SOL because nvFuser can't fuse matmul+softmax+matmul at this moment. https://github.com/NVIDIA/Fuser/issues/2278 is going to add an SDPA IR node so we can fallback to the existing flash attention implementation in ATen. When that's done, we'll see in the fusion definition simply the SDPA node instead of the decomposed form.

wujingyue commented 1 month ago

https://github.com/Lightning-AI/lightning-thunder/commit/b06bf4edeff9b25a5411b38ad3257d142bff803b adds the backprop. It's hard to verify because https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/core/rematerialization.py#L569 mixes in some ops from the forward pass. However, when I try to print the backprop trace before rematerialization (see below), I do see 12 prims.matmuls, which looks right. (There are 4 linear layers and 2 matmuls in the forward pass, each of which becomes 2 matmuls in backprop).

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t27, = cotangents
  t0, t4, t8, t11, t12, t13, t15, t20, t_attn_c_attn_weight, t42, t45, t49, t59, \
  t71, t75, t80, t84, t_attn_c_proj_weight, t92, t100, t108, t113, t116, t117, \
  t118, t120, t125, t_mlp_c_fc_weight, t127, t127, t127, t129, t149, t152, t136, \
  t157, t_mlp_c_proj_weight, t168, = C0
  i6, i9, f43, f47, i54, f62, f78, i85, f89, f91, f93, f102, = C1
  i622 = prims.sub(1600, i85)  # i622: "int 1600"
  i807 = prims.sub(1600, i6)  # i807: "int 1600"
  [t524, t527, t588, t591, t596, t602, t661, t664, t773, t776, t781, t787, t827] = nvFusion0(f102, f43, f47, f62, f78, f89, f91, f93, i54, i622, i807, i9, t0, t100, t108, t11, t113, t116, t117, t118, t12, t120, t125, t127, t129, t13, t136, t149, t15, t152, t157, t168, t20, t27, t4, t42, t45, t49, t59, t71, t75, t8, t80, t84, t92, t_attn_c_attn_weight, t_attn_c_proj_weight, t_mlp_c_fc_weight, t_mlp_c_proj_weight)
    # t506 = prims.convert_element_type(t27, dtypes.float32)  # t506: "cuda:0 f32[16, 128, 1600]"
    # t511 = prims.mul(f102, t506)  # t511: "cuda:0 f32[16, 128, 1600]"
    # t514 = prims.mul(t168, t511)  # t514: "cuda:0 f32[16, 128, 1600]"
    # t517 = prims.convert_element_type(t514, dtypes.bfloat16)  # t517: "cuda:0 bf16[16, 128, 1600]"
    # t518 = prims.reshape(t517, (2048, 1600))  # t518: "cuda:0 bf16[2048, 1600]"
    # t519 = prims.matmul(t518, t_mlp_c_proj_weight)  # t519: "cuda:0 bf16[2048, 6400]"
    # t520 = prims.reshape(t519, (16, 128, 6400))  # t520: "cuda:0 bf16[16, 128, 6400]"
    # t522 = prims.transpose(t518, (1, 0))  # t522: "cuda:0 bf16[1600, 2048]"
    # t523 = prims.reshape(t157, (2048, 6400))  # t523: "cuda:0 bf16[2048, 6400]"
    # t524 = prims.matmul(t522, t523)  # t524: "cuda:0 bf16[1600, 6400]"
    # t526 = prims.sum(t514, (0, 1))  # t526: "cuda:0 f32[1600]"
    # t527 = prims.convert_element_type(t526, dtypes.bfloat16)  # t527: "cuda:0 bf16[1600]"
    # t528 = prims.convert_element_type(t520, dtypes.float32)  # t528: "cuda:0 f32[16, 128, 6400]"
    # t529 = prims.mul(t152, t528)  # t529: "cuda:0 f32[16, 128, 6400]"
    # t530 = prims.mul(t136, t528)  # t530: "cuda:0 f32[16, 128, 6400]"
    # t537 = prims.mul(t149, t149)  # t537: "cuda:0 f32[16, 128, 6400]"
    # t538 = prims.sub(1.0, t537)  # t538: "cuda:0 f32[16, 128, 6400]"
    # t539 = prims.mul(t530, t538)  # t539: "cuda:0 f32[16, 128, 6400]"
    # t543 = prims.mul(f93, t539)  # t543: "cuda:0 f32[16, 128, 6400]"
    # t550 = prims.mul(f91, t543)  # t550: "cuda:0 f32[16, 128, 6400]"
    # t554 = prims.mul(f89, t529)  # t554: "cuda:0 f32[16, 128, 6400]"
    # t558 = prims.add(t543, t554)  # t558: "cuda:0 f32[16, 128, 6400]"
    # t561 = prims.mul(t127, t550)  # t561: "cuda:0 f32[16, 128, 6400]"
    # t562 = prims.mul(t129, t550)  # t562: "cuda:0 f32[16, 128, 6400]"
    # t567 = prims.add(t558, t562)  # t567: "cuda:0 f32[16, 128, 6400]"
    # t570 = prims.mul(t127, t561)  # t570: "cuda:0 f32[16, 128, 6400]"
    # t576 = prims.add(t567, t570)  # t576: "cuda:0 f32[16, 128, 6400]"
    # t580 = prims.add(t576, t570)  # t580: "cuda:0 f32[16, 128, 6400]"
    # t581 = prims.convert_element_type(t580, dtypes.bfloat16)  # t581: "cuda:0 bf16[16, 128, 6400]"
    # t582 = prims.reshape(t581, (2048, 6400))  # t582: "cuda:0 bf16[2048, 6400]"
    # t583 = prims.matmul(t582, t_mlp_c_fc_weight)  # t583: "cuda:0 bf16[2048, 1600]"
    # t584 = prims.reshape(t583, (16, 128, 1600))  # t584: "cuda:0 bf16[16, 128, 1600]"
    # t586 = prims.transpose(t582, (1, 0))  # t586: "cuda:0 bf16[6400, 2048]"
    # t587 = prims.reshape(t125, (2048, 1600))  # t587: "cuda:0 bf16[2048, 1600]"
    # t588 = prims.matmul(t586, t587)  # t588: "cuda:0 bf16[6400, 1600]"
    # t590 = prims.sum(t580, (0, 1))  # t590: "cuda:0 f32[6400]"
    # t591 = prims.convert_element_type(t590, dtypes.bfloat16)  # t591: "cuda:0 bf16[6400]"
    # t592 = prims.convert_element_type(t584, dtypes.float32)  # t592: "cuda:0 f32[16, 128, 1600]"
    # t595 = prims.sum(t592, (0, 1))  # t595: "cuda:0 f32[1600]"
    # t596 = prims.convert_element_type(t595, dtypes.bfloat16)  # t596: "cuda:0 bf16[1600]"
    # t597 = prims.mul(t120, t592)  # t597: "cuda:0 f32[16, 128, 1600]"
    # t598 = prims.mul(t118, t592)  # t598: "cuda:0 f32[16, 128, 1600]"
    # t601 = prims.sum(t598, (0, 1))  # t601: "cuda:0 f32[1600]"
    # t602 = prims.convert_element_type(t601, dtypes.bfloat16)  # t602: "cuda:0 bf16[1600]"
    # t603 = prims.mul(t117, t597)  # t603: "cuda:0 f32[16, 128, 1600]"
    # t604 = prims.mul(t116, t597)  # t604: "cuda:0 f32[16, 128, 1600]"
    # t605 = prims.sum(t604, (2,))  # t605: "cuda:0 f32[16, 128]"
    # t606 = prims.broadcast_in_dim(t605, [16, 128, 1], [0, 1])  # t606: "cuda:0 f32[16, 128, 1]"
    # t607 = prims.neg(t603)  # t607: "cuda:0 f32[16, 128, 1600]"
    # t609 = prims.sum(t607, (2,))  # t609: "cuda:0 f32[16, 128]"
    # t610 = prims.broadcast_in_dim(t609, [16, 128, 1], [0, 1])  # t610: "cuda:0 f32[16, 128, 1]"
    # t611 = prims.mul(-0.5, t606)  # t611: "cuda:0 f32[16, 128, 1]"
    # t612 = prims.pow(t113, 3.0)  # t612: "cuda:0 f32[16, 128, 1]"
    # t613 = prims.mul(t611, t612)  # t613: "cuda:0 f32[16, 128, 1]"
    # t615 = prims.sum(t610, (2,))  # t615: "cuda:0 f32[16, 128]"
    # t616 = prims.sum(t613, (2,))  # t616: "cuda:0 f32[16, 128]"
    # t619 = prims.broadcast_in_dim(t615, [16, 128, 1], [0, 1])  # t619: "cuda:0 f32[16, 128, 1]"
    # t620 = prims.broadcast_in_dim(t619, (16, 128, 1600), (0, 1, 2))  # t620: "cuda:0 f32[16, 128, 1600]"
    # t621 = prims.mul(0.000625, t620)  # t621: "cuda:0 f32[16, 128, 1600]"
    # t623 = prims.broadcast_in_dim(t616, [16, 128, 1], [0, 1])  # t623: "cuda:0 f32[16, 128, 1]"
    # t624 = prims.broadcast_in_dim(t623, (16, 128, 1600), (0, 1, 2))  # t624: "cuda:0 f32[16, 128, 1600]"
    # t626 = prims.broadcast_in_dim(t108, [16, 128, 1], [0, 1])  # t626: "cuda:0 f32[16, 128, 1]"
    # t627 = prims.broadcast_in_dim(t626, (16, 128, 1600), (0, 1, 2))  # t627: "cuda:0 f32[16, 128, 1600]"
    # t628 = prims.mul(2.0, t624)  # t628: "cuda:0 f32[16, 128, 1600]"
    # t629 = prims.sub(t100, t627)  # t629: "cuda:0 f32[16, 128, 1600]"
    # t630 = prims.mul(t628, t629)  # t630: "cuda:0 f32[16, 128, 1600]"
    # f631 = prims.convert_element_type(i622, float)  # f631: "float 1600.0"
    # t632 = prims.div(t630, f631)  # t632: "cuda:0 f32[16, 128, 1600]"
    # t633 = prims.add(t621, t632)  # t633: "cuda:0 f32[16, 128, 1600]"
    # t637 = prims.add(t603, t633)  # t637: "cuda:0 f32[16, 128, 1600]"
    # t641 = prims.add(t506, t637)  # t641: "cuda:0 f32[16, 128, 1600]"
    # t648 = prims.mul(f78, t641)  # t648: "cuda:0 f32[16, 128, 1600]"
    # t651 = prims.mul(t92, t648)  # t651: "cuda:0 f32[16, 128, 1600]"
    # t654 = prims.convert_element_type(t651, dtypes.bfloat16)  # t654: "cuda:0 bf16[16, 128, 1600]"
    # t655 = prims.reshape(t654, (2048, 1600))  # t655: "cuda:0 bf16[2048, 1600]"
    # t656 = prims.matmul(t655, t_attn_c_proj_weight)  # t656: "cuda:0 bf16[2048, 1600]"
    # t657 = prims.reshape(t656, (16, 128, 1600))  # t657: "cuda:0 bf16[16, 128, 1600]"
    # t659 = prims.transpose(t655, (1, 0))  # t659: "cuda:0 bf16[1600, 2048]"
    # t660 = prims.reshape(t84, (2048, 1600))  # t660: "cuda:0 bf16[2048, 1600]"
    # t661 = prims.matmul(t659, t660)  # t661: "cuda:0 bf16[1600, 1600]"
    # t663 = prims.sum(t651, (0, 1))  # t663: "cuda:0 f32[1600]"
    # t664 = prims.convert_element_type(t663, dtypes.bfloat16)  # t664: "cuda:0 bf16[1600]"
    # t668 = prims.reshape(t657, (16, 128, 25, 64))  # t668: "cuda:0 bf16[16, 128, 25, 64]"
    # t671 = prims.transpose(t668, (0, 2, 1, 3))  # t671: "cuda:0 bf16[16, 25, 128, 64]"
    # t672 = prims.transpose(t42, (0, 1, 3, 2))  # t672: "cuda:0 bf16[16, 25, 64, 128]"
    # t673 = prims.matmul(t671, t672)  # t673: "cuda:0 bf16[16, 25, 128, 128]"
    # t674 = prims.transpose(t80, (0, 1, 3, 2))  # t674: "cuda:0 bf16[16, 25, 128, 128]"
    # t675 = prims.matmul(t674, t671)  # t675: "cuda:0 bf16[16, 25, 128, 64]"
    # t676 = prims.convert_element_type(t673, dtypes.float32)  # t676: "cuda:0 f32[16, 25, 128, 128]"
    # t678 = prims.mul(f62, t676)  # t678: "cuda:0 f32[16, 25, 128, 128]"
    # t681 = prims.mul(t75, t678)  # t681: "cuda:0 f32[16, 25, 128, 128]"
    # t685 = prims.convert_element_type(t71, dtypes.float32)  # t685: "cuda:0 f32[16, 25, 128, 128]"
    # t687 = prims.mul(t685, t681)  # t687: "cuda:0 f32[16, 25, 128, 128]"
    # i691 = prims.add(i54, 4)  # i691: "int 3"
    # t701 = prims.sum(t687, (i691,))  # t701: "cuda:0 f32[16, 25, 128]"
    # t710 = prims.broadcast_in_dim(t701, [16, 25, 128, 1], [0, 1, 2])  # t710: "cuda:0 f32[16, 25, 128, 1]"
    # t711 = prims.convert_element_type(t710, dtypes.bfloat16)  # t711: "cuda:0 bf16[16, 25, 128, 1]"
    # t712 = prims.broadcast_in_dim(t711, (16, 25, 128, 128), (0, 1, 2, 3))  # t712: "cuda:0 bf16[16, 25, 128, 128]"
    # t714 = prims.convert_element_type(t712, dtypes.float32)  # t714: "cuda:0 f32[16, 25, 128, 128]"
    # t715 = prims.sub(t681, t714)  # t715: "cuda:0 f32[16, 25, 128, 128]"
    # t719 = prims.mul(t685, t715)  # t719: "cuda:0 f32[16, 25, 128, 128]"
    # t720 = prims.convert_element_type(t719, dtypes.bfloat16)  # t720: "cuda:0 bf16[16, 25, 128, 128]"
    # t722 = prims.where(t59, t720, 0.0)  # t722: "cuda:0 bf16[16, 25, 128, 128]"
    # t723 = prims.transpose(t49, (0, 1, 3, 2))  # t723: "cuda:0 bf16[16, 25, 128, 64]"
    # t724 = prims.matmul(t722, t723)  # t724: "cuda:0 bf16[16, 25, 128, 64]"
    # t725 = prims.transpose(t45, (0, 1, 3, 2))  # t725: "cuda:0 bf16[16, 25, 64, 128]"
    # t726 = prims.matmul(t725, t722)  # t726: "cuda:0 bf16[16, 25, 64, 128]"
    # t727 = prims.convert_element_type(t726, dtypes.float32)  # t727: "cuda:0 f32[16, 25, 64, 128]"
    # t729 = prims.mul(f47, t727)  # t729: "cuda:0 f32[16, 25, 64, 128]"
    # t730 = prims.convert_element_type(t729, dtypes.bfloat16)  # t730: "cuda:0 bf16[16, 25, 64, 128]"
    # t733 = prims.transpose(t730, (0, 1, 3, 2))  # t733: "cuda:0 bf16[16, 25, 128, 64]"
    # t734 = prims.convert_element_type(t724, dtypes.float32)  # t734: "cuda:0 f32[16, 25, 128, 64]"
    # t736 = prims.mul(f43, t734)  # t736: "cuda:0 f32[16, 25, 128, 64]"
    # t737 = prims.convert_element_type(t736, dtypes.bfloat16)  # t737: "cuda:0 bf16[16, 25, 128, 64]"
    # t740 = prims.transpose(t675, (0, 2, 1, 3))  # t740: "cuda:0 bf16[16, 128, 25, 64]"
    # t745 = prims.reshape(t740, (16, 128, 1600))  # t745: "cuda:0 bf16[16, 128, 1600]"
    # t748 = prims.transpose(t737, (0, 2, 1, 3))  # t748: "cuda:0 bf16[16, 128, 25, 64]"
    # t753 = prims.reshape(t748, (16, 128, 1600))  # t753: "cuda:0 bf16[16, 128, 1600]"
    # t756 = prims.transpose(t733, (0, 2, 1, 3))  # t756: "cuda:0 bf16[16, 128, 25, 64]"
    # t761 = prims.reshape(t756, (16, 128, 1600))  # t761: "cuda:0 bf16[16, 128, 1600]"
    # t766 = prims.cat((t753, t761, t745), i9)  # t766: "cuda:0 bf16[16, 128, 4800]"
    # t767 = prims.reshape(t766, (2048, 4800))  # t767: "cuda:0 bf16[2048, 4800]"
    # t768 = prims.matmul(t767, t_attn_c_attn_weight)  # t768: "cuda:0 bf16[2048, 1600]"
    # t769 = prims.reshape(t768, (16, 128, 1600))  # t769: "cuda:0 bf16[16, 128, 1600]"
    # t771 = prims.transpose(t767, (1, 0))  # t771: "cuda:0 bf16[4800, 2048]"
    # t772 = prims.reshape(t20, (2048, 1600))  # t772: "cuda:0 bf16[2048, 1600]"
    # t773 = prims.matmul(t771, t772)  # t773: "cuda:0 bf16[4800, 1600]"
    # t774 = prims.convert_element_type(t766, dtypes.float32)  # t774: "cuda:0 f32[16, 128, 4800]"
    # t775 = prims.sum(t774, (0, 1))  # t775: "cuda:0 f32[4800]"
    # t776 = prims.convert_element_type(t775, dtypes.bfloat16)  # t776: "cuda:0 bf16[4800]"
    # t777 = prims.convert_element_type(t769, dtypes.float32)  # t777: "cuda:0 f32[16, 128, 1600]"
    # t780 = prims.sum(t777, (0, 1))  # t780: "cuda:0 f32[1600]"
    # t781 = prims.convert_element_type(t780, dtypes.bfloat16)  # t781: "cuda:0 bf16[1600]"
    # t782 = prims.mul(t15, t777)  # t782: "cuda:0 f32[16, 128, 1600]"
    # t783 = prims.mul(t13, t777)  # t783: "cuda:0 f32[16, 128, 1600]"
    # t786 = prims.sum(t783, (0, 1))  # t786: "cuda:0 f32[1600]"
    # t787 = prims.convert_element_type(t786, dtypes.bfloat16)  # t787: "cuda:0 bf16[1600]"
    # t788 = prims.mul(t12, t782)  # t788: "cuda:0 f32[16, 128, 1600]"
    # t789 = prims.mul(t11, t782)  # t789: "cuda:0 f32[16, 128, 1600]"
    # t790 = prims.sum(t789, (2,))  # t790: "cuda:0 f32[16, 128]"
    # t791 = prims.broadcast_in_dim(t790, [16, 128, 1], [0, 1])  # t791: "cuda:0 f32[16, 128, 1]"
    # t792 = prims.neg(t788)  # t792: "cuda:0 f32[16, 128, 1600]"
    # t794 = prims.sum(t792, (2,))  # t794: "cuda:0 f32[16, 128]"
    # t795 = prims.broadcast_in_dim(t794, [16, 128, 1], [0, 1])  # t795: "cuda:0 f32[16, 128, 1]"
    # t796 = prims.mul(-0.5, t791)  # t796: "cuda:0 f32[16, 128, 1]"
    # t797 = prims.pow(t8, 3.0)  # t797: "cuda:0 f32[16, 128, 1]"
    # t798 = prims.mul(t796, t797)  # t798: "cuda:0 f32[16, 128, 1]"
    # t800 = prims.sum(t795, (2,))  # t800: "cuda:0 f32[16, 128]"
    # t801 = prims.sum(t798, (2,))  # t801: "cuda:0 f32[16, 128]"
    # t804 = prims.broadcast_in_dim(t800, [16, 128, 1], [0, 1])  # t804: "cuda:0 f32[16, 128, 1]"
    # t805 = prims.broadcast_in_dim(t804, (16, 128, 1600), (0, 1, 2))  # t805: "cuda:0 f32[16, 128, 1600]"
    # t806 = prims.mul(0.000625, t805)  # t806: "cuda:0 f32[16, 128, 1600]"
    # t808 = prims.broadcast_in_dim(t801, [16, 128, 1], [0, 1])  # t808: "cuda:0 f32[16, 128, 1]"
    # t809 = prims.broadcast_in_dim(t808, (16, 128, 1600), (0, 1, 2))  # t809: "cuda:0 f32[16, 128, 1600]"
    # t811 = prims.broadcast_in_dim(t4, [16, 128, 1], [0, 1])  # t811: "cuda:0 f32[16, 128, 1]"
    # t812 = prims.broadcast_in_dim(t811, (16, 128, 1600), (0, 1, 2))  # t812: "cuda:0 f32[16, 128, 1600]"
    # t813 = prims.mul(2.0, t809)  # t813: "cuda:0 f32[16, 128, 1600]"
    # t814 = prims.sub(t0, t812)  # t814: "cuda:0 f32[16, 128, 1600]"
    # t815 = prims.mul(t813, t814)  # t815: "cuda:0 f32[16, 128, 1600]"
    # f816 = prims.convert_element_type(i807, float)  # f816: "float 1600.0"
    # t817 = prims.div(t815, f816)  # t817: "cuda:0 f32[16, 128, 1600]"
    # t818 = prims.add(t806, t817)  # t818: "cuda:0 f32[16, 128, 1600]"
    # t822 = prims.add(t788, t818)  # t822: "cuda:0 f32[16, 128, 1600]"
    # t826 = prims.add(t641, t822)  # t826: "cuda:0 f32[16, 128, 1600]"
    # t827 = prims.convert_element_type(t826, dtypes.bfloat16)  # t827: "cuda:0 bf16[16, 128, 1600]"
  return (t827, t776, t773, t664, t661, t781, t787, t596, t602, t591, t588, t527, t524)
wujingyue commented 1 month ago

@cowanmeg Here's how you can get a Thunder trace to help you understand the backprop nvFusion. The Thunder trace tends to be more concise than nvFusion and has shapes annotated. Also, you can dump the intermediate traces to see where the end trace comes from.

  1. Check out the branch wjy/sharded, which disables bookend and the SDPA executor so Thunder gives nvFuser the entire transformer block. It also patches the linear layer to work around https://github.com/NVIDIA/Fuser/issues/2317.
  2. Add print of whatever traces you like to examine. https://github.com/Lightning-AI/lightning-thunder/blob/27158e62a19e144a3081be9507c81084b702c58e/thunder/executors/torch_autograd.py#L109 is where Thunder tries to generate the forward and backward passes from the forward-only original trace. You can print any TraceCtx to see what it looks like. For example, in my previous comment, I tried to print bw_trace before https://github.com/Lightning-AI/lightning-thunder/blob/27158e62a19e144a3081be9507c81084b702c58e/thunder/executors/torch_autograd.py#L214 to see what the trace looks like before rematerialization.
  3. pytest thunder/benchmarks/targets.py -k test_nanogpt_block_grad[thunder] -s.
wujingyue commented 1 month ago

FYI, https://github.com/Lightning-AI/lightning-thunder/commit/e19f6ead412abb1e97c3a46cf214565e63cc6353 tries to update the test case to use the GPT-3 config, the one used in the two most recent Megatron papers: https://arxiv.org/pdf/2104.04473 and https://arxiv.org/pdf/2205.05198. It hits https://github.com/NVIDIA/Fuser/issues/2359 at this moment.