Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float32 for Dynamo+Thunder and Mixtral-8x7B-v0.1 #1093

Closed mpatel31415 closed 2 weeks ago

mpatel31415 commented 3 weeks ago

🐛 Bug

When running the benchmarking we get:

RuntimeError: Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float32!

To Reproduce

Please use: 8 nodes, each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl-20240830"

Training script: python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \ --model_name Mixtral-8x7B-v0.1 \ --distributed_mode fsdp \ --shard_mode zero3 \ --compile dynamo_thunder \ --checkpoint_activations True \ --low_precision_mode fp8-delayed-te \ --micro_batch_size 1

Expected behavior

We should not get this error :)

Environment

system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.1.005 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.6 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.10+git58dfdc1 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.5.0a0+git578b8d7 libraries.pip.torchmetrics 1.4.1 libraries.pip.torchvision 0.19.0a0+d23a6e1

kshitij12345 commented 3 weeks ago

Applying the patch from https://github.com/Lightning-AI/lightning-thunder/pull/1075 allows this to run

cc: @IvanYashchuk

IvanYashchuk commented 3 weeks ago

What are the unsupported operations that are executed with a fallback in the splitter?

kshitij12345 commented 3 weeks ago

These are the reasons for split -

Subgraph and Split Reason Dynamo Graph 1 ```python class GraphModule(torch.nn.Module): def forward(self, L_x_: "bf16[1, 32768, 4096]", L_self_weight: "bf16[4096]"): l_x_ = L_x_ l_self_weight = L_self_weight # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:429 in forward, code: x = x.float() x: "f32[1, 32768, 4096]" = l_x_.float(); l_x_ = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:431 in forward, code: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) mul: "f32[1, 32768, 4096]" = x * x norm_x: "f32[1, 32768, 1]" = torch.mean(mul, dim = -1, keepdim = True); mul = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:432 in forward, code: x_normed = x * torch.rsqrt(norm_x + self.eps) add: "f32[1, 32768, 1]" = norm_x + 1e-05; norm_x = None rsqrt: "f32[1, 32768, 1]" = torch.rsqrt(add); add = None x_normed: "f32[1, 32768, 4096]" = x * rsqrt; x = rsqrt = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:433 in forward, code: x_normed = x_normed.to(dtype=dtype) x_normed_1: "bf16[1, 32768, 4096]" = x_normed.to(dtype = torch.bfloat16); x_normed = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:438 in forward, code: return x_normed * self.weight mul_2: "bf16[1, 32768, 4096]" = x_normed_1 * l_self_weight; x_normed_1 = l_self_weight = None return (mul_2,) ``` Split Reason ``` [SplitReason(type=, info="node with name: x and target: float didn't have any mapping " 'in thunder.', exception=None)] ``` Dynamo Graph 2 ```python class GraphModule(torch.nn.Module): def forward(self, L_stack0_: "bf16[1, 32768, 6144]", L_cos_: "bf16[32768, 128]", L_sin_: "bf16[32768, 128]"): l_stack0_ = L_stack0_ l_cos_ = L_cos_ l_sin_ = L_sin_ # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:220 in torch_dynamo_resume_in_forward_at_215, code: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) qkv: "bf16[1, 32768, 8, 6, 128]" = l_stack0_.view(1, 32768, 8, 6, 128); l_stack0_ = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:221 in torch_dynamo_resume_in_forward_at_215, code: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) qkv_1: "bf16[1, 8, 6, 32768, 128]" = qkv.permute(0, 2, 3, 1, 4); qkv = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:224 in torch_dynamo_resume_in_forward_at_215, code: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) split = qkv_1.split((4, 1, 1), dim = 2); qkv_1 = None q: "bf16[1, 8, 4, 32768, 128]" = split[0] k: "bf16[1, 8, 1, 32768, 128]" = split[1] v: "bf16[1, 8, 1, 32768, 128]" = split[2]; split = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:230 in torch_dynamo_resume_in_forward_at_215, code: k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) k_1: "bf16[1, 8, 4, 32768, 128]" = k.expand(1, 8, 4, 32768, 128); k = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:231 in torch_dynamo_resume_in_forward_at_215, code: v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) v_1: "bf16[1, 8, 4, 32768, 128]" = v.expand(1, 8, 4, 32768, 128); v = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:233 in torch_dynamo_resume_in_forward_at_215, code: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) q_1: "bf16[1, 32, 32768, 128]" = q.reshape(1, -1, 32768, 128); q = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:234 in torch_dynamo_resume_in_forward_at_215, code: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) k_2: "bf16[1, 32, 32768, 128]" = k_1.reshape(1, -1, 32768, 128); k_1 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:235 in torch_dynamo_resume_in_forward_at_215, code: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) v_2: "bf16[1, 32, 32768, 128]" = v_1.reshape(1, -1, 32768, 128); v_1 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:237 in torch_dynamo_resume_in_forward_at_215, code: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) getitem_3: "bf16[1, 32, 32768, 128]" = q_1[(Ellipsis, slice(None, 128, None))] # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:375 in apply_rope, code: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x1: "bf16[1, 32, 32768, 64]" = getitem_3[(Ellipsis, slice(None, 64, None))] # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:376 in apply_rope, code: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) x2: "bf16[1, 32, 32768, 64]" = getitem_3[(Ellipsis, slice(64, None, None))] # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:377 in apply_rope, code: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) neg: "bf16[1, 32, 32768, 64]" = -x2; x2 = None rotated: "bf16[1, 32, 32768, 128]" = torch.cat((neg, x1), dim = -1); neg = x1 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:378 in apply_rope, code: roped = (x * cos) + (rotated * sin) mul: "bf16[1, 32, 32768, 128]" = getitem_3 * l_cos_; getitem_3 = None mul_1: "bf16[1, 32, 32768, 128]" = rotated * l_sin_; rotated = None roped: "bf16[1, 32, 32768, 128]" = mul + mul_1; mul = mul_1 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:379 in apply_rope, code: return roped.to(dtype=x.dtype) q_roped: "bf16[1, 32, 32768, 128]" = roped.to(dtype = torch.bfloat16); roped = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:238 in torch_dynamo_resume_in_forward_at_215, code: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) getitem_6: "bf16[1, 32, 32768, 128]" = k_2[(Ellipsis, slice(None, 128, None))] # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:375 in apply_rope, code: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x1_1: "bf16[1, 32, 32768, 64]" = getitem_6[(Ellipsis, slice(None, 64, None))] # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:376 in apply_rope, code: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) x2_1: "bf16[1, 32, 32768, 64]" = getitem_6[(Ellipsis, slice(64, None, None))] # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:377 in apply_rope, code: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) neg_1: "bf16[1, 32, 32768, 64]" = -x2_1; x2_1 = None rotated_1: "bf16[1, 32, 32768, 128]" = torch.cat((neg_1, x1_1), dim = -1); neg_1 = x1_1 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:378 in apply_rope, code: roped = (x * cos) + (rotated * sin) mul_2: "bf16[1, 32, 32768, 128]" = getitem_6 * l_cos_; getitem_6 = l_cos_ = None mul_3: "bf16[1, 32, 32768, 128]" = rotated_1 * l_sin_; rotated_1 = l_sin_ = None roped_1: "bf16[1, 32, 32768, 128]" = mul_2 + mul_3; mul_2 = mul_3 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:379 in apply_rope, code: return roped.to(dtype=x.dtype) k_roped: "bf16[1, 32, 32768, 128]" = roped_1.to(dtype = torch.bfloat16); roped_1 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:239 in torch_dynamo_resume_in_forward_at_215, code: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) getitem_9: "bf16[1, 32, 32768, 0]" = q_1[(Ellipsis, slice(128, None, None))]; q_1 = None q_2: "bf16[1, 32, 32768, 128]" = torch.cat((q_roped, getitem_9), dim = -1); q_roped = getitem_9 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:240 in torch_dynamo_resume_in_forward_at_215, code: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) getitem_10: "bf16[1, 32, 32768, 0]" = k_2[(Ellipsis, slice(128, None, None))]; k_2 = None k_3: "bf16[1, 32, 32768, 128]" = torch.cat((k_roped, getitem_10), dim = -1); k_roped = getitem_10 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:258 in scaled_dot_product_attention, code: y = torch.nn.functional.scaled_dot_product_attention( y: "bf16[1, 32, 32768, 128]" = torch._C._nn.scaled_dot_product_attention(q_2, k_3, v_2, attn_mask = None, dropout_p = 0.0, scale = 0.08838834764831843, is_causal = True); q_2 = k_3 = v_2 = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:261 in scaled_dot_product_attention, code: return y.transpose(1, 2) y_1: "bf16[1, 32768, 32, 128]" = y.transpose(1, 2); y = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:249 in torch_dynamo_resume_in_forward_at_215, code: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side y_2: "bf16[1, 32768, 4096]" = y_1.reshape(1, 32768, 4096); y_1 = None return (y_2,) ``` Split Reasons ``` [SplitReason(type=, info='Failed while running meta for node with name: rotated and ' 'target: , see exception field', exception=ValueError("neg had an unexpected type . Supported types are (, , )")), SplitReason(type=, info='Failed while running meta for node with name: rotated_1 and ' 'target: , see exception field', exception=ValueError("neg_1 had an unexpected type . Supported types are (, , )")), SplitReason(type=, info='Failed while running meta for node with name: q_2 and ' 'target: , see exception field', exception=ValueError("q_roped had an unexpected type . Supported types are (, , )")), SplitReason(type=, info='Failed while running meta for node with name: k_3 and ' 'target: , see exception field', exception=ValueError("k_roped had an unexpected type . Supported types are (, , )"))] ``` Dynamo Graph 3 ```python class GraphModule(torch.nn.Module): def forward(self, L_stack0_: "bf16[32768, 8]", L_x_: "bf16[32768, 4096]"): l_stack0_ = L_stack0_ l_x_ = L_x_ # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:341 in torch_dynamo_resume_in_forward_at_340, code: probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token) topk = torch.topk(l_stack0_, 2); l_stack0_ = None probs: "bf16[32768, 2]" = topk[0] indices: "i64[32768, 2]" = topk[1]; topk = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:342 in torch_dynamo_resume_in_forward_at_340, code: probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) softmax: "f32[32768, 2]" = probs.softmax(dim = 1, dtype = torch.float32); probs = None probs_1: "bf16[32768, 2]" = softmax.to(dtype = torch.bfloat16); softmax = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:343 in torch_dynamo_resume_in_forward_at_340, code: masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device) unsqueeze: "i64[32768, 2, 1]" = indices.unsqueeze(-1); indices = None arange: "i64[8]" = torch.arange(8, device = device(type='cuda', index=1)) masks: "b8[32768, 2, 8]" = unsqueeze == arange; unsqueeze = arange = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:344 in torch_dynamo_resume_in_forward_at_340, code: masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token) masks_1: "b8[8, 32768, 2]" = masks.permute(2, 0, 1); masks = None # File: /home/kkalambarkar/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/litgpt/model.py:345 in torch_dynamo_resume_in_forward_at_340, code: y = torch.zeros_like(x) # (B*T, C) y: "bf16[32768, 4096]" = torch.zeros_like(l_x_); l_x_ = None return (masks_1, probs_1, y) ``` Split Reason ``` [SplitReason(type=, info="node with name: softmax and target: softmax didn't have any " 'mapping in thunder.', exception=None)] ```

