Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

grad transform `forward_and_backward_from_trace` is not handling NumberProxy properly in saved_for_backward #541

Closed jjsjann123 closed 2 months ago

jjsjann123 commented 4 months ago

🐛 Bug

construct_trace used in grad transform mistakenly converts all saved_for_backward into proxies.

https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L74-L77 https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L3619-L3620

For a simple program like this:

import torch
import thunder

def foo(t, ab):
    return t * ab

jfoo = thunder.jit(foo)

dtype = torch.float32
t = torch.randn(5, 3, device="cuda").to(dtype=dtype)
t.requires_grad_()

ab = 0.5

out = jfoo(t, ab)

The transform gives us:

===fwd trc===
 # Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(t):
  # t: "cuda:0 f32[5, 3]"
  t0 = ltorch.mul(t, 0.5)  # t0: "cuda:0 f32[5, 3]"
    # t0 = prims.mul(t, 0.5)  # t0: "cuda:0 f32[5, 3]"
  return ({'output': t0, 'flat_args': [t], 'flat_output': (t0,)}, (((t, 0.5), None, ([0.5],)),))

=== saved_for_backward=(((t, 0.5), None, ([0.5],)),)

===bwd trc===
 # Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C2, = saved_for_backward
  t1, = cotangents
  _, _, C4, = C2
  C5, = C4
  f1, = C5
  t7 = ltorch.mul(f1, t1)  # t7: "cuda:0 f32[5, 3]"
    # t7 = prims.mul(f1, t1)  # t7: "cuda:0 f32[5, 3]"
  return [t7]

--- bw_flat_saved_for_backward=[t, [FloatProxy name=f0, value=0.5], None, [FloatProxy name=f1, value=0.5]]

We have saved_for_backward=(((t, 0.5), None, ([0.5],)),), which was later translated to bw_flat_saved_for_backward=[t, [FloatProxy name=f0, value=0.5], None, [FloatProxy name=f1, value=0.5]] in backward trace and that is not right.

The backward trace is treating that as a proxy f1, instead of baking static numbers in.

Context

We need to support NumberProxy passed from fwd to bwd. Currently all NumberProxies are baked in as constant here: https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L3449

When we remove that line, it breaks rematerialization pass, since it assumed saved_for_backward is consistent between forward and backward. https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/rematerialization.py#L635 https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/rematerialization.py#L643

Note that it's using the same new_required_for_backward for both forward and backward. In the example above, we'll have fwd trace transformed like

fw_trace=# Constructed by Transform for execution (took 1 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast  
def augmented_forward_fn(t):
  # t: "cuda:0 f32[5, 3]"
  [t0] = nvFusion0(t)
    # t0 = prims.mul(t, 0.5)  # t0: "cuda:0 f32[5, 3]"
  return {'output': t0, 'flat_args': [t], 'flat_output': (t0,)}, ((), (0.5,))

===
new_fw_trace=# Constructed by Rematerialization
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(t):
  # t: "cuda:0 f32[5, 3]"
  [t0] = nvFusion0(t)
    # t0 = prims.mul(t, 0.5)  # t0: "cuda:0 f32[5, 3]"
  return {'output': t0, 'flat_args': [t], 'flat_output': (t0,)}, ((), (f1,))

Where the new_fw_trace is trying to save a f1 NumberProxy, which is never a number proxy on fwd graph.

jjsjann123 commented 4 months ago

Wondering if @IvanYashchuk has any thought/preference on how we want to fix this?

jjsjann123 commented 4 months ago

cc'ing @wujingyue regarding the resharding issue that we just discussed offline.

wujingyue commented 4 months ago

Thanks for tagging me. This indeed looks like the same symptom I encountered.

For https://github.com/NVIDIA/Fuser/issues/2199, I used to be able to generate a one-nvFusion transformer block backprop by enabling linear, disabling bookend and disabling the cudnn and sdpa executors. It failed on me today, and here's a way to reproduce:

  1. Check out https://github.com/Lightning-AI/lightning-thunder/tree/wjy/proxy
  2. pytest thunder/benchmarks/targets.py -k test_nanogpt_block[backward-thunder] -s
  3. You'll observe the backward trace split into two nvFusions by the following node. i54 is a NumberProxy.
  t711 = torch.sum(t688, i54, True, dtype=None)  # t711: "cuda:0 bf16[16, 25, 128, 1]"
    # t711 = ltorch.sum(t688, i54, True, dtype=None)  # t711: "cuda:0 bf16[16, 25, 128, 1]"
      # b828 = prims.ge(i54, 0)  # b828: "bool False"
      # b829 = prims.lt(i54, 0)  # b829: "bool True"
      # i830 = prims.add(i54, 4)  # i830: "int 3"
      # b831 = prims.ge(i830, 0)  # b831: "bool True"
      # b832 = prims.lt(i830, 4)  # b832: "bool True"
      # t833 = ltorch.to(t688, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None)  # t833: "cuda:0 f32[16, 25, 128, 128]"
        # t833 = prims.convert_element_type(t688, dtypes.float32)  # t833: "cuda:0 f32[16, 25, 128, 128]"
      # t840 = prims.sum(t833, (i830,))  # t840: "cuda:0 f32[16, 25, 128]"
      # b841 = prims.eq(i830, 0)  # b841: "bool False"
      # b842 = prims.eq(i830, 1)  # b842: "bool False"
      # b843 = prims.eq(i830, 2)  # b843: "bool False"
      # b844 = prims.eq(i830, 3)  # b844: "bool True"
      # b845 = prims.eq(i830, 0)  # b845: "bool False"
      # b846 = prims.eq(i830, 1)  # b846: "bool False"
      # b847 = prims.eq(i830, 2)  # b847: "bool False"
      # b848 = prims.eq(i830, 3)  # b848: "bool True"
      # t849 = prims.broadcast_in_dim(t840, [16, 25, 128, 1], [0, 1, 2])  # t849: "cuda:0 f32[16, 25, 128, 1]"
      # t711 = ltorch.to(t849, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None)  # t711: "cuda:0 bf16[16, 25, 128, 1]"
        # t711 = prims.convert_element_type(t849, dtypes.bfloat16)  # t711: "cuda:0 bf16[16, 25, 128, 1]"
