Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.16k stars 76 forks source link

__bool__ (and data dependent control flow) #735

Open t-vi opened 3 months ago

t-vi commented 3 months ago

HF BERT data-dependent control flow:

if self.config.pad_token_id in input_ids[:, [-1, 0]]:
   4348     warn_string = (
   4349         "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
   4350         "https://huggingface.co/docs/transformers/troubleshooting"
   4351         "#incorrect-output-when-padding-tokens-arent-masked."
   4352     )
   4354     # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
   4355     # attention_mask or not. In this case, we should still show a warning because this is a rare case.

input_ids is a tensor, that ultimately makes us fail on __bool__ for tensors.

Repro:

import torch, thunder, transformers

m = transformers.BertForSequenceClassification(transformers.BertConfig())
jm = thunder.jit(m)
a = torch.randint(1, 20, (1, 25))
jm(a)
riccardofelluga commented 2 weeks ago

Same issue as for #1174, it seems that the first step here would be to implement the bool ops on tensors. Some repro snippets:

import torch
import thunder

def foo(x):
    return not x

jf = thunder.jit(foo)
a = torch.tensor(0)

jf(a)

and also something like this:

def bar(x):
    return x or False

Stack trace for reference:

  File "/opt/pytorch/lightning-thunder/test.py", line 10, in <module>
    jf(a)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 717, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 219, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 506, in get_computation_and_inputs
    jit_results: TraceResults = thunder_general_jit(
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1635, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 7189, in fn_
    raise e
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 7150, in fn_2
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/test.py", line 5, in foo
    return not x
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 5976, in impl
    if bool(tos):
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1387, in impl
    return dunder_bool(x)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 387, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 1646, in __bool__
    raise NotImplementedError
NotImplementedError