Honestly, I think the splits are due to a bug in is_node_supported and that allows for a failing graph to work because of splitting😅

On digging a bit more, I found the culprit subgraph and here is the minimal repro.

import torch
import thunder
from torch import device

def forward(L_stack0_: "bf16[32768, 8]", L_x_: "bf16[32768, 4096]"):
    l_stack0_ = L_stack0_
    l_x_ = L_x_

    topk = torch.topk(l_stack0_, 2);  l_stack0_ = None
    probs: "bf16[32768, 2]" = topk[0]
    indices: "i64[32768, 2]" = topk[1];  topk = None
    softmax: "f32[32768, 2]" = probs.softmax(dim = 1, dtype = torch.float32);  probs = None    
    return softmax

jforward = thunder.jit(forward)

L_stack0_ = torch.randn([32768, 8], dtype=torch.bfloat16, device='cuda', requires_grad=True)
L_x_ = torch.randn([32768, 4096], dtype=torch.bfloat16, device='cuda', requires_grad=True)

jforward(L_stack0_, L_x_)
kshitij12345 commented 2 weeks ago

@IvanYashchuk

While investigating this, I think I have stumbled upon a rematerialization issue. Is my understanding correct that for rematerialization, there is an implicit expectation that the forward and backward trace shouldn't have tensor proxy name collision except for the proxies passed as saved_for_backward?

