bminixhofer / zett

Code for Zero-Shot Tokenizer Transfer
https://arxiv.org/abs/2405.07883
101 stars 7 forks source link

How to use the hyper network for Llama3-8b #10

Open gushu333 opened 1 month ago

gushu333 commented 1 month ago

Hi, I want to use the hyper network for llama3-8b and tried to load the hyper network with the following command:

hypernet = AutoModel.from_pretrained('zett-hypernetwork-Meta-Llama-3-8B-experimental', trust_remote_code=True, from_flax=True)

but ran into the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 563, in from_pretrained
    return model_class.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 3404, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 923, in __init__
    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py", line 134, in __init__
    assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
AssertionError: Padding_idx must be within num_embeddings

So, could you please tell me how to use the hyper network for llama3-8b for inference? Thanks!