TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.55k stars 301 forks source link

Add Redwood 2L model #180

Closed ArthurConmy closed 1 year ago

ArthurConmy commented 1 year ago

Since Redwood's 2L model has been open-sourced, anyone should be able to download weights: https://drive.google.com/file/d/1CEZF9QEY2VEtOatgVh7VRL4SkDYyIFLL/view?usp=sharing

and use them in this library:

#%% [markdown]

import transformer_lens
import os
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.HookedTransformer import HookedTransformer

cfg = HookedTransformerConfig(
    n_layers=2, 
    d_model=256,
    n_ctx=2048,
    n_heads=8,
    d_head=32,
    # model_name : str = "custom"
    # d_mlp: Optional[int] = None
    # act_fn: Optional[str] = None
    d_vocab=50259,
    # eps: float = 1e-5
    use_attn_result = True,
    use_attn_scale = True, # divide by sqrt(d_head)
    # use_local_attn: bool = False
    # original_architecture: Optional[str] = None
    # from_checkpoint: bool = False
    # checkpoint_index: Optional[int] = None
    # checkpoint_label_type: Optional[str] = None
    # checkpoint_value: Optional[int] = None
    # tokenizer_name: Optional[str] = None
    # window_size: Optional[int] = None
    # attn_types: Optional[List] = None
    # init_mode: str = "gpt2"
    # normalization_type: Optional[str] = "LN"
    # device: Optional[str] = None
    # attention_dir: str = "causal"
    attn_only = True,
    # seed: Optional[int] = None
    # initializer_range: float = -1.0
    # init_weights: bool = True
    # scale_attn_by_inverse_layer_idx: bool = False
    positional_embedding_type = "shortformer",
    # final_rms: bool = False
    # d_vocab_out: int = -1
    # parallel_attn_mlp: bool = False
    # rotary_dim: Optional[int] = None
    # n_params: Optional[int] = None
    # use_hook_tokens: bool = False
)

et_model.load_state_dict(torch.load("et_model_state_dict.pt"))

would be great if someone could check that this works, then maybe someone at @redwoodresearch should add the model weights to HF or something?

neelnanda-io commented 1 year ago

I'm happy to put the weights on HuggingFace if you think Redwood would be fine with it?

On Fri, 24 Feb 2023, 12:33 ArthurConmy, @.***> wrote:

Since Redwood's 2L model has been open https://github.com/redwoodresearch/rust_circuit_public-sourced https://github.com/redwoodresearch/remix_public, anyone should be able to download weights: https://drive.google.com/file/d/1CEZF9QEY2VEtOatgVh7VRL4SkDYyIFLL/view?usp=sharing

and use them in this library:

%% [markdown]

import transformer_lens import os from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.HookedTransformer import HookedTransformer

cfg = HookedTransformerConfig( n_layers=2, d_model=256, n_ctx=2048, n_heads=8, d_head=32,

model_name : str = "custom"

# d_mlp: Optional[int] = None
# act_fn: Optional[str] = None
d_vocab=50259,
# eps: float = 1e-5
use_attn_result = True,
use_attn_scale = True, # divide by sqrt(d_head)
# use_local_attn: bool = False
# original_architecture: Optional[str] = None
# from_checkpoint: bool = False
# checkpoint_index: Optional[int] = None
# checkpoint_label_type: Optional[str] = None
# checkpoint_value: Optional[int] = None
# tokenizer_name: Optional[str] = None
# window_size: Optional[int] = None
# attn_types: Optional[List] = None
# init_mode: str = "gpt2"
# normalization_type: Optional[str] = "LN"
# device: Optional[str] = None
# attention_dir: str = "causal"
attn_only = True,
# seed: Optional[int] = None
# initializer_range: float = -1.0
# init_weights: bool = True
# scale_attn_by_inverse_layer_idx: bool = False
positional_embedding_type = "shortformer",
# final_rms: bool = False
# d_vocab_out: int = -1
# parallel_attn_mlp: bool = False
# rotary_dim: Optional[int] = None
# n_params: Optional[int] = None
# use_hook_tokens: bool = False

)

et_model.load_state_dict(torch.load("et_model_state_dict.pt"))

would be great if someone could check that this works, then maybe someone at @redwoodresearch https://github.com/redwoodresearch should add the model weights to HF or something?

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/180, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKLS5UTXMIAQJM6OP33WZCTDJANCNFSM6AAAAAAVG3DCSE . You are receiving this because you are subscribed to this thread.Message ID: @.***>

