Closed kshitij12345 closed 2 days ago
Minimal Repro:
import torch
import thunder
def f(ids):
for ix, t in enumerate(ids):
pass
return ids
jf = thunder.jit(f)
ids = torch.randn(2, 2).to(torch.long)
jf(ids)
Great repro, I think we would want Tensor.__iter__
here. I'm not 100% sure what the best strategy is given that iters are not first class objects, one might be
torch.Tensor.__iter__
lookaside to jit_ext.py,Does that make sense?
I tried adding lookaside for torch.Tensor.__iter__
with following patch
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index ac9e127..2da42d7 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -890,6 +890,18 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs):
model, model.named_buffers, model.get_buffer, *unwrapped_args, **unwrapped_kwargs
)
+@general_jit_lookaside(torch.Tensor.__iter__)
+def _general_tensor_iter_lookaside(obj: Any, *args, **kwargs):
+
+ # NOTE: This will be interpreted.
+ def _tensor_iter_impl(t):
+ for t_slice in t.unbind():
+ yield t_slice
+
+ pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(torch.Tensor.__iter__).provenance])
+
+ return _interpret_call(_tensor_iter_impl, wrap(unwrap(obj), pr))
+
@general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs):
But it never gets called.
On printing what iter_lookaside
gets, I see that it receives a TensorProxy and then ends up calling __getitem__
(which does something unintended) as it doesn't have __iter__
.
Adding __iter__
to TensorProxy with following patch works (I think other iterable proxies like TupleProxy and ListProxy probably work because they inherit from tuple and list which allows them to have __iter__
)
diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py
index df6dce4..b440091 100644
--- a/thunder/core/proxies.py
+++ b/thunder/core/proxies.py
@@ -1340,6 +1340,9 @@ class TensorProxy(Proxy, TensorProxyInterface):
method = resolve_method("getitem", self, key)
return method(self, key)
+ def __iter__(self):
+ return iter(self.unbind())
+
#
# Elementwise unary operators
#
What do you think about this (or maybe there is still a way with lookaisde)?
Ooops, sorry, @kshitij12345 , I accidentally stepped on your feet... But I am not sure about my approach. Is it fine to have a lookside for TensorProxies?
Ah, no worries, as long as the issue is fixed :) Thanks for looking into this.
Is it fine to have a lookside for TensorProxies?
@t-vi what are your thoughts?
Seems good to have one. The reason to do it as a lookaside is to not handle iter objects in the trace.
But yeah, if defining the iter on the tensorproxy works, lets just have that, I think it might be the same result just with a slightly different execution model. We should comment the itermethod. What is the trace we get from that?
@t-vi , something like this
In [8]: def f(x):
...: res = x
...: for xi in x:
...: res = res + xi.unsqueeze(0)
...: return res
f(torch.rand(3, 2, 2))
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cpu f32[3, 2, 2]"
(xi, t4, t5) = torch.unbind(x, 0)
# (xi, t4, t5) = ltorch.unbind(x, 0)
# (t15, t16, t17) = ltorch.tensor_split(x, 3, 0)
# t15 = prims.slice_prim(x, [0, 0, 0], [1, 2, 2], [1, 1, 1]) # t15: "cpu f32[1, 2, 2]"
# t16 = prims.slice_prim(x, [1, 0, 0], [2, 2, 2], [1, 1, 1]) # t16: "cpu f32[1, 2, 2]"
# t17 = prims.slice_prim(x, [2, 0, 0], [3, 2, 2], [1, 1, 1]) # t17: "cpu f32[1, 2, 2]"
# xi = ltorch.squeeze(t15, 0) # xi: "cpu f32[2, 2]"
# xi = prims.squeeze(t15, (0,)) # xi: "cpu f32[2, 2]"
# t4 = ltorch.squeeze(t16, 0) # t4: "cpu f32[2, 2]"
# t4 = prims.squeeze(t16, (0,)) # t4: "cpu f32[2, 2]"
# t5 = ltorch.squeeze(t17, 0) # t5: "cpu f32[2, 2]"
# t5 = prims.squeeze(t17, (0,)) # t5: "cpu f32[2, 2]"
b = torch.unsqueeze(xi, 0) # b: "cpu f32[1, 2, 2]"
# b = ltorch.unsqueeze(xi, 0) # b: "cpu f32[1, 2, 2]"
# b = prims.broadcast_in_dim(xi, [1, 2, 2], [1, 2]) # b: "cpu f32[1, 2, 2]"
del xi
t9 = torch.unsqueeze(t4, 0) # t9: "cpu f32[1, 2, 2]"
# t9 = ltorch.unsqueeze(t4, 0) # t9: "cpu f32[1, 2, 2]"
# t9 = prims.broadcast_in_dim(t4, [1, 2, 2], [1, 2]) # t9: "cpu f32[1, 2, 2]"
del t4
t12 = torch.unsqueeze(t5, 0) # t12: "cpu f32[1, 2, 2]"
# t12 = ltorch.unsqueeze(t5, 0) # t12: "cpu f32[1, 2, 2]"
# t12 = prims.broadcast_in_dim(t5, [1, 2, 2], [1, 2]) # t12: "cpu f32[1, 2, 2]"
del t5
result = torch.add(x, b) # result: "cpu f32[3, 2, 2]"
# result = ltorch.add(x, b, alpha=None) # result: "cpu f32[3, 2, 2]"
# t22 = prims.broadcast_in_dim(b, (3, 2, 2), (0, 1, 2)) # t22: "cpu f32[3, 2, 2]"
# result = prims.add(x, t22) # result: "cpu f32[3, 2, 2]"
del x, b
res = torch.add(result, t9) # res: "cpu f32[3, 2, 2]"
# res = ltorch.add(result, t9, alpha=None) # res: "cpu f32[3, 2, 2]"
# t25 = prims.broadcast_in_dim(t9, (3, 2, 2), (0, 1, 2)) # t25: "cpu f32[3, 2, 2]"
# res = prims.add(result, t25) # res: "cpu f32[3, 2, 2]"
del result, t9
t14 = torch.add(res, t12) # t14: "cpu f32[3, 2, 2]"
# t14 = ltorch.add(res, t12, alpha=None) # t14: "cpu f32[3, 2, 2]"
# t28 = prims.broadcast_in_dim(t12, (3, 2, 2), (0, 1, 2)) # t28: "cpu f32[3, 2, 2]"
# t14 = prims.add(res, t28) # t14: "cpu f32[3, 2, 2]"
del res, t12
return t14
...
I'd say not overly pretty but not too terrible.
NOTE: For minimal repro - see comment below
Full Log - error.log
(Steps to repro are same from #678 and copied from there except addition of megatron_core commit details in environment) To Repro -
Note you'll need the referenced ./data directory; ping @tfogal privately for now.
Environment
cc @tfogal