Below are the full forward and backward traces in case you are interested: ```py $ pytest thunder/benchmarks/targets.py -k test_nanogpt_block[backward-thunder] -s ========================================================================================================================================================================================================================================= test session starts ========================================================================================================================================================================================================================================= platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0 Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket= benchmark: 4.0.0 (defaults: timer=torch.utils.benchmark.utils.timer.timer disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=True warmup_iterations=100000) rootdir: /opt/pytorch/lightning-thunder configfile: pyproject.toml plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, anyio-4.3.0, shard-0.1.2 timeout: 900.0s timeout method: signal timeout func_only: False collected 644 items / 643 deselected / 1 selected Running 1 items in this shard thunder/benchmarks/targets.py # Constructed by Delete Last Used (took 0 milliseconds) import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def augmented_forward_fn(x, t_attn_c_attn_bias, t_attn_c_attn_weight, t_attn_c_proj_bias, t_attn_c_proj_weight, t_ln_1_bias, t_ln_1_weight, t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, t_mlp_c_proj_bias, t_mlp_c_proj_weight): # x: "cuda:0 bf16[16, 128, 1600]" # t_attn_c_attn_bias: "cuda:0 bf16[4800]" # t_attn_c_attn_weight: "cuda:0 bf16[4800, 1600]" # t_attn_c_proj_bias: "cuda:0 bf16[1600]" # t_attn_c_proj_weight: "cuda:0 bf16[1600, 1600]" # t_ln_1_bias: "cuda:0 bf16[1600]" # t_ln_1_weight: "cuda:0 bf16[1600]" # t_ln_2_bias: "cuda:0 bf16[1600]" # t_ln_2_weight: "cuda:0 bf16[1600]" # t_mlp_c_fc_bias: "cuda:0 bf16[6400]" # t_mlp_c_fc_weight: "cuda:0 bf16[6400, 1600]" # t_mlp_c_proj_bias: "cuda:0 bf16[1600]" # t_mlp_c_proj_weight: "cuda:0 bf16[1600, 6400]" [t100, t108, t113, t166, t178, t4, t42, t71, t73, t8, t80, t90] = nvFusion0(t_attn_c_attn_bias, t_attn_c_attn_weight, t_attn_c_proj_bias, t_attn_c_proj_weight, t_ln_1_bias, t_ln_1_weight, t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, t_mlp_c_proj_bias, t_mlp_c_proj_weight, x) # t0 = prims.convert_element_type(x, dtypes.float32) # t0: "cuda:0 f32[16, 128, 1600]" # (t3, t4) = prims.var_mean(t0, (2,), correction=0) # t5 = prims.broadcast_in_dim(t3, [16, 128, 1], [0, 1]) # t5: "cuda:0 f32[16, 128, 1]" # t6 = prims.broadcast_in_dim(t4, [16, 128, 1], [0, 1]) # t6: "cuda:0 f32[16, 128, 1]" # t7 = prims.add(t5, 1e-05) # t7: "cuda:0 f32[16, 128, 1]" # t8 = prims.rsqrt(t7) # t8: "cuda:0 f32[16, 128, 1]" # t9 = prims.broadcast_in_dim(t6, (16, 128, 1600), (0, 1, 2)) # t9: "cuda:0 f32[16, 128, 1600]" # t11 = prims.sub(t0, t9) # t11: "cuda:0 f32[16, 128, 1600]" # t12 = prims.broadcast_in_dim(t8, (16, 128, 1600), (0, 1, 2)) # t12: "cuda:0 f32[16, 128, 1600]" # t13 = prims.mul(t11, t12) # t13: "cuda:0 f32[16, 128, 1600]" # t14 = prims.broadcast_in_dim(t_ln_1_weight, (16, 128, 1600), (2,)) # t14: "cuda:0 bf16[16, 128, 1600]" # t15 = prims.convert_element_type(t14, dtypes.float32) # t15: "cuda:0 f32[16, 128, 1600]" # t16 = prims.mul(t13, t15) # t16: "cuda:0 f32[16, 128, 1600]" # t17 = prims.broadcast_in_dim(t_ln_1_bias, (16, 128, 1600), (2,)) # t17: "cuda:0 bf16[16, 128, 1600]" # t18 = prims.convert_element_type(t17, dtypes.float32) # t18: "cuda:0 f32[16, 128, 1600]" # t19 = prims.add(t16, t18) # t19: "cuda:0 f32[16, 128, 1600]" # t20 = prims.convert_element_type(t19, dtypes.bfloat16) # t20: "cuda:0 bf16[16, 128, 1600]" # t21 = prims.linear(t20, t_attn_c_attn_weight, t_attn_c_attn_bias) # t21: "cuda:0 bf16[16, 128, 4800]" # t22 = prims.slice_prim(t21, [0, 0, 0], [16, 128, 1600], [1, 1, 1]) # t22: "cuda:0 bf16[16, 128, 1600]" # t23 = prims.slice_prim(t21, [0, 0, 1600], [16, 128, 3200], [1, 1, 1]) # t23: "cuda:0 bf16[16, 128, 1600]" # t24 = prims.slice_prim(t21, [0, 0, 3200], [16, 128, 4800], [1, 1, 1]) # t24: "cuda:0 bf16[16, 128, 1600]" # t25 = prims.reshape(t23, (16, 128, 25, 64)) # t25: "cuda:0 bf16[16, 128, 25, 64]" # t26 = prims.transpose(t25, (0, 2, 1, 3)) # t26: "cuda:0 bf16[16, 25, 128, 64]" # t31 = prims.reshape(t22, (16, 128, 25, 64)) # t31: "cuda:0 bf16[16, 128, 25, 64]" # t34 = prims.transpose(t31, (0, 2, 1, 3)) # t34: "cuda:0 bf16[16, 25, 128, 64]" # t39 = prims.reshape(t24, (16, 128, 25, 64)) # t39: "cuda:0 bf16[16, 128, 25, 64]" # t42 = prims.transpose(t39, (0, 2, 1, 3)) # t42: "cuda:0 bf16[16, 25, 128, 64]" # t43 = prims.convert_element_type(t34, dtypes.float32) # t43: "cuda:0 f32[16, 25, 128, 64]" # t44 = prims.mul(t43, 0.3535533905932738) # t44: "cuda:0 f32[16, 25, 128, 64]" # t45 = prims.convert_element_type(t44, dtypes.bfloat16) # t45: "cuda:0 bf16[16, 25, 128, 64]" # t46 = prims.transpose(t26, (0, 1, 3, 2)) # t46: "cuda:0 bf16[16, 25, 64, 128]" # t47 = prims.convert_element_type(t46, dtypes.float32) # t47: "cuda:0 f32[16, 25, 64, 128]" # t48 = prims.mul(t47, 0.3535533905932738) # t48: "cuda:0 f32[16, 25, 64, 128]" # t49 = prims.convert_element_type(t48, dtypes.bfloat16) # t49: "cuda:0 bf16[16, 25, 64, 128]" # t50 = prims.matmul(t45, t49) # t50: "cuda:0 bf16[16, 25, 128, 128]" # t51 = prims.iota(128, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t51: "cuda:0 i64[128]" # t52 = prims.broadcast_in_dim(t51, [128, 1], [0]) # t52: "cuda:0 i64[128, 1]" # t54 = prims.broadcast_in_dim(t51, [1, 128], [1]) # t54: "cuda:0 i64[1, 128]" # t55 = prims.add(t52, 0) # t55: "cuda:0 i64[128, 1]" # t56 = prims.broadcast_in_dim(t55, (128, 128), (0, 1)) # t56: "cuda:0 i64[128, 128]" # t57 = prims.broadcast_in_dim(t54, (128, 128), (0, 1)) # t57: "cuda:0 i64[128, 128]" # t58 = prims.ge(t56, t57) # t58: "cuda:0 b8[128, 128]" # t59 = prims.broadcast_in_dim(t58, (16, 25, 128, 128), (2, 3)) # t59: "cuda:0 b8[16, 25, 128, 128]" # t60 = prims.where(t59, t50, -float('inf')) # t60: "cuda:0 bf16[16, 25, 128, 128]" # t61 = prims.convert_element_type(t60, dtypes.float32) # t61: "cuda:0 f32[16, 25, 128, 128]" # t62 = prims.amax(t61, (3,)) # t62: "cuda:0 f32[16, 25, 128]" # t63 = prims.broadcast_in_dim(t62, [16, 25, 128, 1], [0, 1, 2]) # t63: "cuda:0 f32[16, 25, 128, 1]" # t64 = prims.broadcast_in_dim(t63, (16, 25, 128, 128), (0, 1, 2, 3)) # t64: "cuda:0 f32[16, 25, 128, 128]" # t65 = prims.sub(t61, t64) # t65: "cuda:0 f32[16, 25, 128, 128]" # t66 = prims.exp(t65) # t66: "cuda:0 f32[16, 25, 128, 128]" # t67 = prims.sum(t66, (3,)) # t67: "cuda:0 f32[16, 25, 128]" # t68 = prims.broadcast_in_dim(t67, [16, 25, 128, 1], [0, 1, 2]) # t68: "cuda:0 f32[16, 25, 128, 1]" # t69 = prims.broadcast_in_dim(t68, (16, 25, 128, 128), (0, 1, 2, 3)) # t69: "cuda:0 f32[16, 25, 128, 128]" # t70 = prims.div(t66, t69) # t70: "cuda:0 f32[16, 25, 128, 128]" # t71 = prims.convert_element_type(t70, dtypes.bfloat16) # t71: "cuda:0 bf16[16, 25, 128, 128]" # t72 = prims.uniform((16, 25, 128, 128), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t72: "cuda:0 bf16[16, 25, 128, 128]" # t73 = prims.lt(t72, 0.9) # t73: "cuda:0 b8[16, 25, 128, 128]" # t75 = prims.convert_element_type(t73, dtypes.float32) # t75: "cuda:0 f32[16, 25, 128, 128]" # t76 = prims.mul(t70, t75) # t76: "cuda:0 f32[16, 25, 128, 128]" # t79 = prims.mul(t76, 1.1111111111111112) # t79: "cuda:0 f32[16, 25, 128, 128]" # t80 = prims.convert_element_type(t79, dtypes.bfloat16) # t80: "cuda:0 bf16[16, 25, 128, 128]" # t81 = prims.matmul(t80, t42) # t81: "cuda:0 bf16[16, 25, 128, 64]" # t82 = prims.transpose(t81, (0, 2, 1, 3)) # t82: "cuda:0 bf16[16, 128, 25, 64]" # t83 = prims.stride_order(t82, (3, 2, 1, 0)) # t83: "cuda:0 bf16[16, 128, 25, 64]" # t84 = prims.reshape(t83, (16, 128, 1600)) # t84: "cuda:0 bf16[16, 128, 1600]" # t85 = prims.linear(t84, t_attn_c_proj_weight, t_attn_c_proj_bias) # t85: "cuda:0 bf16[16, 128, 1600]" # t89 = prims.uniform((16, 128, 1600), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t89: "cuda:0 bf16[16, 128, 1600]" # t90 = prims.lt(t89, 0.9) # t90: "cuda:0 b8[16, 128, 1600]" # t91 = prims.convert_element_type(t85, dtypes.float32) # t91: "cuda:0 f32[16, 128, 1600]" # t92 = prims.convert_element_type(t90, dtypes.float32) # t92: "cuda:0 f32[16, 128, 1600]" # t93 = prims.mul(t91, t92) # t93: "cuda:0 f32[16, 128, 1600]" # t96 = prims.mul(t93, 1.1111111111111112) # t96: "cuda:0 f32[16, 128, 1600]" # t100 = prims.add(t0, t96) # t100: "cuda:0 f32[16, 128, 1600]" # (t107, t108) = prims.var_mean(t100, (2,), correction=0) # t109 = prims.broadcast_in_dim(t107, [16, 128, 1], [0, 1]) # t109: "cuda:0 f32[16, 128, 1]" # t110 = prims.broadcast_in_dim(t108, [16, 128, 1], [0, 1]) # t110: "cuda:0 f32[16, 128, 1]" # t112 = prims.add(t109, 1e-05) # t112: "cuda:0 f32[16, 128, 1]" # t113 = prims.rsqrt(t112) # t113: "cuda:0 f32[16, 128, 1]" # t114 = prims.broadcast_in_dim(t110, (16, 128, 1600), (0, 1, 2)) # t114: "cuda:0 f32[16, 128, 1600]" # t116 = prims.sub(t100, t114) # t116: "cuda:0 f32[16, 128, 1600]" # t117 = prims.broadcast_in_dim(t113, (16, 128, 1600), (0, 1, 2)) # t117: "cuda:0 f32[16, 128, 1600]" # t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[16, 128, 1600]" # t119 = prims.broadcast_in_dim(t_ln_2_weight, (16, 128, 1600), (2,)) # t119: "cuda:0 bf16[16, 128, 1600]" # t120 = prims.convert_element_type(t119, dtypes.float32) # t120: "cuda:0 f32[16, 128, 1600]" # t121 = prims.mul(t118, t120) # t121: "cuda:0 f32[16, 128, 1600]" # t122 = prims.broadcast_in_dim(t_ln_2_bias, (16, 128, 1600), (2,)) # t122: "cuda:0 bf16[16, 128, 1600]" # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: "cuda:0 f32[16, 128, 1600]" # t124 = prims.add(t121, t123) # t124: "cuda:0 f32[16, 128, 1600]" # t125 = prims.convert_element_type(t124, dtypes.bfloat16) # t125: "cuda:0 bf16[16, 128, 1600]" # t126 = prims.linear(t125, t_mlp_c_fc_weight, t_mlp_c_fc_bias) # t126: "cuda:0 bf16[16, 128, 6400]" # t127 = prims.convert_element_type(t126, dtypes.float32) # t127: "cuda:0 f32[16, 128, 6400]" # t129 = prims.mul(t127, t127) # t129: "cuda:0 f32[16, 128, 6400]" # t133 = prims.mul(t129, t127) # t133: "cuda:0 f32[16, 128, 6400]" # t136 = prims.mul(0.5, t127) # t136: "cuda:0 f32[16, 128, 6400]" # t139 = prims.mul(0.044715, t133) # t139: "cuda:0 f32[16, 128, 6400]" # t143 = prims.add(t127, t139) # t143: "cuda:0 f32[16, 128, 6400]" # t146 = prims.mul(0.7978845608028654, t143) # t146: "cuda:0 f32[16, 128, 6400]" # t149 = prims.tanh(t146) # t149: "cuda:0 f32[16, 128, 6400]" # t152 = prims.add(1.0, t149) # t152: "cuda:0 f32[16, 128, 6400]" # t156 = prims.mul(t136, t152) # t156: "cuda:0 f32[16, 128, 6400]" # t157 = prims.convert_element_type(t156, dtypes.bfloat16) # t157: "cuda:0 bf16[16, 128, 6400]" # t158 = prims.linear(t157, t_mlp_c_proj_weight, t_mlp_c_proj_bias) # t158: "cuda:0 bf16[16, 128, 1600]" # t164 = prims.uniform((16, 128, 1600), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t164: "cuda:0 bf16[16, 128, 1600]" # t166 = prims.lt(t164, 0.9) # t166: "cuda:0 b8[16, 128, 1600]" # t167 = prims.convert_element_type(t158, dtypes.float32) # t167: "cuda:0 f32[16, 128, 1600]" # t168 = prims.convert_element_type(t166, dtypes.float32) # t168: "cuda:0 f32[16, 128, 1600]" # t169 = prims.mul(t167, t168) # t169: "cuda:0 f32[16, 128, 1600]" # t173 = prims.mul(t169, 1.1111111111111112) # t173: "cuda:0 f32[16, 128, 1600]" # t177 = prims.add(t100, t173) # t177: "cuda:0 f32[16, 128, 1600]" # t178 = prims.convert_element_type(t177, dtypes.bfloat16) # t178: "cuda:0 bf16[16, 128, 1600]" return {'output': t178, 'flat_args': [x, t_attn_c_attn_bias, t_attn_c_attn_weight, t_attn_c_proj_bias, t_attn_c_proj_weight, t_ln_1_bias, t_ln_1_weight, t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, t_mlp_c_proj_bias, t_mlp_c_proj_weight], 'flat_output': (t178,)}, ((t100, t108, t113, t166, t4, t42, t71, t73, t8, t80, t90, t_attn_c_attn_bias, t_attn_c_attn_weight, t_attn_c_proj_weight, t_ln_1_bias, t_ln_1_weight, t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, t_mlp_c_proj_weight, x), (1.1111111111111112, 0.3535533905932738, 0.3535533905932738, 1.1111111111111112, 1.1111111111111112, 0.5, 0.044715, 0.7978845608028654, -1, 0, 0, 2)) # Constructed by Delete Last Used (took 0 milliseconds) import operator import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" C0, C1, = saved_for_backward clear_mutable_collection(saved_for_backward) del saved_for_backward t27, = cotangents clear_mutable_collection(cotangents) del cotangents t100, t108, t113, t166, t4, t42, t71, t73, t8, t80, t90, t_attn_c_attn_bias, \ t_attn_c_attn_weight, t_attn_c_proj_weight, t_ln_1_bias, t_ln_1_weight, \ t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, \ t_mlp_c_proj_weight, x, = C0 clear_mutable_collection(C0) del C0 f102, f43, f47, f62, f78, f89, f91, f93, i54, i6, i85, i9, = C1 clear_mutable_collection(C1) del C1 i622 = operator.sub(1600, i85) # i622: "int 1600" # i622 = prims.sub(1600, i85) # i622: "int 1600" del i85 i807 = operator.sub(1600, i6) # i807: "int 1600" # i807 = prims.sub(1600, i6) # i807: "int 1600" del i6 [t524, t527, t588, t591, t596, t602, t642, t661, t664, t684, t688, t745] = nvFusion0(f102, f62, f78, f89, f91, f93, i622, t100, t108, t113, t166, t27, t42, t71, t73, t80, t90, t_attn_c_proj_weight, t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, t_mlp_c_proj_weight) # t75 = prims.convert_element_type(t73, dtypes.float32) # t75: "cuda:0 f32[16, 25, 128, 128]" # t81 = prims.matmul(t80, t42) # t81: "cuda:0 bf16[16, 25, 128, 64]" # t82 = prims.transpose(t81, (0, 2, 1, 3)) # t82: "cuda:0 bf16[16, 128, 25, 64]" # t83 = prims.stride_order(t82, (3, 2, 1, 0)) # t83: "cuda:0 bf16[16, 128, 25, 64]" # t84 = prims.reshape(t83, (16, 128, 1600)) # t84: "cuda:0 bf16[16, 128, 1600]" # t92 = prims.convert_element_type(t90, dtypes.float32) # t92: "cuda:0 f32[16, 128, 1600]" # t110 = prims.broadcast_in_dim(t108, [16, 128, 1], [0, 1]) # t110: "cuda:0 f32[16, 128, 1]" # t114 = prims.broadcast_in_dim(t110, (16, 128, 1600), (0, 1, 2)) # t114: "cuda:0 f32[16, 128, 1600]" # t116 = prims.sub(t100, t114) # t116: "cuda:0 f32[16, 128, 1600]" # t117 = prims.broadcast_in_dim(t113, (16, 128, 1600), (0, 1, 2)) # t117: "cuda:0 f32[16, 128, 1600]" # t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[16, 128, 1600]" # t119 = prims.broadcast_in_dim(t_ln_2_weight, (16, 128, 1600), (2,)) # t119: "cuda:0 bf16[16, 128, 1600]" # t120 = prims.convert_element_type(t119, dtypes.float32) # t120: "cuda:0 f32[16, 128, 1600]" # t121 = prims.mul(t118, t120) # t121: "cuda:0 f32[16, 128, 1600]" # t122 = prims.broadcast_in_dim(t_ln_2_bias, (16, 128, 1600), (2,)) # t122: "cuda:0 bf16[16, 128, 1600]" # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: "cuda:0 f32[16, 128, 1600]" # t124 = prims.add(t121, t123) # t124: "cuda:0 f32[16, 128, 1600]" # t125 = prims.convert_element_type(t124, dtypes.bfloat16) # t125: "cuda:0 bf16[16, 128, 1600]" # t126 = prims.linear(t125, t_mlp_c_fc_weight, t_mlp_c_fc_bias) # t126: "cuda:0 bf16[16, 128, 6400]" # t127 = prims.convert_element_type(t126, dtypes.float32) # t127: "cuda:0 f32[16, 128, 6400]" # t129 = prims.mul(t127, t127) # t129: "cuda:0 f32[16, 128, 6400]" # t133 = prims.mul(t129, t127) # t133: "cuda:0 f32[16, 128, 6400]" # t136 = prims.mul(0.5, t127) # t136: "cuda:0 f32[16, 128, 6400]" # t139 = prims.mul(0.044715, t133) # t139: "cuda:0 f32[16, 128, 6400]" # t143 = prims.add(t127, t139) # t143: "cuda:0 f32[16, 128, 6400]" # t146 = prims.mul(0.7978845608028654, t143) # t146: "cuda:0 f32[16, 128, 6400]" # t149 = prims.tanh(t146) # t149: "cuda:0 f32[16, 128, 6400]" # t152 = prims.add(1.0, t149) # t152: "cuda:0 f32[16, 128, 6400]" # t156 = prims.mul(t136, t152) # t156: "cuda:0 f32[16, 128, 6400]" # t157 = prims.convert_element_type(t156, dtypes.bfloat16) # t157: "cuda:0 bf16[16, 128, 6400]" # t168 = prims.convert_element_type(t166, dtypes.float32) # t168: "cuda:0 f32[16, 128, 1600]" # t506 = prims.convert_element_type(t27, dtypes.float32) # t506: "cuda:0 f32[16, 128, 1600]" # t511 = prims.mul(f102, t506) # t511: "cuda:0 f32[16, 128, 1600]" # t514 = prims.mul(t168, t511) # t514: "cuda:0 f32[16, 128, 1600]" # t517 = prims.convert_element_type(t514, dtypes.bfloat16) # t517: "cuda:0 bf16[16, 128, 1600]" # t518 = prims.reshape(t517, (2048, 1600)) # t518: "cuda:0 bf16[2048, 1600]" # t519 = prims.matmul(t518, t_mlp_c_proj_weight) # t519: "cuda:0 bf16[2048, 6400]" # t520 = prims.reshape(t519, (16, 128, 6400)) # t520: "cuda:0 bf16[16, 128, 6400]" # t522 = prims.transpose(t518, (1, 0)) # t522: "cuda:0 bf16[1600, 2048]" # t523 = prims.reshape(t157, (2048, 6400)) # t523: "cuda:0 bf16[2048, 6400]" # t524 = prims.matmul(t522, t523) # t524: "cuda:0 bf16[1600, 6400]" # t526 = prims.sum(t514, (0, 1)) # t526: "cuda:0 f32[1600]" # t527 = prims.convert_element_type(t526, dtypes.bfloat16) # t527: "cuda:0 bf16[1600]" # t528 = prims.convert_element_type(t520, dtypes.float32) # t528: "cuda:0 f32[16, 128, 6400]" # t529 = prims.mul(t152, t528) # t529: "cuda:0 f32[16, 128, 6400]" # t530 = prims.mul(t136, t528) # t530: "cuda:0 f32[16, 128, 6400]" # t537 = prims.mul(t149, t149) # t537: "cuda:0 f32[16, 128, 6400]" # t538 = prims.sub(1.0, t537) # t538: "cuda:0 f32[16, 128, 6400]" # t539 = prims.mul(t530, t538) # t539: "cuda:0 f32[16, 128, 6400]" # t543 = prims.mul(f93, t539) # t543: "cuda:0 f32[16, 128, 6400]" # t550 = prims.mul(f91, t543) # t550: "cuda:0 f32[16, 128, 6400]" # t554 = prims.mul(f89, t529) # t554: "cuda:0 f32[16, 128, 6400]" # t558 = prims.add(t543, t554) # t558: "cuda:0 f32[16, 128, 6400]" # t561 = prims.mul(t127, t550) # t561: "cuda:0 f32[16, 128, 6400]" # t562 = prims.mul(t129, t550) # t562: "cuda:0 f32[16, 128, 6400]" # t567 = prims.add(t558, t562) # t567: "cuda:0 f32[16, 128, 6400]" # t570 = prims.mul(t127, t561) # t570: "cuda:0 f32[16, 128, 6400]" # t576 = prims.add(t567, t570) # t576: "cuda:0 f32[16, 128, 6400]" # t580 = prims.add(t576, t570) # t580: "cuda:0 f32[16, 128, 6400]" # t581 = prims.convert_element_type(t580, dtypes.bfloat16) # t581: "cuda:0 bf16[16, 128, 6400]" # t582 = prims.reshape(t581, (2048, 6400)) # t582: "cuda:0 bf16[2048, 6400]" # t583 = prims.matmul(t582, t_mlp_c_fc_weight) # t583: "cuda:0 bf16[2048, 1600]" # t584 = prims.reshape(t583, (16, 128, 1600)) # t584: "cuda:0 bf16[16, 128, 1600]" # t586 = prims.transpose(t582, (1, 0)) # t586: "cuda:0 bf16[6400, 2048]" # t587 = prims.reshape(t125, (2048, 1600)) # t587: "cuda:0 bf16[2048, 1600]" # t588 = prims.matmul(t586, t587) # t588: "cuda:0 bf16[6400, 1600]" # t590 = prims.sum(t580, (0, 1)) # t590: "cuda:0 f32[6400]" # t591 = prims.convert_element_type(t590, dtypes.bfloat16) # t591: "cuda:0 bf16[6400]" # t592 = prims.convert_element_type(t584, dtypes.float32) # t592: "cuda:0 f32[16, 128, 1600]" # t595 = prims.sum(t592, (0, 1)) # t595: "cuda:0 f32[1600]" # t596 = prims.convert_element_type(t595, dtypes.bfloat16) # t596: "cuda:0 bf16[1600]" # t597 = prims.mul(t120, t592) # t597: "cuda:0 f32[16, 128, 1600]" # t598 = prims.mul(t118, t592) # t598: "cuda:0 f32[16, 128, 1600]" # t601 = prims.sum(t598, (0, 1)) # t601: "cuda:0 f32[1600]" # t602 = prims.convert_element_type(t601, dtypes.bfloat16) # t602: "cuda:0 bf16[1600]" # t603 = prims.mul(t117, t597) # t603: "cuda:0 f32[16, 128, 1600]" # t604 = prims.mul(t116, t597) # t604: "cuda:0 f32[16, 128, 1600]" # t605 = prims.sum(t604, (2,)) # t605: "cuda:0 f32[16, 128]" # t606 = prims.broadcast_in_dim(t605, [16, 128, 1], [0, 1]) # t606: "cuda:0 f32[16, 128, 1]" # t607 = prims.neg(t603) # t607: "cuda:0 f32[16, 128, 1600]" # t609 = prims.sum(t607, (2,)) # t609: "cuda:0 f32[16, 128]" # t610 = prims.broadcast_in_dim(t609, [16, 128, 1], [0, 1]) # t610: "cuda:0 f32[16, 128, 1]" # t611 = prims.mul(-0.5, t606) # t611: "cuda:0 f32[16, 128, 1]" # t612 = prims.pow(t113, 3.0) # t612: "cuda:0 f32[16, 128, 1]" # t613 = prims.mul(t611, t612) # t613: "cuda:0 f32[16, 128, 1]" # t615 = prims.sum(t610, (2,)) # t615: "cuda:0 f32[16, 128]" # t616 = prims.sum(t613, (2,)) # t616: "cuda:0 f32[16, 128]" # t619 = prims.broadcast_in_dim(t615, [16, 128, 1], [0, 1]) # t619: "cuda:0 f32[16, 128, 1]" # t620 = prims.broadcast_in_dim(t619, (16, 128, 1600), (0, 1, 2)) # t620: "cuda:0 f32[16, 128, 1600]" # t621 = prims.mul(0.000625, t620) # t621: "cuda:0 f32[16, 128, 1600]" # t623 = prims.broadcast_in_dim(t616, [16, 128, 1], [0, 1]) # t623: "cuda:0 f32[16, 128, 1]" # t624 = prims.broadcast_in_dim(t623, (16, 128, 1600), (0, 1, 2)) # t624: "cuda:0 f32[16, 128, 1600]" # t626 = prims.broadcast_in_dim(t108, [16, 128, 1], [0, 1]) # t626: "cuda:0 f32[16, 128, 1]" # t627 = prims.broadcast_in_dim(t626, (16, 128, 1600), (0, 1, 2)) # t627: "cuda:0 f32[16, 128, 1600]" # t628 = prims.mul(2.0, t624) # t628: "cuda:0 f32[16, 128, 1600]" # t629 = prims.sub(t100, t627) # t629: "cuda:0 f32[16, 128, 1600]" # t630 = prims.mul(t628, t629) # t630: "cuda:0 f32[16, 128, 1600]" # f631 = prims.convert_element_type(i622, float) # f631: "float 1600.0" # t632 = prims.div(t630, f631) # t632: "cuda:0 f32[16, 128, 1600]" # t633 = prims.add(t621, t632) # t633: "cuda:0 f32[16, 128, 1600]" # t637 = prims.add(t603, t633) # t637: "cuda:0 f32[16, 128, 1600]" # t641 = prims.add(t506, t637) # t641: "cuda:0 f32[16, 128, 1600]" # t642 = prims.convert_element_type(t641, dtypes.bfloat16) # t642: "cuda:0 bf16[16, 128, 1600]" # t648 = prims.mul(f78, t641) # t648: "cuda:0 f32[16, 128, 1600]" # t651 = prims.mul(t92, t648) # t651: "cuda:0 f32[16, 128, 1600]" # t654 = prims.convert_element_type(t651, dtypes.bfloat16) # t654: "cuda:0 bf16[16, 128, 1600]" # t655 = prims.reshape(t654, (2048, 1600)) # t655: "cuda:0 bf16[2048, 1600]" # t656 = prims.matmul(t655, t_attn_c_proj_weight) # t656: "cuda:0 bf16[2048, 1600]" # t657 = prims.reshape(t656, (16, 128, 1600)) # t657: "cuda:0 bf16[16, 128, 1600]" # t659 = prims.transpose(t655, (1, 0)) # t659: "cuda:0 bf16[1600, 2048]" # t660 = prims.reshape(t84, (2048, 1600)) # t660: "cuda:0 bf16[2048, 1600]" # t661 = prims.matmul(t659, t660) # t661: "cuda:0 bf16[1600, 1600]" # t663 = prims.sum(t651, (0, 1)) # t663: "cuda:0 f32[1600]" # t664 = prims.convert_element_type(t663, dtypes.bfloat16) # t664: "cuda:0 bf16[1600]" # t668 = prims.reshape(t657, (16, 128, 25, 64)) # t668: "cuda:0 bf16[16, 128, 25, 64]" # t671 = prims.transpose(t668, (0, 2, 1, 3)) # t671: "cuda:0 bf16[16, 25, 128, 64]" # t672 = prims.transpose(t42, (0, 1, 3, 2)) # t672: "cuda:0 bf16[16, 25, 64, 128]" # t673 = prims.matmul(t671, t672) # t673: "cuda:0 bf16[16, 25, 128, 128]" # t674 = prims.transpose(t80, (0, 1, 3, 2)) # t674: "cuda:0 bf16[16, 25, 128, 128]" # t675 = prims.matmul(t674, t671) # t675: "cuda:0 bf16[16, 25, 128, 64]" # t676 = prims.convert_element_type(t673, dtypes.float32) # t676: "cuda:0 f32[16, 25, 128, 128]" # t678 = prims.mul(f62, t676) # t678: "cuda:0 f32[16, 25, 128, 128]" # t681 = prims.mul(t75, t678) # t681: "cuda:0 f32[16, 25, 128, 128]" # t684 = prims.convert_element_type(t681, dtypes.bfloat16) # t684: "cuda:0 bf16[16, 25, 128, 128]" # t685 = prims.convert_element_type(t71, dtypes.float32) # t685: "cuda:0 f32[16, 25, 128, 128]" # t687 = prims.mul(t685, t681) # t687: "cuda:0 f32[16, 25, 128, 128]" # t688 = prims.convert_element_type(t687, dtypes.bfloat16) # t688: "cuda:0 bf16[16, 25, 128, 128]" # t740 = prims.transpose(t675, (0, 2, 1, 3)) # t740: "cuda:0 bf16[16, 128, 25, 64]" # t745 = prims.reshape(t740, (16, 128, 1600)) # t745: "cuda:0 bf16[16, 128, 1600]" del f102, f62, f78, f89, f91, f93, i622, t100, t108, t113, t166, t27, t42, t73, t80, t90, t_attn_c_proj_weight, t_ln_2_bias, t_ln_2_weight, t_mlp_c_fc_bias, t_mlp_c_fc_weight, t_mlp_c_proj_weight t711 = torch.sum(t688, i54, True, dtype=None) # t711: "cuda:0 bf16[16, 25, 128, 1]" # t711 = ltorch.sum(t688, i54, True, dtype=None) # t711: "cuda:0 bf16[16, 25, 128, 1]" # b828 = prims.ge(i54, 0) # b828: "bool False" # b829 = prims.lt(i54, 0) # b829: "bool True" # i830 = prims.add(i54, 4) # i830: "int 3" # b831 = prims.ge(i830, 0) # b831: "bool True" # b832 = prims.lt(i830, 4) # b832: "bool True" # t833 = ltorch.to(t688, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t833: "cuda:0 f32[16, 25, 128, 128]" # t833 = prims.convert_element_type(t688, dtypes.float32) # t833: "cuda:0 f32[16, 25, 128, 128]" # t840 = prims.sum(t833, (i830,)) # t840: "cuda:0 f32[16, 25, 128]" # b841 = prims.eq(i830, 0) # b841: "bool False" # b842 = prims.eq(i830, 1) # b842: "bool False" # b843 = prims.eq(i830, 2) # b843: "bool False" # b844 = prims.eq(i830, 3) # b844: "bool True" # b845 = prims.eq(i830, 0) # b845: "bool False" # b846 = prims.eq(i830, 1) # b846: "bool False" # b847 = prims.eq(i830, 2) # b847: "bool False" # b848 = prims.eq(i830, 3) # b848: "bool True" # t849 = prims.broadcast_in_dim(t840, [16, 25, 128, 1], [0, 1, 2]) # t849: "cuda:0 f32[16, 25, 128, 1]" # t711 = ltorch.to(t849, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t711: "cuda:0 bf16[16, 25, 128, 1]" # t711 = prims.convert_element_type(t849, dtypes.bfloat16) # t711: "cuda:0 bf16[16, 25, 128, 1]" del t688, i54 [t773, t776, t781, t787, t827] = nvFusion1(f43, f47, i807, i9, t4, t642, t684, t71, t711, t745, t8, t_attn_c_attn_bias, t_attn_c_attn_weight, t_ln_1_bias, t_ln_1_weight, x) # t0 = prims.convert_element_type(x, dtypes.float32) # t0: "cuda:0 f32[16, 128, 1600]" # t6 = prims.broadcast_in_dim(t4, [16, 128, 1], [0, 1]) # t6: "cuda:0 f32[16, 128, 1]" # t9 = prims.broadcast_in_dim(t6, (16, 128, 1600), (0, 1, 2)) # t9: "cuda:0 f32[16, 128, 1600]" # t11 = prims.sub(t0, t9) # t11: "cuda:0 f32[16, 128, 1600]" # t12 = prims.broadcast_in_dim(t8, (16, 128, 1600), (0, 1, 2)) # t12: "cuda:0 f32[16, 128, 1600]" # t13 = prims.mul(t11, t12) # t13: "cuda:0 f32[16, 128, 1600]" # t14 = prims.broadcast_in_dim(t_ln_1_weight, (16, 128, 1600), (2,)) # t14: "cuda:0 bf16[16, 128, 1600]" # t15 = prims.convert_element_type(t14, dtypes.float32) # t15: "cuda:0 f32[16, 128, 1600]" # t16 = prims.mul(t13, t15) # t16: "cuda:0 f32[16, 128, 1600]" # t17 = prims.broadcast_in_dim(t_ln_1_bias, (16, 128, 1600), (2,)) # t17: "cuda:0 bf16[16, 128, 1600]" # t18 = prims.convert_element_type(t17, dtypes.float32) # t18: "cuda:0 f32[16, 128, 1600]" # t19 = prims.add(t16, t18) # t19: "cuda:0 f32[16, 128, 1600]" # t20 = prims.convert_element_type(t19, dtypes.bfloat16) # t20: "cuda:0 bf16[16, 128, 1600]" # t21 = prims.linear(t20, t_attn_c_attn_weight, t_attn_c_attn_bias) # t21: "cuda:0 bf16[16, 128, 4800]" # t22 = prims.slice_prim(t21, [0, 0, 0], [16, 128, 1600], [1, 1, 1]) # t22: "cuda:0 bf16[16, 128, 1600]" # t23 = prims.slice_prim(t21, [0, 0, 1600], [16, 128, 3200], [1, 1, 1]) # t23: "cuda:0 bf16[16, 128, 1600]" # t25 = prims.reshape(t23, (16, 128, 25, 64)) # t25: "cuda:0 bf16[16, 128, 25, 64]" # t26 = prims.transpose(t25, (0, 2, 1, 3)) # t26: "cuda:0 bf16[16, 25, 128, 64]" # t31 = prims.reshape(t22, (16, 128, 25, 64)) # t31: "cuda:0 bf16[16, 128, 25, 64]" # t34 = prims.transpose(t31, (0, 2, 1, 3)) # t34: "cuda:0 bf16[16, 25, 128, 64]" # t43 = prims.convert_element_type(t34, dtypes.float32) # t43: "cuda:0 f32[16, 25, 128, 64]" # t44 = prims.mul(t43, 0.3535533905932738) # t44: "cuda:0 f32[16, 25, 128, 64]" # t45 = prims.convert_element_type(t44, dtypes.bfloat16) # t45: "cuda:0 bf16[16, 25, 128, 64]" # t46 = prims.transpose(t26, (0, 1, 3, 2)) # t46: "cuda:0 bf16[16, 25, 64, 128]" # t47 = prims.convert_element_type(t46, dtypes.float32) # t47: "cuda:0 f32[16, 25, 64, 128]" # t48 = prims.mul(t47, 0.3535533905932738) # t48: "cuda:0 f32[16, 25, 64, 128]" # t49 = prims.convert_element_type(t48, dtypes.bfloat16) # t49: "cuda:0 bf16[16, 25, 64, 128]" # t51 = prims.iota(128, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t51: "cuda:0 i64[128]" # t52 = prims.broadcast_in_dim(t51, [128, 1], [0]) # t52: "cuda:0 i64[128, 1]" # t54 = prims.broadcast_in_dim(t51, [1, 128], [1]) # t54: "cuda:0 i64[1, 128]" # t55 = prims.add(t52, 0) # t55: "cuda:0 i64[128, 1]" # t56 = prims.broadcast_in_dim(t55, (128, 128), (0, 1)) # t56: "cuda:0 i64[128, 128]" # t57 = prims.broadcast_in_dim(t54, (128, 128), (0, 1)) # t57: "cuda:0 i64[128, 128]" # t58 = prims.ge(t56, t57) # t58: "cuda:0 b8[128, 128]" # t59 = prims.broadcast_in_dim(t58, (16, 25, 128, 128), (2, 3)) # t59: "cuda:0 b8[16, 25, 128, 128]" # t824 = prims.convert_element_type(t642, dtypes.float32) # t824: "cuda:0 f32[16, 128, 1600]" # t712 = prims.broadcast_in_dim(t711, (16, 25, 128, 128), (0, 1, 2, 3)) # t712: "cuda:0 bf16[16, 25, 128, 128]" # t713 = prims.convert_element_type(t684, dtypes.float32) # t713: "cuda:0 f32[16, 25, 128, 128]" # t714 = prims.convert_element_type(t712, dtypes.float32) # t714: "cuda:0 f32[16, 25, 128, 128]" # t715 = prims.sub(t713, t714) # t715: "cuda:0 f32[16, 25, 128, 128]" # t717 = prims.convert_element_type(t71, dtypes.float32) # t717: "cuda:0 f32[16, 25, 128, 128]" # t719 = prims.mul(t717, t715) # t719: "cuda:0 f32[16, 25, 128, 128]" # t720 = prims.convert_element_type(t719, dtypes.bfloat16) # t720: "cuda:0 bf16[16, 25, 128, 128]" # t722 = prims.where(t59, t720, 0.0) # t722: "cuda:0 bf16[16, 25, 128, 128]" # t723 = prims.transpose(t49, (0, 1, 3, 2)) # t723: "cuda:0 bf16[16, 25, 128, 64]" # t724 = prims.matmul(t722, t723) # t724: "cuda:0 bf16[16, 25, 128, 64]" # t725 = prims.transpose(t45, (0, 1, 3, 2)) # t725: "cuda:0 bf16[16, 25, 64, 128]" # t726 = prims.matmul(t725, t722) # t726: "cuda:0 bf16[16, 25, 64, 128]" # t727 = prims.convert_element_type(t726, dtypes.float32) # t727: "cuda:0 f32[16, 25, 64, 128]" # t729 = prims.mul(f47, t727) # t729: "cuda:0 f32[16, 25, 64, 128]" # t730 = prims.convert_element_type(t729, dtypes.bfloat16) # t730: "cuda:0 bf16[16, 25, 64, 128]" # t733 = prims.transpose(t730, (0, 1, 3, 2)) # t733: "cuda:0 bf16[16, 25, 128, 64]" # t734 = prims.convert_element_type(t724, dtypes.float32) # t734: "cuda:0 f32[16, 25, 128, 64]" # t736 = prims.mul(f43, t734) # t736: "cuda:0 f32[16, 25, 128, 64]" # t737 = prims.convert_element_type(t736, dtypes.bfloat16) # t737: "cuda:0 bf16[16, 25, 128, 64]" # t748 = prims.transpose(t737, (0, 2, 1, 3)) # t748: "cuda:0 bf16[16, 128, 25, 64]" # t753 = prims.reshape(t748, (16, 128, 1600)) # t753: "cuda:0 bf16[16, 128, 1600]" # t756 = prims.transpose(t733, (0, 2, 1, 3)) # t756: "cuda:0 bf16[16, 128, 25, 64]" # t761 = prims.reshape(t756, (16, 128, 1600)) # t761: "cuda:0 bf16[16, 128, 1600]" # t766 = prims.cat((t753, t761, t745), i9) # t766: "cuda:0 bf16[16, 128, 4800]" # t767 = prims.reshape(t766, (2048, 4800)) # t767: "cuda:0 bf16[2048, 4800]" # t768 = prims.matmul(t767, t_attn_c_attn_weight) # t768: "cuda:0 bf16[2048, 1600]" # t769 = prims.reshape(t768, (16, 128, 1600)) # t769: "cuda:0 bf16[16, 128, 1600]" # t771 = prims.transpose(t767, (1, 0)) # t771: "cuda:0 bf16[4800, 2048]" # t772 = prims.reshape(t20, (2048, 1600)) # t772: "cuda:0 bf16[2048, 1600]" # t773 = prims.matmul(t771, t772) # t773: "cuda:0 bf16[4800, 1600]" # t774 = prims.convert_element_type(t766, dtypes.float32) # t774: "cuda:0 f32[16, 128, 4800]" # t775 = prims.sum(t774, (0, 1)) # t775: "cuda:0 f32[4800]" # t776 = prims.convert_element_type(t775, dtypes.bfloat16) # t776: "cuda:0 bf16[4800]" # t777 = prims.convert_element_type(t769, dtypes.float32) # t777: "cuda:0 f32[16, 128, 1600]" # t780 = prims.sum(t777, (0, 1)) # t780: "cuda:0 f32[1600]" # t781 = prims.convert_element_type(t780, dtypes.bfloat16) # t781: "cuda:0 bf16[1600]" # t782 = prims.mul(t15, t777) # t782: "cuda:0 f32[16, 128, 1600]" # t783 = prims.mul(t13, t777) # t783: "cuda:0 f32[16, 128, 1600]" # t786 = prims.sum(t783, (0, 1)) # t786: "cuda:0 f32[1600]" # t787 = prims.convert_element_type(t786, dtypes.bfloat16) # t787: "cuda:0 bf16[1600]" # t788 = prims.mul(t12, t782) # t788: "cuda:0 f32[16, 128, 1600]" # t789 = prims.mul(t11, t782) # t789: "cuda:0 f32[16, 128, 1600]" # t790 = prims.sum(t789, (2,)) # t790: "cuda:0 f32[16, 128]" # t791 = prims.broadcast_in_dim(t790, [16, 128, 1], [0, 1]) # t791: "cuda:0 f32[16, 128, 1]" # t792 = prims.neg(t788) # t792: "cuda:0 f32[16, 128, 1600]" # t794 = prims.sum(t792, (2,)) # t794: "cuda:0 f32[16, 128]" # t795 = prims.broadcast_in_dim(t794, [16, 128, 1], [0, 1]) # t795: "cuda:0 f32[16, 128, 1]" # t796 = prims.mul(-0.5, t791) # t796: "cuda:0 f32[16, 128, 1]" # t797 = prims.pow(t8, 3.0) # t797: "cuda:0 f32[16, 128, 1]" # t798 = prims.mul(t796, t797) # t798: "cuda:0 f32[16, 128, 1]" # t800 = prims.sum(t795, (2,)) # t800: "cuda:0 f32[16, 128]" # t801 = prims.sum(t798, (2,)) # t801: "cuda:0 f32[16, 128]" # t804 = prims.broadcast_in_dim(t800, [16, 128, 1], [0, 1]) # t804: "cuda:0 f32[16, 128, 1]" # t805 = prims.broadcast_in_dim(t804, (16, 128, 1600), (0, 1, 2)) # t805: "cuda:0 f32[16, 128, 1600]" # t806 = prims.mul(0.000625, t805) # t806: "cuda:0 f32[16, 128, 1600]" # t808 = prims.broadcast_in_dim(t801, [16, 128, 1], [0, 1]) # t808: "cuda:0 f32[16, 128, 1]" # t809 = prims.broadcast_in_dim(t808, (16, 128, 1600), (0, 1, 2)) # t809: "cuda:0 f32[16, 128, 1600]" # t811 = prims.broadcast_in_dim(t4, [16, 128, 1], [0, 1]) # t811: "cuda:0 f32[16, 128, 1]" # t812 = prims.broadcast_in_dim(t811, (16, 128, 1600), (0, 1, 2)) # t812: "cuda:0 f32[16, 128, 1600]" # t813 = prims.mul(2.0, t809) # t813: "cuda:0 f32[16, 128, 1600]" # t814 = prims.sub(t0, t812) # t814: "cuda:0 f32[16, 128, 1600]" # t815 = prims.mul(t813, t814) # t815: "cuda:0 f32[16, 128, 1600]" # f816 = prims.convert_element_type(i807, float) # f816: "float 1600.0" # t817 = prims.div(t815, f816) # t817: "cuda:0 f32[16, 128, 1600]" # t818 = prims.add(t806, t817) # t818: "cuda:0 f32[16, 128, 1600]" # t822 = prims.add(t788, t818) # t822: "cuda:0 f32[16, 128, 1600]" # t826 = prims.add(t824, t822) # t826: "cuda:0 f32[16, 128, 1600]" # t827 = prims.convert_element_type(t826, dtypes.bfloat16) # t827: "cuda:0 bf16[16, 128, 1600]" del f43, f47, i807, i9, t4, t642, t684, t71, t711, t745, t8, t_attn_c_attn_bias, t_attn_c_attn_weight, t_ln_1_bias, t_ln_1_weight, x return (t827, t776, t773, t664, t661, t781, t787, t596, t602, t591, t588, t527, t524) . ------------------------------------------------------ benchmark: 1 tests ----------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------- test_nanogpt_block[backward-thunder] 3.7064 3.9829 3.7802 0.0415 3.7751 0.0167 32;33 264.5370 288 1 ------------------------------------------------------------------------------------------------------------------------------- Legend: Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. OPS: Operations Per Second, computed as 1 / Mean =========================================================================================================================================================================================================================== 1 passed, 643 deselected, 7 warnings in 16.27s ============================================================================================================================================================================================================================ ```
wujingyue commented 4 months ago

