Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
I suppose torch semantics are to cast each type to the first type?
Note this is very similar to #750. It seems like the issue in #750 just appeared in cat even though the error was earlier, but now we are finding the issue in cat through some other code.
Minimal Repro
$ cat cat-dtype.py
import torch
import thunder
def foo():
x = torch.randn((5,3), dtype=torch.bfloat16)
y = torch.randn((2,3), dtype=torch.float16)
z = torch.cat((x,y), dim=0)
return z
foo()
thfoo = thunder.jit(foo)
thfoo()
$ python3 cat-dtype.py
Traceback (most recent call last):
File "/tmp/cat-dtype.py", line 12, in <module>
thfoo()
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 683, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 225, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
jit_results: TraceResults = interpreter(
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 213, in _general_frontend
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
result = jfn(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6769, in fn_
raise e
File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6737, in fn_2
return fn(*args, **kwargs)
File "/tmp/cat-dtype.py", line 7, in foo
z = torch.cat((x,y), dim=0)
File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1272, in wrapping_wrapper
res = ufn(*uargs, **ukwargs)
File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 704, in wrapper
return fn(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 276, in __call__
result = self.meta(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
result = fn(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 812, in cat
return clang.cat(tensors, dim)
File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
result = fn(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/clang/__init__.py", line 1289, in cat
return prims.cat(tensors, dim)
File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 272, in __call__
result = self.meta(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
result = fn(*args, **kwargs)
File "/home/tfogal/dev/thunder/thunder/core/prims.py", line 2983, in cat_meta
utils.check_same_dtype(*tensors)
File "/home/tfogal/dev/thunder/thunder/core/utils.py", line 240, in check_same_dtype
check(
File "/home/tfogal/dev/thunder/thunder/core/baseutils.py", line 103, in check
raise exception_type(s())
RuntimeError: Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float16!
🚀 Model / language coverage
First, I applied this diff to thunder:
The diff was necessary to get the beginning of the output below, which conveys that a
cat
operator is what is at fault:Full log of the run
Instructions on how to run NeVA are in #343.
Pitch
This is for the NeVA model #343 .
Alternatives / Potential work-arounds
It seems like our
cat
checks are too stringent, in that torch allows mismatched dtypes here:I suppose torch semantics are to cast each type to the first type?
Note this is very similar to #750. It seems like the issue in #750 just appeared in
cat
even though the error was earlier, but now we are finding the issue incat
through some other code.Minimal Repro
cc @tfogal