Context - For the script below, after applying patch from https://github.com/Lightning-AI/lightning-thunder/pull/1067 (without this I see another remat error)

import torch
import thunder
from torch import device

def forward(L_stack0_: "bf16[32768, 8]", L_x_: "bf16[32768, 4096]"):
    l_stack0_ = L_stack0_

    probs = l_stack0_
    softmax: "f32[32768, 2]" = probs.softmax(dim = 1, dtype=torch.float);  probs = None    
    return softmax

jforward = thunder.jit(forward)

L_stack0_ = torch.randn([32768, 8], dtype=torch.bfloat16, device='cuda', requires_grad=True)
L_x_ = torch.randn([32768, 4096], dtype=torch.bfloat16, device='cuda', requires_grad=True)

jforward(L_stack0_, L_x_)

This is the joint trace generated (before remat)

def joint_fn(args, kwargs, cotangents):
  # L_stack0_: "cuda:0 bf16[32768, 8]"
  [t9] = nvFusion0(L_stack0_)
    # t0 = prims.convert_element_type(L_stack0_, dtypes.float32)  # t0: "cuda:0 f32[32768, 8]"
    # t1 = prims.amax(t0, (1,))  # t1: "cuda:0 f32[32768]"
    # t2 = prims.broadcast_in_dim(t1, [32768, 1], [0])  # t2: "cuda:0 f32[32768, 1]"
    # t3 = prims.broadcast_in_dim(t2, (32768, 8), (0, 1))  # t3: "cuda:0 f32[32768, 8]"
    # t4 = prims.sub(t0, t3)  # t4: "cuda:0 f32[32768, 8]"
    # t5 = prims.exp(t4)  # t5: "cuda:0 f32[32768, 8]"
    # t6 = prims.sum(t5, (1,))  # t6: "cuda:0 f32[32768]"
    # t7 = prims.broadcast_in_dim(t6, [32768, 1], [0])  # t7: "cuda:0 f32[32768, 1]"
    # t8 = prims.broadcast_in_dim(t7, (32768, 8), (0, 1))  # t8: "cuda:0 f32[32768, 8]"
    # t9 = prims.div(t5, t8)  # t9: "cuda:0 f32[32768, 8]"
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  t10, = cotangents
  t9, = C0
  [t7] = nvFusion0(t9, t10)
    # t2 = prims.mul(t9, t10)  # t2: "cuda:0 f32[32768, 8]"
    # t3 = prims.sum(t2, (1,))  # t3: "cuda:0 f32[32768]"
    # t4 = prims.broadcast_in_dim(t3, [32768, 1], [0])  # t4: "cuda:0 f32[32768, 1]"
    # t5 = prims.broadcast_in_dim(t4, (32768, 8), (0, 1))  # t5: "cuda:0 f32[32768, 8]"
    # t6 = prims.sub(t10, t5)  # t6: "cuda:0 f32[32768, 8]"
    # t7 = prims.mul(t9, t6)  # t7: "cuda:0 f32[32768, 8]"
  return {'output': t9, 'flat_args': [L_stack0_], 'flat_output': (t9,)}, ((t7,),)

