Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.12k stars 69 forks source link

type inference: mismatched dtype in cat operator #750

Closed tfogal closed 1 month ago

tfogal commented 1 month ago

🚀 Model / language coverage

The following code results in a

RuntimeError: Expected dtype thunder.dtypes.float32 but found thunder.dtypes.int64_!

error in cats implementation. It seems we end up confused about the proper dtype of the second tensor.

#!python3
import torch
import thunder
import einops

def foo(input_ids, inputs_embeds):
  batch_size, sequence_length, hidden_size = inputs_embeds.shape

  media_features = torch.randn((2,1,1,256,5120), dtype=torch.float16)
  num_images_per_sample = media_features.size(1)
  num_patches = media_features.size(3) * media_features.size(2)

  media_end_id = 32005
  sorted_media_end_positions_mask, media_end_positions_mask_sort_idx = (
      # NOTE: to(torch.long) is needed because PyTorch does not have sort for boolean tensors on CUDA
      (input_ids == media_end_id).to(torch.long).sort(dim=-1, descending=True, stable=True)
  )

  padded_media_indices = torch.where(
    sorted_media_end_positions_mask.to(torch.bool),
    media_end_positions_mask_sort_idx - num_patches + 1,
    sequence_length
  )
  padded_media_indices = padded_media_indices.unsqueeze(-1) + torch.arange(
    num_patches, device=padded_media_indices.device
  ).repeat(*padded_media_indices.shape, 1)
  padded_media_indices = padded_media_indices.reshape(batch_size, -1)
  padded_media_indices = einops.repeat(padded_media_indices, 'b s -> b s h', h=hidden_size)

  second = torch.zeros((batch_size, num_patches, hidden_size), device=inputs_embeds.device)
  # Note: thunder can be made to work by explicitly setting the dtype:
  #   second = torch.zeros((batch_size, num_patches, hidden_size), dtype=torch.float32, device=inputs_embeds.device)
  #print(f"ii dt:shape={inputs_embeds.dtype}:{inputs_embeds.shape}")
  #print(f"2nd dt:shape={second.dtype}:{second.shape}")
  updated_input_embeds = torch.cat(
    (inputs_embeds, second), dim=1
  )
  return updated_input_embeds

at = torch.zeros((2,384), dtype=torch.int64)
bt = torch.randn((2,384, 5120), dtype=torch.float32)

foo(at, bt)

thfoo = thunder.jit(foo)
thfoo(at, bt)

As the comment in the zeros line indicates, thunder can be coerced into compiling this by explicitly adding a dtype to the zeros call. However, it seems the bug is more global than just zeros, as our zeros works perfectly on its own:

>>> def z1(x: torch.Tensor) -> torch.Tensor:
...     return torch.zeros([2,1,2], device=x.device)
...
>>> abc
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
>>> abc.shape
torch.Size([2, 2, 2])
>>> abc.dtype
torch.float32
>>> z1(abc).dtype
torch.float32
>>> th_z1 = thunder.jit(z1)
>>> th_z1(abc).dtype
torch.float32
>>> assert z1(abc).dtype == th_z1(abc).dtype
>>> 

Pitch

This came about while using Nik's patch to try to get #343 to work. Nik and I still need some iteration on his patch, so there's no guarantee that this will be the next bug after #124, but it's plausibly a blocker.

cc @apaz-cli @tfogal

t-vi commented 1 month ago

Note that PyTorch upcasts automatically when given tensors of varying dtype while Thunder currently errors. When I tried to add this (clumsily) #41 , it seemed that I hit some inconsistency in torch eager vs. compile.

tfogal commented 1 month ago

Note that PyTorch upcasts automatically when given tensors of varying dtype while Thunder currently errors.

Ahh, yeah, I suspected something is off there; thanks for the confirmation!

But I think something more insidious is going on here---when run in eager, the types match. i.e.: print(f"dtypes: {inputs_embeds.dtype}, {second.dtype}") says 'float32' twice in eager mode, but 'float32, int64' in thunder.

If we were to actually do #41, it should get us through this but would actually end up masking the deeper bug.

kshitij12345 commented 1 month ago

The actual issue here is that the factory functions like zeros and ones rely on full which infers it's dtype based on fill value (when dtype is not passed explicitly)

https://github.com/Lightning-AI/lightning-thunder/blob/a3e432f7174019b2eda85865890d5f7342a993c2/thunder/clang/__init__.py#L275-L277

Also, this is hidden during execution with torchex as it does the correct thing of reading the value from torch.get_default_dtype- https://github.com/Lightning-AI/lightning-thunder/blob/a3e432f7174019b2eda85865890d5f7342a993c2/thunder/executors/torchex.py#L486

Minimal Repro (output is float but in trace we see that proxy has integer dtype):

import torch
import thunder

def foo(x: torch.Tensor) -> torch.Tensor:
    o = torch.zeros((2,1,2), device=x.device)
    return o

jfoo = thunder.jit(foo)
o = jfoo(torch.randn(3, 3))
print(o.dtype)
print(thunder.last_traces(jfoo)[0])

Output

torch.float32

import thunder
import thunder.core.devices as devices
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # /home/kkalambarkar/lightning-thunder/scratchpad/test.py:63:             o = torch.zeros((2,1,2), device=x.device)
  o = ltorch.zeros((2, 1, 2), device=devices.Device("cpu"), dtype=None)  # o: "cpu i64[2, 1, 2]"
    # o = ltorch.full((2, 1, 2), 0, device=devices.Device("cpu"), dtype=None)  # o: "cpu i64[2, 1, 2]"
      # o = prims.full((2, 1, 2), 0, device=devices.Device("cpu"), dtype=dtypes.int64_)  # o: "cpu i64[2, 1, 2]"
  return o

I think this is a duplicate of https://github.com/Lightning-AI/lightning-thunder/issues/621

tfogal commented 1 month ago

triage review: