huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.57k stars 27.14k forks source link

llama `tie_word_embeddings` ignored on cpu and with auto dtype only #33689

Open kylesayrs opened 2 months ago

kylesayrs commented 2 months ago

System Info

platform: linux: ubuntu 22.04 python version: 3.10.12 transformers version: 4.44.2

Who can help?

No response

Information

Tasks

Reproduction

import torch
import pytest
from transformers import AutoModelForCausalLM

@pytest.mark.parametrize(
    "torch_dtype,tie_word_embeddings,device_map",
    [
        (torch.float16, False, "cpu"   ),  # passes
        (torch.float32, False, "cpu"   ),  # fails
        (torch.float32, False, "cuda:0"),  # passes

        (torch.float16, True, "cpu"   ),  # passes
        (torch.float32, True, "cpu"   ),  # passes
        (torch.float32, True, "cuda:0"),  # passes
    ],
)
def test_model_shared(torch_dtype, tie_word_embeddings, device_map, tmp_path):
    # load model
    model = AutoModelForCausalLM.from_pretrained(
        "Xenova/llama2.c-stories15M",
        torch_dtype=torch_dtype,
        tie_word_embeddings=tie_word_embeddings,
        device_map=device_map
    )

    # modify lm head
    with torch.no_grad():
        model.lm_head.weight += 1

    # check that embed_tokens is not modified
    if tie_word_embeddings:
        assert torch.equal(model.lm_head.weight, model.model.embed_tokens.weight)
    else:
        assert not torch.equal(model.lm_head.weight, model.model.embed_tokens.weight)

Expected behavior

I expect tied tensors should not be tied if tie_word_embeddings=False. Instead, the tensors are tied. Seems to be the root cause of #33688

ydshieh commented 1 month ago

It turns out that

_load_state_dict_into_meta_model

has some issues.

(as well as _load_state_dict_into_model)

ydshieh commented 1 month ago

To be continued with this

import os

import torch
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
from transformers import LlamaForCausalLM, LlamaConfig

checkpoint = "Xenova/llama2.c-stories15M"
local_path = "my_model"

LlamaForCausalLM.from_pretrained(checkpoint).save_pretrained(local_path, safe_serialization=False)
pytorch_bin = os.path.join(local_path, "pytorch_model.bin")

config = LlamaConfig.from_pretrained(checkpoint)
model = LlamaForCausalLM(config=config)
state_dict = load_state_dict(pytorch_bin)
expected_keys = list(model.state_dict().keys())

for dtype in [torch.float16, torch.float32]:
    _load_state_dict_into_meta_model(
        model=model,
        state_dict=state_dict,
        start_prefix='',
        expected_keys=expected_keys,
        device_map={'': torch.device('cpu')},
        offload_folder=None,
        offload_index=None,
        state_dict_folder=None,
        state_dict_index=None,
        dtype=dtype,
        hf_quantizer=None,
        is_safetensors=False,
        keep_in_fp32_modules=[],
        unexpected_keys=[],
        pretrained_model_name_or_path=None
    )
    print(model.state_dict()["model.embed_tokens.weight"].data_ptr() == model.state_dict()["lm_head.weight"].data_ptr())

with prints

False
True