Remat finds a cut at t4 , we can see that we have t4 in both forward and backward section but both are derived from different computations.

It errors with

And it errors with 
  File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 610, in rematerialize_forward_and_backward
    joint_extrace = rematerialize(joint_extrace)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 559, in rematerialize
    updated_consumer = apply_rematerialization_for_consumer(current_producer, current_consumer, cut)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 193, in apply_rematerialization_for_consumer
    new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
  File "/home/kkalambarkar/lightning-thunder/thunder/core/rematerialization.py", line 193, in <lambda>
    new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
KeyError: 't4'

To verify if name collision is the issue - I added a patch which disabled the names from forward_trace to appear in backward_trace.

diff --git a/thunder/common.py b/thunder/common.py
index 7683f60c..b93987c0 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -1,5 +1,5 @@
 import dis
-from typing import Any
+from typing import Any, Set
 from collections.abc import Callable, Generator, Hashable, Sequence
 from collections import deque, defaultdict
 import time
@@ -511,6 +511,7 @@ def trace(
     include_return_statement: bool = True,
     use_dce: bool = True,
     insert_ddp_syncs: bool = False,
+    used_names: Set[str] | None = None
 ) -> Callable:
     @make_opaque
     def _trace(
@@ -533,6 +534,9 @@ def trace(
                 return fn(*args, **kwargs)

             trace = TraceCtx(fn)
+            if used_names is not None:
+                for name in used_names:
+                    trace.add_name(name)
             tracectx_tok = set_tracectx(trace)

             proxyargs, proxykwargs = args, kwargs
diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index 5b056da8..f9b3679c 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -2178,13 +2178,13 @@ def softmax_aug_fwd(a: Proxy, dim: int, dtype: dtypes.dtype | None = None) -> VJ
     from thunder.torch import softmax

     primal = softmax(a, dim, dtype=dtype)
-    residuals = (primal, dim)
+    residuals = (primal, dim, a.dtype)
     return VJPDual(primal, residuals)

 @register_backward("torch.softmax")
-def softmax_backward(primal, dim, g):
-    return primal * (g - (primal * g).sum(dim, keepdim=True))
+def softmax_backward(primal, dim, original_dtype, g):
+    return (primal * (g - (primal * g).sum(dim, keepdim=True))).to(original_dtype)

 def iter_bound_symbols(bound_symbols):
@@ -2998,7 +2998,7 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa
             out = tree_flatten(out)[0]
         return out

-    backward_trace = construct_trace(rename_proxies=False)(backward_fn, saved_for_backward, cotangents)
+    backward_trace = construct_trace(rename_proxies=False, used_names=forward_trace.names)(backward_fn, saved_for_backward, cotangents)

     # We are done with constructing the forward and backward passes at this
     # stage. The following is not strictly necessary, but it's good to filter

With this everything works, even the patch from https://github.com/Lightning-AI/lightning-thunder/pull/1067 is not required (and also fixes the issue that 1067 tries to solve).

EDIT - I tried running test_networks.py, test_nvfuser_remat.py and some tests from test_grad.py and they all worked.

kshitij12345 commented 2 weeks ago

@riccardofelluga do you think https://github.com/Lightning-AI/lightning-thunder/pull/1067 could be related to the above comment

riccardofelluga commented 2 weeks ago

It might be related, however if this is the real issue, it still does not make #1067 obsolete because in any case if that check is valid then we can return early and therefore save some cycles. I find it interesting that you discovered that it's about naming, I couldn't have guessed. My question now would be, what caused the duplicate name in the fist place?

Could you also test the following with your patch and without #1067?

pytest thunder/benchmarks/targets.py -k test_nanogpt_cross_entropy[forward-thunder]
kshitij12345 commented 2 weeks ago

I am not very confident of my understanding of rematerialization, so it might be that my understanding is incorrect.

what caused the duplicate name in the fist place?

I think we don't constrain the names during the backward trace creation (assuming remat requires unique names between traces). Currently, I don't think there is any infra to specify the names that a trace shouldn't take.

Could you also test the following with your patch and without https://github.com/Lightning-AI/lightning-thunder/pull/1067?

Yes, it works. I can run pytest thunder/benchmarks/targets.py -k test_nanogpt_cross_entropy[forward-thunder] successfully with the patch while it fails on main with KeyError.

IvanYashchuk commented 2 weeks ago

Is my understanding correct that for rematerialization, there is an implicit expectation that the forward and backward trace shouldn't have tensor proxy name collision except for the proxies passed as saved_for_backward

When the pass was developed we probably had a global name counter. I think your patch to generate the backward trace without a name clash with the forward trace is a good one! Please submit a pull request.