Closed tfogal closed 1 month ago
Note that PyTorch upcasts automatically when given tensors of varying dtype while Thunder currently errors. When I tried to add this (clumsily) #41 , it seemed that I hit some inconsistency in torch eager vs. compile.
Note that PyTorch upcasts automatically when given tensors of varying dtype while Thunder currently errors.
Ahh, yeah, I suspected something is off there; thanks for the confirmation!
But I think something more insidious is going on here---when run in eager, the types match. i.e.: print(f"dtypes: {inputs_embeds.dtype}, {second.dtype}")
says 'float32' twice in eager mode, but 'float32, int64' in thunder.
If we were to actually do #41, it should get us through this but would actually end up masking the deeper bug.
The actual issue here is that the factory functions like zeros
and ones
rely on full
which infers it's dtype based on fill value (when dtype is not passed explicitly)
Also, this is hidden during execution with torchex
as it does the correct thing of reading the value from torch.get_default_dtype
-
https://github.com/Lightning-AI/lightning-thunder/blob/a3e432f7174019b2eda85865890d5f7342a993c2/thunder/executors/torchex.py#L486
Minimal Repro (output is float but in trace we see that proxy has integer dtype):
import torch
import thunder
def foo(x: torch.Tensor) -> torch.Tensor:
o = torch.zeros((2,1,2), device=x.device)
return o
jfoo = thunder.jit(foo)
o = jfoo(torch.randn(3, 3))
print(o.dtype)
print(thunder.last_traces(jfoo)[0])
Output
torch.float32
import thunder
import thunder.core.devices as devices
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation():
# /home/kkalambarkar/lightning-thunder/scratchpad/test.py:63: o = torch.zeros((2,1,2), device=x.device)
o = ltorch.zeros((2, 1, 2), device=devices.Device("cpu"), dtype=None) # o: "cpu i64[2, 1, 2]"
# o = ltorch.full((2, 1, 2), 0, device=devices.Device("cpu"), dtype=None) # o: "cpu i64[2, 1, 2]"
# o = prims.full((2, 1, 2), 0, device=devices.Device("cpu"), dtype=dtypes.int64_) # o: "cpu i64[2, 1, 2]"
return o
I think this is a duplicate of https://github.com/Lightning-AI/lightning-thunder/issues/621
triage review:
🚀 Model / language coverage
The following code results in a
error in
cat
s implementation. It seems we end up confused about the proper dtype of thesecond
tensor.As the comment in the
zeros
line indicates, thunder can be coerced into compiling this by explicitly adding adtype
to thezeros
call. However, it seems the bug is more global than justzeros
, as ourzeros
works perfectly on its own:Pitch
This came about while using Nik's patch to try to get #343 to work. Nik and I still need some iteration on his patch, so there's no guarantee that this will be the next bug after #124, but it's plausibly a blocker.
cc @apaz-cli @tfogal