huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

End-to-end generation compile stopped working #33794

Closed binkjakub closed 1 month ago

binkjakub commented 1 month ago

System Info

Who can help?

@ArthurZucker @gante

Information

Tasks

Reproduction

  1. Install transformers>=4.45.0 and accelerate
  2. Run example script presented in release notes (https://github.com/huggingface/transformers/releases/tag/v4.44.0)
    
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch
    import copy

model = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B", torch_dtype=torch.bfloat16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")

compile generate

compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")

compiled generate does NOT accept parameterization except a) model inputs b) a generation config

generation_config = copy.deepcopy(model.generation_config) generation_config.pad_token_id = model.config.eos_token_id

model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt") model_inputs = model_inputs.to(model.device) output_compiled = compiled_generate(**model_inputs, generation_config=generation_config) print(output_compiled)


3. Observe the following error

Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:04<00:00, 1.18s/it] Traceback (most recent call last): File "my_script.py", line 19, in output_compiled = compiled_generate(model_inputs, generation_config=generation_config) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn return fn(*args, *kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1116, in call return self._torchdynamo_orig_callable( File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in call return _compile( File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_utils_internal.py", line 84, in wrapper_function return StrobelightCompileTimeProfiler.profile_compile_time( File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time return func(args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/contextlib.py", line 79, in inner return func(*args, kwds) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper r = func(*args, *kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner out_code = transform_code_object(code, transform) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object transformations(instructions, code_options) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn return fn(args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform tracer.run() File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run super().run() File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 132, in realize_and_forward return getattr(self.realize(), name)(*args, *kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function return super().call_function(tx, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function return super().call_function(tx, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return(self, [self.self_args(), args], kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inlinecall(parent, func, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inlinecall tracer.run() File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function return super().call_function(tx, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return(self, [self.self_args(), args], kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inlinecall(parent, func, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inlinecall tracer.run() File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1459, in CALL_FUNCTION self.call_function(fn, args, {}) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function return super().call_function(tx, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function return super().call_function(tx, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function return tx.inline_user_function_return(self, [self.self_args(), *args], kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call return cls.inlinecall(parent, func, args, kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inlinecall tracer.run() File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run while self.step(): File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step self.dispatch_table[inst.opcode](self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper return inner_fn(self, inst) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1459, in CALL_FUNCTION self.call_function(fn, args, {}) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function self.push(fn.call_function(self, args, kwargs)) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 301, in call_function unimplemented(f"call_function {self} {args} {kwargs}") File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented raise Unsupported(msg) torch._dynamo.exc.Unsupported: call_function ConstantVariable(method: <bound method PretrainedConfig.get_text_config of LlamaConfig { "_name_or_path": "meta-llama/Meta-Llama-3.1-8B", "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "head_dim": 128, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 131072, "mlp_bias": false, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": { "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3" }, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.45.1", "use_cache": true, "vocab_size": 128256 }

) [] {}

from user code: File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 38, in inner return fn(*args, *kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(args, **kwargs) File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/transformers/generation/utils.py", line 1922, in generate self._prepare_cache_for_generation( File "~/miniconda3/envs/my_env/lib/python3.10/site-packages/transformers/generation/utils.py", line 1605, in _prepare_cache_for_generation num_hidden_layers = self.config.get_text_config().num_hidden_layers

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True



### Expected behavior

The script should perform full compilation and generate LLM output. It behaves as expected when using previous version of transformers, i.e., `<4.45.0`
ArthurZucker commented 1 month ago

cc @gante and @zucchini-nlp !

Warra07 commented 1 month ago

I tried to reproduce your error and downgrading transformers to see if it works for previous version as stated but i think this isn't it.

It depend on the model architecture you're using i think, torch dynamo seems not to be fully implemented for Llama 3.1, it would require disabling the specific functions that can't be compiled.

i get your exact same error with torch 2.4.1 and transformers (4.44.x and 4.45.x), the 2.4.1. i tried upgrading torch to pre release version 2.6.0 (which has more implemented function from dynamo)

the initial error disappear, but other error may occur if a mutable attribute aren't registered, i for example got the error from DynamicCache key_cache attribute, a quick fix for those attribute is to initialize those attributes with torch.jit.Attribute.

However other torch dynamo missing implementation error appears, i didn't dig any further for now.

so maybe avoid compiling with llama3 for now ? :>

ArthurZucker commented 1 month ago

No no, compile support should work for llama 2, 3, 3.1 and 3.2 if not vision! @gante will work on a fix!

gante commented 1 month ago

Hi @binkjakub @Warra07 πŸ‘‹ Thank you for raising the issue!

33861 is part of the fix, I'm still investigating the complete fix :) I'm also trying to figure out a way of adding a test for this feature without resulting in a super slow test, to avoid regressions πŸ€—

gante commented 1 month ago

33861 now fixes this issue πŸ€— It also adds a fast test (tested on all relevant commits) to confirm we don't regress on end-to-end compilation