If a proper fix takes a long time, I'm happy to take a workaround to my local branch to unblock myself.

IvanYashchuk commented 3 months ago

construct_trace used in grad transform mistakenly converts all saved_for_backward into proxies. We have saved_for_backward=(((t, 0.5), None, ([0.5],)),), which was later translated to bw_flat_saved_for_backward=[t, [FloatProxy name=f0, value=0.5], None, [FloatProxy name=f1, value=0.5]] in backward trace and that is not right. The backward trace is treating that as a proxy f1, instead of baking static numbers in.

Why do you think static numbers should be used for backward trace? How does it make a difference today and what would you like to see when you're confident that Thunder supports NumberProxies properly?

For the example provided in the issue description, it's not this piece of code's fault that ab does not even appear in the function signature. This code was written to be independent of the "frontend" that acquires the trace. The current frontend for some reason chooses to drop these numbers from function inputs. Let's go back in time to 955e45397c5757ef8d4f6e94f70d410d7f8ebe4d and we would see ab there:

import torch
import thunder

def foo(t, ab):
    return t * ab

jfoo = thunder.compile(foo)

dtype = torch.float32
t = torch.randn(5, 3, device="cuda").to(dtype=dtype)
t.requires_grad_()

ab = 0.5

out = jfoo(t, ab)
print(thunder.last_traces(jfoo)[0][0])
# Constructed by Augmented forward pass
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(t, ab):
  # t: "cuda:0 f32[5, 3]" 
  # ab: "float 0.5" 
  t0 = prims.mul(t, ab)  # t0: "cuda:0 f32[5, 3]"
  return {'output': t0, 'flat_args': [t, ab], 'flat_output': (t0,)}, ((t,), (0.5,))

