Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k stars 64 forks source link

Returning `dtype` or `device` from jitted function returns thunder's dtype or device (not torch.{dtype/device}). #573

Closed kshitij12345 closed 1 week ago

kshitij12345 commented 1 month ago

I think this can be a problem where part of the program is called with thunder.jit and it's output is used for non-compatible parts (eg. https://github.com/Lightning-AI/lightning-thunder/issues/461). These non-compatible parts may expect these objects to be torch types.

import thunder
import torch

def foo(x):
    return x.device, x.dtype

jfoo = thunder.jit(foo)
x = torch.randn(4,)
o = jfoo(x)
print(o)  # (cpu, float32)
print(type(o[0]), type(o[1]))  # <class 'thunder.core.devices.Device'> <class 'thunder.core.dtypes.floating'>

print(thunder.last_traces(jfoo)[-1])

# # Constructed by Delete Last Used (took 0 milliseconds)
# import thunder.core.devices as devices
# import thunder.core.dtypes as dtypes
# import torch
# from thunder.executors.torchex import no_autocast

# @torch.no_grad()
# @no_autocast
# def computation():
#   return (devices.Device("cpu"), dtypes.float32)
k223kim commented 1 month ago

Hi @kshitij12345 , can I ask you why you think #461 is related to this issue (just trying to get a better understanding)? Because, when I reproduced the recursive error by only jitting the encoder of CLIPTextTransformer, it threw an error during the forward pass of the encoder. Doesn't that mean this error is not related to passing non-compatible parts? Would appreciate it if you can correct me if I am interpreting this incorrectly.

kshitij12345 commented 1 month ago

Ah, sorry for the confusion - I didn't mean this was the cause for #461. What I meant was the approach in #461 to jit a section of model may have issues due to this bug.

For example:

import thunder
import torch

def foo(x):
    return x + 1, x.device

def bar(x, device):
    return x + torch.ones_like(x, device=device)

x = torch.ones(3, 3)

# Eager works
bar(*foo(x))

jfoo = thunder.jit(foo)
# This doesn't work because `jfoo` returns `thunder.Device` which is passed to
# `torch.ones_like`
# TypeError: ones_like(): argument 'device' must be torch.device, not Device
bar(*jfoo(x))

I hope this makes more sense now. Thanks for looking into this!

k223kim commented 1 month ago

@kshitij12345 Thanks for the clarification. It makes better sense now :)

mruberry commented 1 month ago

triage review —