Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 76 forks source link

Recursion error in transformer module with NeMo Stable Diffusion #461

Closed athitten closed 3 months ago

athitten commented 4 months ago

🐛 Bug

NeMo's Stable Diffusion uses CLIPTextModel from HuggingFace transformers. Using thunder.jit with the CLIPTextModel is causing a RecursionError.

To Reproduce

Steps to reproduce the behavior:

  1. Add the following lines to transformers/models/clip/modeling_clip.py here in the location where transformers is installed in your container
        ## thunder.jit
        self.embeddings = thunder.jit(self.embeddings)
        self.encoder = thunder.jit(self.encoder)
        self.final_layer_norm = thunder.jit(self.final_layer_norm)
  2. Run NeMo Stable Diffusion with the command below:
    python examples/multimodal/text_to_image/stable_diffusion/sd_train.py trainer.precision=16 trainer.num_nodes=1 trainer.devices=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=1 model.global_batch_size=1 model.data.synthetic_data=True exp_manager.exp_dir=/workspace/TestData/multimodal/stable_diffusion_train model.inductor=False model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder ++model.cond_stage_config.version=openai/clip-vit-large-patch14 ++model.cond_stage_config.max_length=77 ~model.cond_stage_config.restore_from_path ~model.cond_stage_config.freeze ~model.cond_stage_config.layer model.unet_config.from_pretrained=null model.first_stage_config.from_pretrained=null model.unet_config.use_flash_attention=False model.unet_config.attention_resolutions=\[1\] model.unet_config.channel_mult=\[1\]

Partial stack trace below:

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/core/module.py", line 49, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/__init__.py", line 617, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/__init__.py", line 202, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/workspace/software/lightning-thunder/thunder/__init__.py", line 540, in get_computation_and_inputs
    autocast(computation_trc.python_callable(), dtype=autocast_thunder_dtype), *inps
  File "/workspace/software/lightning-thunder/thunder/core/trace.py", line 437, in python_callable
    python_str = self.python(**kwargs)
  File "/workspace/software/lightning-thunder/thunder/core/trace.py", line 318, in python
    import_ctx, call_ctx, object_ctx = self._gather_ctxs()
  File "/workspace/software/lightning-thunder/thunder/core/trace.py", line 281, in _gather_ctxs
    bsym_import_ctx, bsym_call_ctx, bsym_object_ctx = bsym.gather_ctxs()
  File "/workspace/software/lightning-thunder/thunder/core/symbol.py", line 580, in gather_ctxs
    return self.import_ctx(), self._get_call_ctx(), self.object_ctx()
  File "/workspace/software/lightning-thunder/thunder/core/symbol.py", line 520, in import_ctx
    self._out_printables, self._arg_printables, self._kwarg_printables  # type: ignore
  File "/workspace/software/lightning-thunder/thunder/core/symbol.py", line 472, in _out_printables
    return codeutils.to_printable(trace, self.output, import_ctx=self._import_ctx, object_ctx=self._object_ctx)
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 128, in to_printable
    printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 128, in to_printable
    printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 128, in to_printable
    printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))
  [Previous line repeated 2899 more times]
  File "/workspace/software/lightning-thunder/thunder/core/codeutils.py", line 123, in to_printable
    if is_collection(x):
  File "/workspace/software/lightning-thunder/thunder/core/baseutils.py", line 153, in is_collection
    return isinstance(x, collections.abc.Collection) and not isinstance(x, (str, torch.Tensor, np.ndarray))
  File "/usr/lib/python3.10/abc.py", line 117, in __instancecheck__
    def __instancecheck__(cls, instance):
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_trace_dispatch_regular.py", line 469, in __call__
    return None if event == 'call' else NO_FTRACE
RecursionError: maximum recursion depth exceeded in comparison

CC: @tfogal

cc @apaz-cli @tfogal

athitten commented 3 months ago

FYI just figured that self.encoder consisted of a nn.ModuleList with a for loop (shown below) which probably caused the recursion error. self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])

Adding thunder.jit to the individual modules of the for loop instead of the entire nn.ModuleList fixed the RecursionError.

tfogal commented 3 months ago

Thanks @athitten ! That's really helpful.

Tagging triage review. Triage team, beyond the obvious "add support for control flow", I'm curious what our options are here.

t-vi commented 3 months ago

Staring down the traceback (rather than running it myself) it does not look like the modules itself (litgpt also uses a for loop over ModuleList), but as if we do have a trace that fails to print itself because of some reference cycle (which might be caused by the interpreter erroneously inserting that into the trace).

k223kim commented 3 months ago

Hi Team ⚡️, currently I am working on this issue and would like to share how I reproduced the same error (just as a reference to anyone else who is working on it). It is quite similar to the code shown above, but just smaller :)

  1. Clone and install from source from hugging face transformers.
    git clone https://github.com/huggingface/transformers.git
    cd transformers
    git checkout tags/v.4.41.2
    pip install -e .
  2. As mentioned above, add this to the CLIPTextTransformer class:
        self.embeddings = thunder.jit(self.embeddings)
        self.encoder = thunder.jit(self.encoder)
        self.final_layer_norm = thunder.jit(self.final_layer_norm)
  3. Run the following script: (This is actually from the huggingface repo)
    
    from transformers import CLIPTokenizer, CLIPTextModel

model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

outputs = model(**inputs) last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output # pooled (EOS token) states


(cc. @t-vi )
t-vi commented 3 months ago

@k223kim debugged this more and the infinite recursion is from to_printable assuming that tree_flatten will "simplify" the input when it in reality produces the original input as part of the flattened objects for BaseModelOutput (from transformers, a dataclass https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/modeling_outputs.py#L24-L47.)

def to_printable(
    trace: Optional,
    x: Any,
    *,
    import_ctx: Optional[dict] = None,
    object_ctx: Optional[dict] = None,
) -> Printable:
    # Short-circuits if x is a Proxy
    if isinstance(x, ProxyInterface):
        return x

    if is_collection(x):
        flat, spec = tree_flatten(x)

        printables = []
        for f in flat:
            printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx))

        printable = tree_unflatten(printables, spec)
        return printable
t-vi commented 3 months ago

More minimal repro to create a test in a fix:

import transformers
import torch
import thunder

def fn(x):
    return transformers.modeling_outputs.BaseModelOutput(x)

jfn = thunder.jit(fn)

x = torch.randn(5, 5)

print(jfn(x))