daniel-ziegler commented 1 year ago

Absolutely, go ahead!

neelnanda-io commented 1 year ago

Huh, I tried downloading it, and can't load the file in Python:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-d867ee4a87b8> in <module>
----> 1 torch.load("/workspace/_scratch/et_model_state_dict.pt")

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    703             # reset back to the original position.
    704             orig_position = opened_file.tell()
--> 705             with _open_zipfile_reader(opened_file) as opened_zipfile:
    706                 if _is_torchscript_zip(opened_zipfile):
    707                     warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in __init__(self, name_or_buffer)
    240 class _open_zipfile_reader(_opener):
    241     def __init__(self, name_or_buffer) -> None:
--> 242         super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
    243 
    244 

RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory

If I try pickle, I get:

Type 'copyright', 'credits' or 'license' for more information
IPython 7.31.1 -- An enhanced Interactive Python. Type '?' for help.

In IPython
In IPython
Set autoreload
Imported everything!
Using pad_token, but it is not set yet.
Loaded pretrained model pythia-160m into HookedTransformer
---------------------------------------------------------------------------
HTTPError                                 Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/huggingface_hub/utils/_errors.py in hf_raise_for_status(response, endpoint_name)
    212     try:
--> 213         response.raise_for_status()
    214     except HTTPError as e:

/opt/conda/lib/python3.7/site-packages/requests/models.py in raise_for_status(self)
   1020         if http_error_msg:
-> 1021             raise HTTPError(http_error_msg, response=self)
   1022 

HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/EleutherAI/pythia-160m-seed1/resolve/main/config.json

The above exception was the direct cause of the following exception:

EntryNotFoundError                        Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/transformers/utils/hub.py in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, use_auth_token, revision, local_files_only, subfolder, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash)
    419             use_auth_token=use_auth_token,
--> 420             local_files_only=local_files_only,
    421         )

/opt/conda/lib/python3.7/site-packages/huggingface_hub/file_download.py in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, use_auth_token, local_files_only, legacy_cache_layout)
   1056                     proxies=proxies,
-> 1057                     timeout=etag_timeout,
   1058                 )

/opt/conda/lib/python3.7/site-packages/huggingface_hub/file_download.py in get_hf_file_metadata(url, use_auth_token, proxies, timeout)
   1358     )
-> 1359     hf_raise_for_status(r)
   1360 

/opt/conda/lib/python3.7/site-packages/huggingface_hub/utils/_errors.py in hf_raise_for_status(response, endpoint_name)
    230             )
--> 231             raise EntryNotFoundError(message, response) from e
    232 

EntryNotFoundError: 404 Client Error. (Request ID: Root=1-63f9e721-789105633a9097b1307af268)

Entry Not Found for url: https://huggingface.co/EleutherAI/pythia-160m-seed1/resolve/main/config.json.

During handling of the above exception, another exception occurred:

OSError                                   Traceback (most recent call last)
<ipython-input-2-971ae5ec2179> in <module>
      1 model0 = HookedTransformer.from_pretrained("pythia-160m")
----> 2 model1 = HookedTransformer.from_pretrained("pythia-160m-1")
      3 model2 = HookedTransformer.from_pretrained("pythia-160m-2")
      4 print(evals.sanity_check(model0))
      5 print(evals.sanity_check(model1))

~/TransformerLens/transformer_lens/HookedTransformer.py in from_pretrained(cls, model_name, fold_ln, center_writing_weights, center_unembed, refactor_factored_attn_matrices, checkpoint_index, checkpoint_value, hf_model, device, move_state_dict_to_device, **model_kwargs)
    680             checkpoint_value=checkpoint_value,
    681             fold_ln=fold_ln,
--> 682             device=device,
    683         )
    684 

~/TransformerLens/transformer_lens/loading_from_pretrained.py in get_pretrained_model_config(model_name, checkpoint_index, checkpoint_value, fold_ln, device)
    523         cfg_dict = convert_neel_model_config(official_model_name)
    524     else:
--> 525         cfg_dict = convert_hf_model_config(official_model_name)
    526     # Processing common to both model types
    527     # Remove any prefix, saying the organization who made a model.

~/TransformerLens/transformer_lens/loading_from_pretrained.py in convert_hf_model_config(official_model_name)
    355     official_model_name = get_official_model_name(official_model_name)
    356     # Load HuggingFace model config
