huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.45k stars 27.11k forks source link

llama3 not fx traced #33966

Open myungjin opened 1 month ago

myungjin commented 1 month ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

from transformers.utils.fx import symbolic_trace
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
gm = symbolic_trace(model, input_names=["input_ids", "attention_mask", "past_key_values"])
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1503, in symbolic_trace
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1326, in trace
    self.graph = super().trace(root, concrete_args=concrete_args)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 822, in trace
    (self.create_arg(fn(*args)),),
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1190, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
    layer_outputs = decoder_layer(
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1190, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 729, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 800, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 1190, in call_module
    return super().call_module(m, forward, args, kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 518, in call_module
    ret_val = forward(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 646, in forward
    if query_states.device.type == "cuda" and causal_mask is not None:
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/transformers/utils/fx.py", line 669, in __bool__
    return super().__bool__()
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/proxy.py", line 447, in __bool__
    return self.tracer.to_bool(self)
  File "~/miniconda3/envs/pytorch-dev/lib/python3.10/site-packages/torch/fx/proxy.py", line 307, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

Expected behavior

it should run without error. symbolic_trace() with ["input_ids", "attention_mask"] runs fine. However, when ["input_ids", "attention_mask", "past_key_values"] is fed as input_names, the error occurs. If using past_key_values is incorrect, it should be warned and aborted before trying to trace the model.

While a fix on a related error (https://github.com/huggingface/transformers/issues/29923) is included in the released version, it seems there is still some bug.

ArthurZucker commented 1 month ago

Hey! Thanks for reporting! cc @michaelbenayoun I can reproduce, but no idea how to avoid this. It's also not new, so I think past key value path was not tested!

ziyueluocs commented 1 month ago

@ArthurZucker I’ve experienced the same issue. It seems that although the fix on https://github.com/huggingface/transformers/issues/29923 resolves the problem related to cache, in https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/utils/fx.py#L1052-L1068, we’re still creating dummy_inputs for past_key_values using the old tuple method. I replaced this part with:

elif "past_key_values" in input_name:
    if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
        raise NotImplementedError(
            f"Symbolic tracing with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in the Transformers repository if you'd like to see this support added."
        )
    inputs_dict[input_name] = DynamicCache()

It seems we can now trace Llama. However, I’m not sure if this is the correct way to fix the issue.

ArthurZucker commented 1 month ago

I think there are fixes for direct tracing in optimum, but happy to have something more general here!

myungjin commented 1 month ago

@ArthurZucker Thanks for looking into this. Can you provide a link or pointer to the fixes you refer to? I went through optimum document. Examples in the doc still rely on symbolic_trace, so I don't know how optimum can help addressing this issue.

ArthurZucker commented 1 month ago

The linked PR should help!