Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

AssertionError from process_recorded_modifications #1405

Open IvanYashchuk opened 1 week ago

IvanYashchuk commented 1 week ago

🐛 Bug

I get the following assertion error from Thunder JIT:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1731, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges, ad_hoc_executor)
   1729         prims.python_return(result)
   1730         computation_trace.set_current_source_location(None, None)
-> 1731         process_recorded_modifications(ctx, epilogue_trace)
   1732         last_interpreter_log = jfn._last_interpreter_log
   1734 pro_to_comp, computation_intermediates = get_computation_inputs_and_intermediates(computation_trace)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1610, in process_recorded_modifications(ctx, epilogue_trace)
   1608 if inst == PseudoInst.STORE_SUBSCR:
   1609     (value,) = args
-> 1610     assert isinstance(value.value, Proxy)
   1612     assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR
   1613     assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT

AssertionError:

using the following script:

import torch
import thunder
from transformers import Qwen2Config, Qwen2ForCausalLM

# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json
configuration = Qwen2Config(
    # Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default
    # config uses Multi-Head Attention
    num_attention_heads=28,
    num_key_value_heads=4,
    # Scaled down for testing
    hidden_size=56,
    vocab_size=4096,
)
configuration.num_hidden_layers = 1
with torch.device("cuda"):
    model = Qwen2ForCausalLM(configuration).to(torch.bfloat16)

compiled_model = thunder.jit(model)
input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda")
compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)

The transformers version is 4.45.2

t-vi commented 1 week ago

Thank you @IvanYashchuk

t-vi commented 1 week ago

@IvanYashchuk what transformers version are you using? I'm getting an indexing error with 4.43-ish and a different assertion error (but needs fixing, too) with 4.46.2.

t-vi commented 1 week ago

So with 4.46.2 and the following lookaside, things seem to work:

from transformers.modeling_utils import PreTrainedModel

@thunder.core.jit_ext.register_general_jit_lookaside(PreTrainedModel.loss_function.fget)
@thunder.core.jit_ext.interpreter_needs_wrap
def fn(*args, **kwargs):
    return PreTrainedModel.loss_function.fget(*args, **kwargs)

The loss_function property uses the Python re module to parse the loss function config, I wonder if we should allow marking modules as "treat everything here as opaque". Dangerous tool, but I think it might be more reasonable than relying on the internals of transformers. @lantiga for UX thoughts

IvanYashchuk commented 1 week ago

I get this error with transformers version of 4.45.2