Closed kshitij12345 closed 1 week ago
Hi @kshitij12345 , can I ask you why you think #461 is related to this issue (just trying to get a better understanding)? Because, when I reproduced the recursive error by only jitting the encoder of CLIPTextTransformer, it threw an error during the forward pass of the encoder. Doesn't that mean this error is not related to passing non-compatible parts? Would appreciate it if you can correct me if I am interpreting this incorrectly.
Ah, sorry for the confusion - I didn't mean this was the cause for #461. What I meant was the approach in #461 to jit a section of model may have issues due to this bug.
For example:
import thunder
import torch
def foo(x):
return x + 1, x.device
def bar(x, device):
return x + torch.ones_like(x, device=device)
x = torch.ones(3, 3)
# Eager works
bar(*foo(x))
jfoo = thunder.jit(foo)
# This doesn't work because `jfoo` returns `thunder.Device` which is passed to
# `torch.ones_like`
# TypeError: ones_like(): argument 'device' must be torch.device, not Device
bar(*jfoo(x))
I hope this makes more sense now. Thanks for looking into this!
@kshitij12345 Thanks for the clarification. It makes better sense now :)
triage review —
I think this can be a problem where part of the program is called with
thunder.jit
and it's output is used for non-compatible parts (eg. https://github.com/Lightning-AI/lightning-thunder/issues/461). These non-compatible parts may expect these objects to betorch
types.