Closed athitten closed 3 months ago
FYI just figured that self.encoder
consisted of a nn.ModuleList
with a for loop (shown below) which probably caused the recursion error.
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
Adding thunder.jit
to the individual modules of the for loop instead of the entire nn.ModuleList fixed the RecursionError.
Thanks @athitten ! That's really helpful.
Tagging triage review. Triage team, beyond the obvious "add support for control flow", I'm curious what our options are here.
Staring down the traceback (rather than running it myself) it does not look like the modules itself (litgpt also uses a for loop over ModuleList), but as if we do have a trace that fails to print itself because of some reference cycle (which might be caused by the interpreter erroneously inserting that into the trace).
Hi Team ⚡️, currently I am working on this issue and would like to share how I reproduced the same error (just as a reference to anyone else who is working on it). It is quite similar to the code shown above, but just smaller :)
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout tags/v.4.41.2
pip install -e .
CLIPTextTransformer
class:
self.embeddings = thunder.jit(self.embeddings)
self.encoder = thunder.jit(self.encoder)
self.final_layer_norm = thunder.jit(self.final_layer_norm)
from transformers import CLIPTokenizer, CLIPTextModel
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output # pooled (EOS token) states
(cc. @t-vi )
@k223kim debugged this more and the infinite recursion is from to_printable
assuming that tree_flatten
will "simplify" the input when it in reality produces the original input as part of the flattened objects for BaseModelOutput
(from transformers, a dataclass https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/modeling_outputs.py#L24-L47.)
def to_printable(
trace: Optional,
x: Any,
*,
import_ctx: Optional[dict] = None,
object_ctx: Optional[dict] = None,
) -> Printable:
# Short-circuits if x is a Proxy
if isinstance(x, ProxyInterface):
return x
if is_collection(x):
flat, spec = tree_flatten(x)
printables = []
for f in flat:
printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
printable = tree_unflatten(printables, spec)
return printable
More minimal repro to create a test in a fix:
import transformers
import torch
import thunder
def fn(x):
return transformers.modeling_outputs.BaseModelOutput(x)
jfn = thunder.jit(fn)
x = torch.randn(5, 5)
print(jfn(x))
🐛 Bug
NeMo's Stable Diffusion uses CLIPTextModel from HuggingFace transformers. Using thunder.jit with the CLIPTextModel is causing a RecursionError.
To Reproduce
Steps to reproduce the behavior:
Partial stack trace below:
CC: @tfogal
cc @apaz-cli @tfogal