huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.53k stars 27.13k forks source link

Torch.compile fail during inference with meta-llama/Meta-Llama-3.1-8B-Instruct #34604

Open prasiyer opened 3 weeks ago

prasiyer commented 3 weeks ago

System Info

Who can help?

@gante , @ArthurZucker While using torch.compile(), I get the following error. I have included the sample code in the "Steps to reproduce"

Error:
Traceback (most recent call last):
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/route_utils.py", line 276, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1923, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1506, in call_function
    prediction = await fn(*processed_input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/utils.py", line 785, in async_wrapper
    response = await f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/chat_interface.py", line 607, in _submit_fn
    response = await anyio.to_thread.run_sync(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2134, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 851, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vp899/projects/Agent_System/Code/Agent_Launch_UI_v2_Experiments.py", line 253, in contract_analyst_chat
    outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 482, in transform
    tracer = InstructionTranslator(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2085, in __init__
    self._throw_if_in_functorch()
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2126, in _throw_if_in_functorch
    eager = torch._dynamo.lookup_backend("eager")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 58, in lookup_backend
    _lazy_import()
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 91, in _lazy_import
    import_submodule(backends)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1866, in import_submodule
    importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
  File "/anaconda/envs/pi2_py311/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/cudagraphs.py", line 10, in <module>
    from torch._inductor.cudagraph_trees import cudagraphify_impl
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 71, in <module>
    from torch._inductor.compile_fx import (
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 57, in <module>
    from .fx_passes.joint_graph import joint_graph_passes
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 12, in <module>
    from ..pattern_matcher import (
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py", line 46, in <module>
    from .lowering import fallback_node_due_to_unsupported_type
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 6002, in <module>
    import_submodule(kernel)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1866, in import_submodule
    importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
  File "/anaconda/envs/pi2_py311/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py", line 155, in <module>
    flex_attention_template = TritonTemplate(
                              ^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 453, in __init__
    self.template = self._template_from_string(source)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/codegen/common.py", line 1720, in _template_from_string
    return env.from_string(source)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 1108, in from_string
    return cls.from_code(self, self.compile(source), gs, None)
                               ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 768, in compile
    self.handle_exception(source=source_hint)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 939, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<unknown>", line 104, in template
torch._dynamo.exc.InternalTorchDynamoError: No filter named 'indent_except_first'.

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Information

Tasks

Reproduction

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, token = llama31_hf_token)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", token = llama31_hf_token, attn_implementation="flash_attention_2",)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

...
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)

Expected behavior

Model should compile and model.generate should yield the answer

Vishal-Padia commented 3 weeks ago

Can you rerun the code after updating the PyTorch to 2.5.1 and Transformers to 4.46.1. You can do it by running the following command: pip install -U torch transformers This will update the torch and transformer to the latest version.

prasiyer commented 2 weeks ago

@Vishal-Padia I did the update. Still getting the following error:

Traceback (most recent call last):
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/gradio/queueing.py", line 622, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/gradio/blocks.py", line 2014, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1565, in call_function
    prediction = await fn(*processed_input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/gradio/utils.py", line 813, in async_wrapper
    response = await f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/gradio/chat_interface.py", line 638, in _submit_fn
    response = await anyio.to_thread.run_sync(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2441, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 943, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vp899/projects/Agent_System/Code/Agent_Launch_UI_v2_Torch_Compile_Expmt.py", line 262, in contract_analyst_chat
    outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 3206, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 904, in call_function
    return self.func.call_function(tx, merged_args, merged_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2102, in CONTAINS_OP
    self.push(right.call_method(self, "__contains__", [left], {}))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 900, in method___contains__
    return result.call_method(tx, "item", [], {})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 769, in method_item
    unimplemented("Tensor.item")
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

from user code:
   File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 40, in inner
    return fn(*args, **kwargs)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
    outputs = self.model(
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 915, in forward
    causal_mask = self._update_causal_mask(
  File "/anaconda/envs/pi4_py311/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 992, in _update_causal_mask
    if attention_mask is not None and 0.0 in attention_mask:
Vishal-Padia commented 2 weeks ago

Try doing this: Instead of using fullgraph while compilation, use Dynamic Mode model.forward = torch.compile(model.forward, mode="reduce-overhead", dynamic=True)

ArthurZucker commented 1 week ago

Hey! If you are using flash-attention, you need to be a bit more carful about how you are using compile. #33932 shows how you can use flash attention. My best recommendation is for you to use sdpa instead ! 🤗

prasiyer commented 6 days ago

@ArthurZucker - for sdpa, are you recommending the default attn_implementation while using AutoModelForCausalLM.from_pretrained or instead, the sdpa_kernel from torch.nn.attention? If you are referring to the sdpa_kernel, can you pls share example usage?

@Vishal-Padia - I will test with dynamic mode and report the results.

ArthurZucker commented 1 day ago

are you recommending the default attn_implementation while using AutoModelForCausalLM.from_pretrained or instead, the sdpa_kernel from torch.nn.attention? If you are referring to the sdpa_kernel, can you pls share example usage?

That is what I recommend. By default if you don't specify an attention implementation, SDPA will be used! Now we removed the graph breaks for flash_attention in #33932 but it's not super native to use!