--> 357     hf_config = AutoConfig.from_pretrained(official_model_name)
    358     architecture = hf_config.architectures[0]
    359     if architecture == "GPTNeoForCausalLM":

/opt/conda/lib/python3.7/site-packages/transformers/models/auto/configuration_auto.py in from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
    807         kwargs["name_or_path"] = pretrained_model_name_or_path
    808         trust_remote_code = kwargs.pop("trust_remote_code", False)
--> 809         config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
    810         if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
    811             if not trust_remote_code:

/opt/conda/lib/python3.7/site-packages/transformers/configuration_utils.py in get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    557         original_kwargs = copy.deepcopy(kwargs)
    558         # Get config dict associated with the base config file
--> 559         config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
    560         if "_commit_hash" in config_dict:
    561             original_kwargs["_commit_hash"] = config_dict["_commit_hash"]

/opt/conda/lib/python3.7/site-packages/transformers/configuration_utils.py in _get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    624                     revision=revision,
    625                     subfolder=subfolder,
--> 626                     _commit_hash=commit_hash,
    627                 )
    628                 commit_hash = extract_commit_hash(resolved_config_file, commit_hash)

/opt/conda/lib/python3.7/site-packages/transformers/utils/hub.py in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, use_auth_token, revision, local_files_only, subfolder, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash)
    453             revision = "main"
    454         raise EnvironmentError(
--> 455             f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
    456             f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
    457         )

OSError: EleutherAI/pythia-160m-seed1 does not appear to have a file named config.json. Checkout 'https://huggingface.co/EleutherAI/pythia-160m-seed1/main' for available files.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-82f40539d1d1> in <module>
     44 )
     45 et_model = HookedTransformer(cfg)
---> 46 et_model.load_state_dict(torch.load("/workspace/_scratch/et_model_state_dict.pt"))

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    703             # reset back to the original position.
    704             orig_position = opened_file.tell()
--> 705             with _open_zipfile_reader(opened_file) as opened_zipfile:
    706                 if _is_torchscript_zip(opened_zipfile):
    707                     warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in __init__(self, name_or_buffer)
    240 class _open_zipfile_reader(_opener):
    241     def __init__(self, name_or_buffer) -> None:
--> 242         super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
    243 
    244 

RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-c0870d9efc9b> in <module>
     46 # et_model.load_state_dict(torch.load("/workspace/_scratch/et_model_state_dict.pt"))
     47 import pickle5
---> 48 pickle5.load("/workspace/_scratch/et_model_state_dict.pt")

TypeError: file must have 'read' and 'readline' attributes
---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
<ipython-input-5-5fcec0b276db> in <module>
     46 # et_model.load_state_dict(torch.load("/workspace/_scratch/et_model_state_dict.pt"))
     47 import pickle5
---> 48 pickle5.load(open("/workspace/_scratch/et_model_state_dict.pt", "rb"))

UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.
---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
<ipython-input-6-d3a13a278885> in <module>
     46 # et_model.load_state_dict(torch.load("/workspace/_scratch/et_model_state_dict.pt"))
     47 import pickle
---> 48 pickle.load(open("/workspace/_scratch/et_model_state_dict.pt", "rb"))

UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.

Did you do anything special to convert the weights to a file? This is PyTorch 1.12.0, so one guess is that there's something weird about compression algorithms

ArthurConmy commented 1 year ago

Uh, on my local machine (pytorch 1.13.0) I downloaded the file and

torch.load(os.path.expanduser("~/Downloads/et_model_state_dict.pt"), map_location="cpu")

worked. Also

fname = huggingface_hub.hf_hub_download(repo_id="ArthurConmy/redwood_attn_2l", filename="et_model_state_dict.pt")
torch.load(fname, map_location="cpu")

worked (in progress getting model loading to work...). Maybe the first error you got was because the download hadn't finished? From stackoverflow

neelnanda-io commented 1 year ago

Thanks, redownloading fixed it. I now have the issue that I cannot get the model to do induction...

To check the basic model details:

ArthurConmy commented 1 year ago

@neelnanda-io yes, this model has LayerNorms and shortformer positional embeddings.

Tokens 50257 and 50258 are "[END]" and "[BEGIN]" (which probably explains the no induction : ( sorry). The rest are the same as GPT-2.

I've also uploaded the validation_data.pt to HF, on which the results were collected