Let's put aside the implementation of forward_and_backward_from_trace. Is the problem that there are two nvFuser regions created in backward (https://github.com/Lightning-AI/lightning-thunder/issues/541#issuecomment-2155706083) instead of one. There's no runtime error, right?

jjsjann123 commented 3 months ago

I think my example is a bit misleading. I'm not arguing the behavior about the initial trace (this is a cache logic), nor am I arguing about the nvfuser logic (this is merely a side effect from our number proxy handling).

Why do you think static numbers should be used for backward trace? How does it make a difference today and what would you like to see when you're confident that Thunder supports NumberProxies properly?

Static numbers should be used for backward trace here because it's a static number in forward trace. Likewise, if it's indeed a proxy number on forwards trace and used in backward trace, we should have a proxy in bwd trace as well. (This is what @kiya00 needed in #481).

It matters today since we are ramping up number proxy support and we need a consistent behavior.

Let's put aside the implementation of forward_and_backward_from_trace

In your posted trace, when we have ab show up on forward trace, it actually looks more scary

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(t, ab):
  # t: "cuda:0 f32[5, 3]" 
  # ab: "float 0.5" 
  t0 = prims.mul(t, ab)  # t0: "cuda:0 f32[5, 3]"
  return {'output': t0, 'flat_args': [t, ab], 'flat_output': (t0,)}, ((t,), (0.5,))

we are saving 0.5 as a static number for backward, which means if backward uses that number directly, it could generate wrong result when runtime input ab changes...

This leads to my original question, grad_transform should preserve static / dynamic characteristics of Number/NumberProxy in forward trace, otherwise our cached program can't even guarantee correctness.

jjsjann123 commented 3 months ago

Let's put aside the implementation of forward_and_backward_from_trace.

Anyway, maybe forward_and_backward_from_trace is not the root-cause. That's what I'm trying to figure out, what function should I be looking at in order to patch grad transform to consistently handle numberproxy between fwd/bwd? Should I also be cautious about impact on rematerialization pass?

jjsjann123 commented 3 months ago

FYI, if we indeed want to have numberproxy showing up on inputs, we need to treat that as numberproxy. i.e. using symbolic values instead in that case.

I'm using this repro for myself as the target for this issue.

import torch
import thunder

def foo(t, ab):
    return t * ab * 0.5

jfoo = thunder.jit(foo, cache="symbolic values")

dtype = torch.float32
t = torch.randn(5, 3, device="cuda").to(dtype=dtype)

t_ref = t.detach()

t.requires_grad_()
t_ref.requires_grad_()

out = jfoo(t, 1.5)
out_ref = foo(t_ref, 1.5)

print("\n\tprologue:\n", thunder.last_prologue_traces(jfoo)[0])
print("\n\tcompute:\n", thunder.last_traces(jfoo)[0])
print("\n\tcompute last trace:\n", thunder.last_traces(jfoo)[-1])
print("\n\tcompute last backward trace:\n", thunder.last_backward_traces(jfoo)[-1])

assert(out.allclose(out_ref))
out.sum().backward()
out_ref.sum().backward()
assert(t.grad.allclose(t_ref.grad))

t.grad = None
t_ref.grad = None

out = jfoo(t, 2.7)
out_ref = foo(t_ref, 2.7)

assert(out.allclose(out_ref))
out.sum().backward()
out_ref.sum().backward()
print(t.grad)
print(t_ref.grad)
assert(t.grad.allclose(t_ref.grad))
IvanYashchuk commented 3 months ago

what function should I be looking at in order to patch grad transform to consistently handle numberproxy between fwd/bwd?

This is the correct line responsible for unwrapping values from proxies in the forward pass https://github.com/Lightning-AI/lightning-thunder/blob/0342223ac7851beb07d7df731389db777b58f1ac/thunder/core/transforms.py#L3449

If rematerialization breaks when this line is removed it's a bug that needs to be fixed. I don't remember why exactly this line was necessary. Probably something was adding the numberproxies to the trace but never recording the operation that produced it. For the example from https://github.com/Lightning-AI/lightning-thunder/issues/541#issuecomment-2159734931 I see this problem

File thunder.augmented_forward_fn_2:11, in augmented_forward_fn(t, ab)
      5 @torch.no_grad()
      6 @no_autocast
      7 def augmented_forward_fn(t, ab):
      8   # t: "cuda:0 f32[5, 3]"
      9   # ab: "float 1.5"
     10   [t1] = nvFusion0(ab, t)
---> 11   return {'output': t1, 'flat_args': [t, ab], 'flat_output': (t1,)}, ((), (ab, f1))

NameError: name 'f1' is not defined

Yan's fix seems like a good workaround https://github.com/Lightning-AI/lightning-thunder/pull/244 and it also fixes the assert. What do you think about it, Jie, should we merge that PR?

Should I also be cautious about impact on rematerialization pass?

No, rematerialization should be working with constant numbers as well as symbolic ones. Jie, could you please provide a failing example? I'll help fix the problem.

we are saving 0.5 as a static number for backward, which means if backward uses that number directly, it could generate wrong result when runtime input ab changes...

The hope is that current Thunder's caching doesn't allow this.

Static numbers should be used for backward trace here because it's a static number in forward trace.

Here's my reasoning: forward and backward functions are separate pure functions. Part of the forward result is passed to the backward as input. Any number input to Thunder functions should be proxified and that's why backward trace uses number proxies independent of whether it was static or symbolic value in forward.

Unwrapping number proxies for forward function output is bad. We/I will fix it.

Jie, do you think that having numberproxies in the backward trace is bad? Do you think we need to change anything?

IvanYashchuk commented 3 months ago

Running the benchmark following the instructions in https://github.com/Lightning-AI/lightning-thunder/issues/541#issuecomment-2155706083 I hit

Traceback (most recent call last):
  File "/home/iyashchuk/dev/Fuser/nvfuser/__init__.py", line 146, in execute
    result = self._execute(
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/Fuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment

I probably need to update the nvFuser installation. Or is this the error that we need to fix?

jjsjann123 commented 3 months ago

Here's my reasoning: forward and backward functions are separate pure functions. Part of the forward result is passed to the backward as input. Any number input to Thunder functions should be proxified and that's why backward trace uses number proxies independent of whether it was static or symbolic value in forward. Unwrapping number proxies for forward function output is bad. We/I will fix it. Jie, do you think that having numberproxies in the backward trace is bad? Do you think we need to change anything?

glad to see that we are on the same page here.

Part of the forward result is passed to the backward as input. Any number input to Thunder functions should be proxified

yeah, I think that's where we should have it fixed. I don't think having numberproxies in the backward trace is bad. It's necessary. But it needs to faithfully reflect what they are from the forward trace.

I don't think Yan's PR is enough as-is. But I think it's in the right direction. We can follow up and work together on that one.

wujingyue commented 3 months ago

Running the benchmark following the instructions in #541 (comment) I hit

Traceback (most recent call last):
  File "/home/iyashchuk/dev/Fuser/nvfuser/__init__.py", line 146, in execute
    result = self._execute(
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/Fuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment

I probably need to update the nvFuser installation. Or is this the error that we need to fix?

I forgot which nvFuser version I was using... The benchmark ran fine when I just resynced to https://github.com/NVIDIA/Fuser/commit/b56c3e7960f18e93c2d806d3167c961ebe7f2b20. Can you give it another try?

jjsjann123 commented 2 months ago

linking resnet issue regarding grad transform: https://github.com/Lightning-AI/lightning-thunder/pull/451#issuecomment-2221056914