Open kylesayrs opened 2 months ago
It turns out that
_load_state_dict_into_meta_model
has some issues.
(as well as _load_state_dict_into_model
)
To be continued with this
import os
import torch
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
from transformers import LlamaForCausalLM, LlamaConfig
checkpoint = "Xenova/llama2.c-stories15M"
local_path = "my_model"
LlamaForCausalLM.from_pretrained(checkpoint).save_pretrained(local_path, safe_serialization=False)
pytorch_bin = os.path.join(local_path, "pytorch_model.bin")
config = LlamaConfig.from_pretrained(checkpoint)
model = LlamaForCausalLM(config=config)
state_dict = load_state_dict(pytorch_bin)
expected_keys = list(model.state_dict().keys())
for dtype in [torch.float16, torch.float32]:
_load_state_dict_into_meta_model(
model=model,
state_dict=state_dict,
start_prefix='',
expected_keys=expected_keys,
device_map={'': torch.device('cpu')},
offload_folder=None,
offload_index=None,
state_dict_folder=None,
state_dict_index=None,
dtype=dtype,
hf_quantizer=None,
is_safetensors=False,
keep_in_fp32_modules=[],
unexpected_keys=[],
pretrained_model_name_or_path=None
)
print(model.state_dict()["model.embed_tokens.weight"].data_ptr() == model.state_dict()["lm_head.weight"].data_ptr())
with prints
False
True
System Info
platform: linux:
ubuntu 22.04
python version:3.10.12
transformers version:4.44.2
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
I expect tied tensors should not be tied if
tie_word_embeddings=False
. Instead, the tensors are tied. Seems to be the root cause of #33688