Open t-vi opened 3 months ago
Same issue as for #1174, it seems that the first step here would be to implement the bool ops on tensors. Some repro snippets:
import torch
import thunder
def foo(x):
return not x
jf = thunder.jit(foo)
a = torch.tensor(0)
jf(a)
and also something like this:
def bar(x):
return x or False
Stack trace for reference:
File "/opt/pytorch/lightning-thunder/test.py", line 10, in <module>
jf(a)
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 717, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
result = fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 219, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 506, in get_computation_and_inputs
jit_results: TraceResults = thunder_general_jit(
File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1635, in thunder_general_jit
result = jfn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 7189, in fn_
raise e
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 7150, in fn_2
return fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/test.py", line 5, in foo
return not x
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 5976, in impl
if bool(tos):
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1387, in impl
return dunder_bool(x)
File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in wrapping_wrapper
res = ufn(*uargs, **ukwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 387, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 1646, in __bool__
raise NotImplementedError
NotImplementedError
HF BERT data-dependent control flow:
input_ids
is a tensor, that ultimately makes us fail on__bool__
for tensors.Repro: