hao-ai-lab / LookaheadDecoding

Apache License 2.0
1.04k stars 63 forks source link

Can't run minimal.py on A100 #27

Closed jiqing-feng closed 7 months ago

jiqing-feng commented 7 months ago

It seem that the codes does not support the newest version of transformers, so I installed transformers==4.34.0 which is the version in requirements.txt. I got this error when I run python minimal.py

Traceback (most recent call last):
  File "/workspace/jiqing/LookaheadDecoding/minimal.py", line 18, in <module>
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    model_class = _get_model_class(config, cls._model_mapping)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 387, in _get_model_class
    supported_models = model_mapping[type(config)]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 739, in __getitem__
    return self._load_attr_from_module(model_type, model_name)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 753, in _load_attr_from_module
    return getattribute_from_module(self._modules[module_name], attr)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 697, in getattribute_from_module
    if hasattr(module, attr):
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 1272, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 1284, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
cannot import name 'flash_attn_func' from 'flash_attn' (/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py)

When I use the newest version of transformers, the following error occurs:

Traceback (most recent call last):
  File "/workspace/jiqing/LookaheadDecoding/minimal.py", line 28, in <module>
    greedy_output = model.generate(**model_inputs, max_new_tokens=1, do_sample=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 1717, in generate
    return self.greedy_search(
  File "/workspace/jiqing/LookaheadDecoding/lade/decoding.py", line 23, in greedy_search_proxy
    return jacobi_greedy_search_multilevel(self, chat=False, *args, **kwargs)
  File "/workspace/jiqing/LookaheadDecoding/lade/decoding.py", line 278, in jacobi_greedy_search_multilevel
    outputs = self.jforward_multilevel(
  File "/workspace/jiqing/LookaheadDecoding/lade/models/llama.py", line 383, in jforward_multilevel
    outputs = self.model.LlamaModeljforward(
  File "/workspace/jiqing/LookaheadDecoding/lade/models/llama.py", line 198, in LlamaModeljforward
    attention_mask = self.j_prepare_decoder_attention_mask(
  File "/workspace/jiqing/LookaheadDecoding/lade/models/llama.py", line 119, in j_prepare_decoder_attention_mask
    expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
  File "/workspace/jiqing/transformers/src/transformers/models/llama/modeling_llama.py", line 83, in _expand_mask
    return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
AttributeError: type object 'AttentionMaskConverter' has no attribute '_prepare_4d_attention_mask'
Viol2000 commented 7 months ago

Hi, I think transformers==4.34.0 is ok. For newer transformers, because they changed their API, we need extra efforts to support them. Our currently released version does not support flash_attn, so please do not load the model with flash-attn. In terms of your problem, I think you can still use 4.34.0 and use pip install flash_attn to solve it.

jiqing-feng commented 7 months ago

Hi, I think transformers==4.34.0 is ok. For newer transformers, because they changed their API, we need extra efforts to support them. Our currently released version does not support flash_attn, so please do not load the model with flash-attn. In terms of your problem, I think you can still use 4.34.0 and use pip install flash_attn to solve it.

Thx! I have fixed the problem with the newest version of transformers, see here. Lade should run well with this change.

BTW, have you ever tried Lade on CPU? It seems to have a large performance decay on the CPU. Do you have any clue about it?

Viol2000 commented 7 months ago

Yes, running on the CPU should be very slow. But maybe you can set the hyper-parameter (LEVEL, WINDOW_SIZE, and GUESS_SET_SIZE in line config_lade) to a small number to see if it will have minor speedups. The main reason is that our method requires trading steps with flops. You can refer to our blog about this. CPU often has very few extra flops, so we can not expect a speedup.