Open jjsjann123 opened 5 months ago
Seems to found another issue with basic indexing.
def foo(a, size0, size1):
return a[:, :, size0:size1]
size0 = 8
size1 = 16
a = torch.randn(2, 4, 30, device="cuda")
a_ref = a.detach()
outs_ref = foo(a_ref, size0, size1)
jfoo = thunder.jit(foo, cache="symbolic values")
outs = jfoo(a, size0, size1)
Even though the trace is using size0 and size1, it's not taking them as inputs from prologue to compute.
def computation(a):
# a: "cuda:0 f32[2, 4, 30]"
# /volume/thunder_jit_reshape.py:52: return a[:, :, size0:size1]
t18 = ltorch.getitem(a, (slice(None, None, None), slice(None, None, None), slice([IntegerProxy name=i0, value=8, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=16, static=CONSTRAINT.CONSTRAINABLE], None))) # t18: "cuda:0 f32[2, 4, 8]"
# b2 = prims.lt(size0, 0) # b2: "bool False"
# b3 = prims.ge(size0, 0) # b3: "bool True"
# b4 = prims.lt(size1, 0) # b4: "bool False"
# b5 = prims.ge(size1, 0) # b5: "bool True"
# b6 = prims.gt(size0, size1) # b6: "bool False"
# b7 = prims.ge(size0, 30) # b7: "bool False"
# b8 = prims.ge(size1, 30) # b8: "bool False"
# b9 = prims.eq(size0, 0) # b9: "bool False"
# t18 = prims.slice_prim(a, [0, 0, size0], [2, 4, size1], [1, 1, 1]) # t18: "cuda:0 f32[2, 4, 8]"
return t18
🐛 Bug
As per title suggests, with symbolic values cache, NumberProxy operation via math.xxx calls are not traced.
To Reproduce
Looking at the trace below, we do not have
prims.fmod
in trace but it just baked in the result and this is wrong.