Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k
stars
60
forks
source link
use `torch.get_default_dtype` and `torch.get_default_device` for factory method in `thunder/torch/__init__.py` #621
🐛 Bug
thunder's produces output with different dtype in compiled function
To Reproduce
Pitches
thunder/torch/__init__.py
isn't properly pulling torch's default dtype/devices in ops liketorch.full
,torch.empty
. Resulting in wrong behavior.The trace from the above function is:
the lowered
prims.full
is executed by nvfuser, which receives an explicit dtype which is different from the vanilla function.cc @apaz-cli