Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

python `slice` is not represented properly in thunder #1182

Closed jjsjann123 closed 1 month ago

jjsjann123 commented 1 month ago

🚀 Feature

slice is currently showing up in trace differently from list/tuple, where NumberProxy is printed explicitly, resulting in an invalid python program.

e.g. in the script below:

import thunder
import torch
dtype = torch.float32

def foo(a):
    return a[..., : (a.shape[-1] // 2)]

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn(2, 2, device="cuda")
out = jfoo(a)

On PR #1027, commit 4d260aed2b0939cebdeeeb4f04cf47358d3d9c8b.

We have a trace like this:

def computation(a, i1):
  # a: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=2, static=CONSTRAINT.CONSTRAINABLE]]"
  # i1: "int 2"

  # /volume/thunder_dynamic/t.py:7:         return a[..., : (a.shape[-1] // 2)]
  b2 = prims.signbit(i1)  # b2: "bool False"
  _ = prims.signbit(2)
  b3 = prims.ne(b2, False)  # b3: "bool False"
  f4 = prims.fmod(i1, 2)  # f4: "float 0.0"
  i5 = prims.convert_element_type(f4, int)  # i5: "int 0"
  b6 = prims.ne(i5, 0)  # b6: "bool False"
  b7 = prims.bitwise_and(b3, b6)  # b7: "bool False"
  i8 = prims.div(i1, 2)  # i8: "int 1"
  i9 = prims.convert_element_type(b7, int)  # i9: "int 0"
  i10 = prims.sub(i8, i9)  # i10: "int 1"
  t37 = ltorch.getitem(a, (..., slice(None, [IntegerProxy name=i10, value=1, static=CONSTRAINT.CONSTRAINABLE], None)))  # t37: "cuda:0 f32[[IntegerProxy name=i29, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i36, value=1, static=CONSTRAINT.CONSTRAINABLE]]"

This line is not a valid python program.

t37 = ltorch.getitem(a, (..., slice(None, [IntegerProxy name=i10, value=1, static=CONSTRAINT.CONSTRAINABLE], None)))

We should instead print out something like slice(None, i10, None).

We plan on just having slice print out being handled inline, similar to how list/tuple. Thanks to suggestion by @t-vi & @mruberry .

Alternative

I initial thought is that we needed SliceProxy, like ListProxy / TupleProxy. I thought that's needed, but as @t-vi pointed out that

I don't think that TupleProxy, DictProxy, and ListProxy are currently not used at all, they are relicts of the functional JIT which tried to proxy everything. I think CollectionProxy is used exclusively for autograd. Having a slice proxy is certainly possible and may be needed if we run into trouble putting numberproxies into slices. (but I wondered why you have lists in your slices above until... ).

I don't think we need SliceProxy for now, so I'll proceed with the less invasive approach first to unblock myself.

jjsjann123 commented 1 month ago

Prototyping this in #1201 . but running into issues with slice via trace.

See this comment: https://github.com/Lightning-AI/lightning-thunder/pull/1201#issuecomment-2378213557

def foo(a, key):
    return thunder.clang.getitem(a, key)
a = torch.randn(2, 2, 9, device="cuda")
tracer = thunder.trace(inline_trace=False)
foo_trace = tracer(foo, a, (..., slice(None, 8, None)))
print(foo_trace)
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def foo(a, key):
  # a: "cuda:0 f32[2, 2, 9]"
  # key: "Collection"
  _, _, = key
  t12 = prims.slice_prim(a, [0, 0, 0], [2, 2, i0], [1, 1, 1])  # t12: "cuda:0 f32[2, 2, 8]"
  return t12

The trace here isn't functional, the slice input isn't unpacked properly. Is this going back to SliceProxy, or did I miss something else again?

BTW, in the PR, the failure is encountered during grad transform. https://github.com/Lightning-AI/lightning-thunder/blob/59467aa5e401ea1c4d2c4e6f9e94d1dbe147a19e/thunder/core/transforms.py#L2280

jjsjann123 commented 1 month ago

Note, I think the issue is coming from the tree_flatten part I did. (I think I needed this change to explicitly mark dependency for entries in slice)

optree.register_pytree_node(
    slice,
    lambda s: ([s.start, s.stop, s.step], None, None),
    lambda _, children: slice(*children),
    namespace=OPTREE_NAMESPACE,
)

Somehow nothing is done for the flattening for list, yet list is unpacked properly. e.g.

def foo2(a, newshape):
    return thunder.clang.reshape(a, newshape)

foo2_trace = tracer(foo2, a, (2, 6, 3))
print(foo2_trace)
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def foo2(a, newshape):
  # a: "cuda:0 f32[2, 2, 9]"
  # newshape: "Collection"
  i0, i1, i2, = newshape
  t15 = prims.reshape(a, (i0, i1, i2))  # t15: "cuda:0 f32[2, 6, 3]"
  return t15

Not sure if this is because of ListProxy... I'll try to dig deeper.

tfogal commented 1 month ago

Jie is this a problem with the default cache argument to thunder.jit? Or only with symbolic values?

jjsjann123 commented 1 month ago

Jie is this a problem with the default cache argument to thunder.jit? Or only with symbolic values?

got me a while to figure out how to access github. 🙇

Yeah this is only with symbolic values. I have it mostly figured out in my prototype PR #1201. There might be some remaining issues with pytree, but at least I should be able to close this issue with that PR. I should have it cleaned up shortly and ready for review.