pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.68k stars 514 forks source link

Missing Keys in state_dict #172

Open bjohn22 opened 6 months ago

bjohn22 commented 6 months ago

I downloaded nvidia/Llama3-ChatQA-1.5-8B manually from HF into local. I ran scripts/convert_hf_checkpoint.py Then I wanted to run generate.py using the local checkpoint dir:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Transformer: Missing key(s) in state_dict: "tok_embeddings.weight", "layers.0.attention.wqkv.weight", "layers.0.attention.wo.weight", "layers.0.feed_forward.w1.weight", "layers.0.feed_forward.w3.weight", "layers.0.feed_forward.w2.weight", "layers.0.ffn_norm.weight", "layers.0.attention_norm.weight",

Here is my weight directory: image

yanboliang commented 2 months ago

Actually Llama3-ChatQA-1.5-8B is not supported, please check all supported models at: https://github.com/pytorch-labs/gpt-fast/blob/c9f683edd4f89d3e81ed8f52387e866a245e3226/model.py#L60-L81

But I think you can replace llama-3-8b in the list with Llama3-ChatQA-1.5-8B and play around it. They should have the same architecture.

bjohn22 commented 2 months ago

Thank you for this comment.

Get Outlook for Androidhttps://aka.ms/AAb9ysg


From: Yanbo Liang @.> Sent: Sunday, September 15, 2024 11:31:41 PM To: pytorch-labs/gpt-fast @.> Cc: John B Olan @.>; Author @.> Subject: Re: [pytorch-labs/gpt-fast] Missing Keys in state_dict (Issue #172)

Actually Llama3-ChatQA-1.5-8B is not supported, please check all supported models at: https://github.com/pytorch-labs/gpt-fast/blob/c9f683edd4f89d3e81ed8f52387e866a245e3226/model.py#L60-L81

But I think you can replace llama-3-8b in the list with Llama3-ChatQA-1.5-8B and play around it. They should have the same architecture.

— Reply to this email directly, view it on GitHubhttps://github.com/pytorch-labs/gpt-fast/issues/172#issuecomment-2352001956, or unsubscribehttps://github.com/notifications/unsubscribe-auth/APJRHRK4Q2GP3ON2G2WQZTLZWZNK3AVCNFSM6AAAAABHIFYWXCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNJSGAYDCOJVGY. You are receiving this because you authored the thread.Message ID: @.***>