TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.51k stars 295 forks source link

[Proposal] Add function #754

Open hamind opened 1 week ago

hamind commented 1 week ago

Proposal

Change some code that could load model locally.

Motivation

Today I want to load gpt2 model that download from huggingface website locally like Llama, but it keeps try to conncetting huggingface to download. Then I check the code and find that

  1. No loadable local model approach
  2. If huggingface model already exists, there is no need to download model config from huggingface and could direct get it fron huggingface model.

Pitch

For model downloaded from huggingface or not cache, providing a approach to load model locally.

Alternatives

  1. For function "HookedTransformer.from_pretrained", could consider to add parameters to pass local model address.
  2. If huggingface model already exists, get the config from huggingface model directly.

Checklist

bryce13950 commented 1 week ago

Could you share the code you are using to load TransformerLens? You should be able to pass in your local version of the model with the param hf_model

hamind commented 1 week ago

I've modified less code, so I've just pasted the relevant code directly here. I've labeled the python file location and line number of the code, as well as the original version of the code which I've represented as a comment, with the new code shown below the old code for your convenience in checking.

 In transformer_lens.HookedTransformer.py line 1257

 cfg = loading.get_pretrained_model_config(
            official_model_name,
            # hf_cfg=hf_cfg
            hf_cfg=hf_model.config,
            checkpoint_index=checkpoint_index,
            checkpoint_value=checkpoint_value,
            fold_ln=fold_ln,
            device=device,
            n_devices=n_devices,
            default_prepend_bos=default_prepend_bos,
            dtype=dtype,
            first_n_layers=first_n_layers,
            **from_pretrained_kwargs,
        )
In transformer_lens.loading_from_pretrained.py line 1583

# if hf_cfg is not None:
#     cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config"# {}).get("load_in_4bit", False)

if hf_cfg is not None:
    cfg_dict["load_in_4bit"] = hf_cfg.to_dict().get("quantization_config"{}).get("load_in_4bit", False)
In transformer_lens.loading_from_pretrained.py line 708

# def convert_hf_model_config(model_name: str, **kwargs):
def convert_hf_model_config(model_name: str, hf_config = None, **kwargs):
    """
    Returns the model config for a HuggingFace model, converted to a dictionary
    in the HookedTransformerConfig format.

    Takes the official_model_name as an input.
    """
    if (Path(model_name) / "config.json").exists():
        logging.info("Loading model config from local directory")
        official_model_name = model_name
    else:
        official_model_name = get_official_model_name(model_name)

    # Load HuggingFace model config
    if "llama" in official_model_name.lower():
        architecture = "LlamaForCausalLM"
    elif "gemma-2" in official_model_name.lower():
        architecture = "Gemma2ForCausalLM"
    elif "gemma" in official_model_name.lower():
        architecture = "GemmaForCausalLM"
    else:
        # huggingface_token = os.environ.get("HF_TOKEN", None)
        # hf_config = AutoConfig.from_pretrained(
        #     official_model_name,
        #     token=huggingface_token,
        #     **kwargs,
        # )
        if hf_config is None:
            huggingface_token = os.environ.get("HF_TOKEN", None)
            hf_config = AutoConfig.from_pretrained(
                official_model_name,
                token=huggingface_token,
                **kwargs,
            )
        architecture = hf_config.architectures[0]
    ...
In transformer_lens.loading_from_pretrained.py line 1525 and line 1543

    if Path(model_name).exists():
        # If the model_name is a path, it's a local model
        # cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
        cfg_dict = convert_hf_model_config(model_name, hf_cfg, **kwargs)
        official_model_name = model_name
    else:
        official_model_name = get_official_model_name(model_name)
    if (
        official_model_name.startswith("NeelNanda")
        or official_model_name.startswith("ArthurConmy")
        or official_model_name.startswith("Baidicoot")
    ):
        cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
    else:
        if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
            "trust_remote_code", False
        ):
            logging.warning(
                f"Loading model {official_model_name} requires setting trust_remote_code=True"
            )
            kwargs["trust_remote_code"] = True
        # cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
        cfg_dict = convert_hf_model_config(official_model_name, hf_cfg, **kwargs)