Open IvanYashchuk opened 1 month ago
The goal of the renaming has been to use more user variable names, not to get consecutive generic names. Is that what you want? (Also, with any bit of luck, #954 will change the trace above.
I understand the goal, but in the example above the renaming doesn't happen when it should. In the example above the trace
object returned by thunder.trace()
plays the role of a user script and variable names should be preserved.
I htink what happens often is that it tries to rename something to an internal name ("tos" etc.) and then finds that that is taken already for whatever reason. With #954 (which is not the end of the story), you'll get t7.
@no_autocast
def computation(x, y):
# x: "cpu f32[3, 4]"
# y: "cpu f32[3, 4]"
# thunder.func_0:15: t0 = ltorch.mul(x, y) # t0: "cpu f32[3, 4]"
t0 = ltorch.mul(x, y) # t0: "cpu f32[3, 4]"
# t0 = prims.mul(x, y) # t0: "cpu f32[3, 4]"
# thunder.func_0:16: t1 = ltorch.true_divide(y, x) # t1: "cpu f32[3, 4]"
t1 = ltorch.true_divide(y, x) # t1: "cpu f32[3, 4]"
# t1 = prims.div(y, x) # t1: "cpu f32[3, 4]"
# thunder.func_0:17: t2 = ltorch.mul(x, y) # t2: "cpu f32[3, 4]"
t2 = ltorch.mul(x, y) # t2: "cpu f32[3, 4]"
# t2 = prims.mul(x, y) # t2: "cpu f32[3, 4]"
# thunder.func_0:18: t3 = ltorch.true_divide(y, x) # t3: "cpu f32[3, 4]"
t3 = ltorch.true_divide(y, x) # t3: "cpu f32[3, 4]"
# t3 = prims.div(y, x) # t3: "cpu f32[3, 4]"
# thunder.func_0:19: t4 = ltorch.mul(t0, t1) # t4: "cpu f32[3, 4]"
t4 = ltorch.mul(t0, t1) # t4: "cpu f32[3, 4]"
# t4 = prims.mul(t0, t1) # t4: "cpu f32[3, 4]"
# thunder.func_0:20: t5 = ltorch.mul(t2, t3) # t5: "cpu f32[3, 4]"
t5 = ltorch.mul(t2, t3) # t5: "cpu f32[3, 4]"
# t5 = prims.mul(t2, t3) # t5: "cpu f32[3, 4]"
# thunder.func_0:21: t6 = ltorch.mul(t5, 1) # t6: "cpu f32[3, 4]"
t6 = ltorch.mul(t5, 1) # t6: "cpu f32[3, 4]"
# _ = prims.convert_element_type(1, float)
# t6 = prims.mul(t5, 1.0) # t6: "cpu f32[3, 4]"
# thunder.func_0:22: t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16) # t7: "cpu f16[3, 4]"
t7 = prims.uniform((3, 4), 0.0, 1.0, device=devices.Device("cpu"), dtype=dtypes.float16) # t7: "cpu f16[3, 4]"
# /usr/local/lib/python3.12/dist-packages/torch/autograd/grad_mode.py:186: torch._C._set_grad_enabled(mode)
return (t4, t5, t6, t7)
🐛 Bug
Proxy renaming in the initial trace doesn't work sometimes. Let's check how does the initial trace look like for the following example (taken from test_core.py::test_cse):
The initial trace in thunder.jit is
Why is
t14
not renamed tot7
and all other variables are renamed? The renaming is happening at https://github.com/Lightning-AI/lightning-thunder/blob/9f6e5b14e7a0fc6c96cca254540666d899df60b2/thunder/core/jit_ext.py#L1822-L1823