Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

use `torch.get_default_dtype` and `torch.get_default_device` for factory method in `thunder/torch/__init__.py` #621

Open jjsjann123 opened 1 week ago

jjsjann123 commented 1 week ago

🐛 Bug

thunder's produces output with different dtype in compiled function

To Reproduce

import thunder
import torch

def foo():
    return torch.ones((1,), device="cuda")

jfoo = thunder.jit(foo)  # works

print("thunder output: ", jfoo())  # integer type
print("ref output: ", foo())  # float type
print(thunder.last_traces(jfoo)[0])

Pitches

thunder/torch/__init__.py isn't properly pulling torch's default dtype/devices in ops like torch.full, torch.empty. Resulting in wrong behavior.

The trace from the above function is:

thunder output:  tensor([1], device='cuda:0')
ref output:  tensor([1.], device='cuda:0')
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # /volume/thunder_463_part2.py:5:         return torch.ones((1,), device="cuda")
  t0 = ltorch.ones((1,), device='cuda', dtype=None)  # t0: "cuda:0 i64[1]"
    # t0 = ltorch.full((1,), 1, device='cuda', dtype=None)  # t0: "cuda:0 i64[1]"
      # t0 = prims.full((1,), 1, device=devices.Device("cuda:0"), dtype=dtypes.int64_)  # t0: "cuda:0 i64[1]"
  return t0

the lowered prims.full is executed by nvfuser, which receives an explicit dtype which is different from the vanilla function.

cc @apaz-cli

mruberry commented 4 days ago

triage review:

lantiga commented 4 days ago

added high-priority since it's needed to support transformers