Open hamind opened 1 week ago
Could you share the code you are using to load TransformerLens? You should be able to pass in your local version of the model with the param hf_model
I've modified less code, so I've just pasted the relevant code directly here. I've labeled the python file location and line number of the code, as well as the original version of the code which I've represented as a comment, with the new code shown below the old code for your convenience in checking.
In transformer_lens.HookedTransformer.py line 1257
cfg = loading.get_pretrained_model_config(
official_model_name,
# hf_cfg=hf_cfg
hf_cfg=hf_model.config,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
dtype=dtype,
first_n_layers=first_n_layers,
**from_pretrained_kwargs,
)
In transformer_lens.loading_from_pretrained.py line 1583
# if hf_cfg is not None:
# cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config"# {}).get("load_in_4bit", False)
if hf_cfg is not None:
cfg_dict["load_in_4bit"] = hf_cfg.to_dict().get("quantization_config"{}).get("load_in_4bit", False)
In transformer_lens.loading_from_pretrained.py line 708
# def convert_hf_model_config(model_name: str, **kwargs):
def convert_hf_model_config(model_name: str, hf_config = None, **kwargs):
"""
Returns the model config for a HuggingFace model, converted to a dictionary
in the HookedTransformerConfig format.
Takes the official_model_name as an input.
"""
if (Path(model_name) / "config.json").exists():
logging.info("Loading model config from local directory")
official_model_name = model_name
else:
official_model_name = get_official_model_name(model_name)
# Load HuggingFace model config
if "llama" in official_model_name.lower():
architecture = "LlamaForCausalLM"
elif "gemma-2" in official_model_name.lower():
architecture = "Gemma2ForCausalLM"
elif "gemma" in official_model_name.lower():
architecture = "GemmaForCausalLM"
else:
# huggingface_token = os.environ.get("HF_TOKEN", None)
# hf_config = AutoConfig.from_pretrained(
# official_model_name,
# token=huggingface_token,
# **kwargs,
# )
if hf_config is None:
huggingface_token = os.environ.get("HF_TOKEN", None)
hf_config = AutoConfig.from_pretrained(
official_model_name,
token=huggingface_token,
**kwargs,
)
architecture = hf_config.architectures[0]
...
In transformer_lens.loading_from_pretrained.py line 1525 and line 1543
if Path(model_name).exists():
# If the model_name is a path, it's a local model
# cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
cfg_dict = convert_hf_model_config(model_name, hf_cfg, **kwargs)
official_model_name = model_name
else:
official_model_name = get_official_model_name(model_name)
if (
official_model_name.startswith("NeelNanda")
or official_model_name.startswith("ArthurConmy")
or official_model_name.startswith("Baidicoot")
):
cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
else:
if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
"trust_remote_code", False
):
logging.warning(
f"Loading model {official_model_name} requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True
# cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
cfg_dict = convert_hf_model_config(official_model_name, hf_cfg, **kwargs)
Proposal
Change some code that could load model locally.
Motivation
Today I want to load gpt2 model that download from huggingface website locally like Llama, but it keeps try to conncetting huggingface to download. Then I check the code and find that
Pitch
For model downloaded from huggingface or not cache, providing a approach to load model locally.
Alternatives
Checklist