Closed jjsjann123 closed 1 month ago
Prototyping this in #1201 . but running into issues with slice via trace.
See this comment: https://github.com/Lightning-AI/lightning-thunder/pull/1201#issuecomment-2378213557
def foo(a, key):
return thunder.clang.getitem(a, key)
a = torch.randn(2, 2, 9, device="cuda")
tracer = thunder.trace(inline_trace=False)
foo_trace = tracer(foo, a, (..., slice(None, 8, None)))
print(foo_trace)
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def foo(a, key):
# a: "cuda:0 f32[2, 2, 9]"
# key: "Collection"
_, _, = key
t12 = prims.slice_prim(a, [0, 0, 0], [2, 2, i0], [1, 1, 1]) # t12: "cuda:0 f32[2, 2, 8]"
return t12
The trace here isn't functional, the slice
input isn't unpacked properly. Is this going back to SliceProxy
, or did I miss something else again?
BTW, in the PR, the failure is encountered during grad transform. https://github.com/Lightning-AI/lightning-thunder/blob/59467aa5e401ea1c4d2c4e6f9e94d1dbe147a19e/thunder/core/transforms.py#L2280
Note, I think the issue is coming from the tree_flatten part I did. (I think I needed this change to explicitly mark dependency for entries in slice)
optree.register_pytree_node(
slice,
lambda s: ([s.start, s.stop, s.step], None, None),
lambda _, children: slice(*children),
namespace=OPTREE_NAMESPACE,
)
Somehow nothing is done for the flattening for list
, yet list is unpacked properly. e.g.
def foo2(a, newshape):
return thunder.clang.reshape(a, newshape)
foo2_trace = tracer(foo2, a, (2, 6, 3))
print(foo2_trace)
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def foo2(a, newshape):
# a: "cuda:0 f32[2, 2, 9]"
# newshape: "Collection"
i0, i1, i2, = newshape
t15 = prims.reshape(a, (i0, i1, i2)) # t15: "cuda:0 f32[2, 6, 3]"
return t15
Not sure if this is because of ListProxy
... I'll try to dig deeper.
Jie is this a problem with the default cache
argument to thunder.jit? Or only with symbolic values?
Jie is this a problem with the default
cache
argument to thunder.jit? Or only with symbolic values?
got me a while to figure out how to access github. 🙇
Yeah this is only with symbolic values. I have it mostly figured out in my prototype PR #1201. There might be some remaining issues with pytree, but at least I should be able to close this issue with that PR. I should have it cleaned up shortly and ready for review.
🚀 Feature
slice
is currently showing up in trace differently from list/tuple, where NumberProxy is printed explicitly, resulting in an invalid python program.e.g. in the script below:
On PR #1027, commit
4d260aed2b0939cebdeeeb4f04cf47358d3d9c8b
.We have a trace like this:
This line is not a valid python program.
We should instead print out something like
slice(None, i10, None)
.We plan on just having
slice
print out being handled inline, similar to how list/tuple. Thanks to suggestion by @t-vi & @mruberry .Alternative
I initial thought is that we needed
SliceProxy
, likeListProxy
/TupleProxy
. I thought that's needed, but as @t-vi pointed out thatI don't think we need
SliceProxy
for now, so I'll proceed with the less invasive approach first to unblock myself.