jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
386 stars 106 forks source link

[Bug Report] Saved Models Cannot be Loaded #168

Closed 4gatepylon closed 4 months ago

4gatepylon commented 4 months ago

Describe the bug Saved models cannot be loaded because some of the parameters in the config are missing. Please look the bottom first at additional context since it includes a likely cause and implies the solution.

Code example + error Here is tree of the model save directory:

<my folder>/
├── cfg.json
├── sae_weights.safetensors
└── sparsity.safetensor

Here is the config JSON:

{
    "model_name": "gpt2-small",
    "model_class_name": "HookedTransformer",
    "hook_name": "blocks.7.attn.hook_z",
    "hook_eval": "NOT_IN_USE",
    "hook_layer": 7,
    "hook_head_index": null,
    "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
    "streaming": true,
    "is_dataset_tokenized": true,
    "context_size": 512,
    "use_cached_activations": false,
    "cached_activations_path": null,
    "d_in": 768,
    "d_sae": 12288,
    "b_dec_init_method": "geometric_median",
    "expansion_factor": 16,
    "activation_fn": "relu",
    "normalize_sae_decoder": true,
    "noise_scale": 0.0,
    "from_pretrained_path": null,
    "apply_b_dec_to_input": true,
    "decoder_orthogonal_init": false,
    "decoder_heuristic_init": false,
    "init_encoder_as_decoder_transpose": false,
    "n_batches_in_buffer": 64,
    "training_tokens": 1000000,
    "finetuning_tokens": 0,
    "store_batch_size_prompts": 16,
    "train_batch_size_tokens": 128,
    "normalize_activations": false,
    "device": "cuda",
    "act_store_device": "cuda",
    "seed": 42,
    "dtype": "float32",
    "prepend_bos": true,
    "autocast": false,
    "autocast_lm": false,
    "compile_llm": false,
    "llm_compilation_mode": null,
    "compile_sae": false,
    "sae_compilation_mode": null,
    "adam_beta1": 0,
    "adam_beta2": 0.999,
    "mse_loss_normalization": null,
    "l1_coefficient": 0.001,
    "lp_norm": 1.0,
    "scale_sparsity_penalty_by_decoder_norm": false,
    "l1_warm_up_steps": 0,
    "lr": 0.0004,
    "lr_scheduler_name": "constant",
    "lr_warm_up_steps": 10000,
    "lr_end": 4e-05,
    "lr_decay_steps": 0,
    "n_restart_cycles": 1,
    "finetuning_method": null,
    "use_ghost_grads": false,
    "feature_sampling_window": 1000,
    "dead_feature_window": 1000,
    "dead_feature_threshold": 0.0001,
    "n_eval_batches": 10,
    "eval_batch_size_prompts": null,
    "log_to_wandb": true,
    "log_activations_store_to_wandb": false,
    "log_optimizer_state_to_wandb": false,
    "wandb_project": "sae_lens_tutorial",
    "wandb_id": null,
    "run_name": "12288-L1-0.001-LR-0.0004-Tokens-1.000e+06",
    "wandb_entity": null,
    "wandb_log_frequency": 10,
    "eval_every_n_wandb_logs": 100,
    "resume": false,
    "n_checkpoints": 2,
    "checkpoint_path": "<my checkpoint path>",
    "verbose": true,
    "model_kwargs": {},
    "model_from_pretrained_kwargs": {},
    "sae_lens_version": "3.0.0",
    "sae_lens_training_version": "3.0.0",
    "tokens_per_buffer": 4194304
}

and here is the code saving the path: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/training/sae_trainer.py#L176

and here is the code generating the error (loading):

sae, cfg_dict, _ = SAE.load_from_pretrained(
    path=Path(with_sae).as_posix(),
    device=device,
)

and here is the ERROR:

Traceback (most recent call last):
  File "/my/folder/main.py", line 410, in <module>
    cli()
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/my/folder/main.py", line 249, in check_trojan
    return check_trojan_f(query, run_dir, generation_path, with_sae)
  ################################ THE ACTUAL ERROR STARTS BELOW ################################
  File "/my/folder/trojans.py", line 683, in check_trojan
    sae, cfg_dict, _ = SAE.load_from_pretrained(
  File "/my/folder/.venv/lib/python3.10/site-packages/sae_lens/sae.py", line 301, in load_from_pretrained
    sae_cfg = SAEConfig.from_dict(cfg_dict)
  File "/my/folder/.venv/lib/python3.10/site-packages/sae_lens/sae.py", line 70, in from_dict
    return cls(**config_dict)
TypeError: SAEConfig.__init__() missing 1 required positional argument: 'finetuning_scaling_factor'

System Info This should NOT matter (look at additional context first please)

Requirements:

# ML and Viz, SWE
transformers[torch]
circuitsvis==1.41.0
datasets==2.19.1
einops==0.7.0
evaluate==0.4.2
fancy-einsum==0.0.3
jaxtyping==0.2.29
numpy==1.26.4
plotly==5.22.0
pydantic==2.7.2
pytest==8.2.1
pyyaml
requests==2.32.2
sae-lens==3.0.0
torch==2.3.0
transformer-lens==1.17.0
# Jupyter
jupyterlab==4.2.1
notebook==7.2.0
ipywidgets==8.1.3
ipython[all]
# CLI
click==8.1.7
# fire==0.6.0 # Not used yet
# Dependencies for cloud storage (etc...)
azure-identity==1.16.0
azure-storage-blob==12.20.0
boto3==1.34.114
google-cloud-storage==2.16.0
# Formatting, linting, typechecking, etc..
black==24.4.2

Additional context Perhaps, I have made some sort of usage mistake, but I have reason to believe this is a genuine mistake because AFAIK the SAE should have its cfg file overwritten; look:

I have not done an ablation test for this yet.

Checklist

jbloomAus commented 4 months ago

Really appreciate you writing the issue here!

I've taken a look and I think I understand the issue. I built most of the tests around the .from_pretrained method which runs through downloaded SAEs and we're missing a direct test SAEs saved by a trainer (a lot of stuff got refactored recently so maybe unsurprising we missed a path). As a result, we have a bunch of code handling defaulting of parameters that doesn't get used if you are loading from a local path. I'm not sure exactly when we lost that test coverage but that's on me and I'll get on it ASAP. Hopefully will have a fix out by tomorrow. In the meantime, you can probably add it manually to the config or edit the source code to default it if you are otherwise blocked.

4gatepylon commented 4 months ago

If there's any value in doing some sort of refactor to organize things in a less error prone way I might be happy to contribute a PR in the future. I'm not sure if this is just a one off or there is a better way to architect this sort of stuff.

The way I would imagine doing storage to be really easy to understand is:

  1. Have some sort of centralized pydantic configuration definitions file/folder. Thus way all jsons, etc... are easy to find.
  2. Have generic to_json and from_json as well as generic "save this thing's config". Whether the thing is an SAE, an SAE training run, some sort of future architectural change like swapping out a layer, a finetune, or something else it should IMO always be the same and be differentiated by a type string in the config.
  3. Kind of implied/implies 1 & 2 but logically configs form a pretty flat tree structure where each config is just a root config with a type and some additional information. This may allow for smaller configs (easier to understand), more scalable configs (more types will be easy to use) and minimize these errors and make correctness more automated (i.e. you can automatically save multiple configs in a non-overlapping way by having a generic "give me the name of this config" method, for example).

I think this is useful in the future if you want to add more features. I think this will become more of a thing because people are going to want to train solutions to superposition not only in more ways but also potentially with slightly different architectures. Features can also include stuff like low-memory training or performance optimizations (important to me because I've gotten OOM problems), Features naturally split into different classes (i.e. performance optimization vs. SAE training parameters and strategy vs. dataset and model vs. hook strategy vs. probably some other stuff vs. other stuff and miscellaneous and niche stuff). This is why the hierarchical approach seems right, but on the other hand this can become hellish if the nesting is deep so a tiny bit of thought should be put into it depending on the necessity.

4gatepylon commented 4 months ago

Hmm, I'm not sure if this is an error or I'm misusing but also something to maybe look into along with this: https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L375. __dict__ includes whatever you set, even if it's not one of the dataclass elements, for example tokens per buffer: https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L240.

I have not done an ablation for this one yet, but I can get a hotfix for myself, dw about that. Just pointing out that this might be of interest.

IPython Example:

In [13]: @dataclass
    ...: class A:
    ...:     x: int = 3
    ...: 

In [14]: a = A(x=5)

In [15]: a
Out[15]: A(x=5)

In [16]: a.__dict__
Out[16]: {'x': 5}

In [17]: a.y = 120

In [18]: a.__dict__
Out[18]: {'x': 5, 'y': 120}
In [19]: A(**a.__dict__)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 1
----> 1 A(**a.__dict__)

TypeError: A.__init__() got an unexpected keyword argument 'y'

If I try to load the cfg.json I get this error:

Traceback (most recent call last):
  File "/my/folder/main.py", line 410, in <module>
    cli()
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/my/folder/main.py", line 249, in check_trojan
    return check_trojan_f(query, run_dir, generation_path, with_sae)
  File "/my/folder/trojans.py", line 684, in check_trojan
    sae = load_sae_from_dir(__with_sae, device=device)
  File "/my/folder/saes.py", line 45, in load_sae_from_dir
    runner_config = LanguageModelSAERunnerConfig.from_json(sae_dir.as_posix() + "/")
  File "/my/folder/.venv/lib/python3.10/site-packages/sae_lens/config.py", line 396, in from_json
    return cls(**cfg)
TypeError: LanguageModelSAERunnerConfig.__init__() got an unexpected keyword argument 'tokens_per_buffer'

Calling code:

def load_sae_from_dir(sae_dir: Path | str, device: str = "cpu") -> SAE:
    """
    Due to a bug (https://github.com/jbloomAus/SAELens/issues/168) in the SAE save implementation for SAE Lens we need to make
    a specialized workaround.

    WARNING this will be creating a directory where the files are LINKED with the exception of "cfg.json" which is copied. This is NOT efficient
    and you should not be calling it many times!

    This wraps: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L284.

    SPECIFICALLY fix cfg.json.
    """
    sae_dir = Path(sae_dir)

    if not all([x.is_file() for x in sae_dir.iterdir()]):
        raise ValueError(
            "Not all files are present in the directory! Only files allowed for loading SAE Directory."
        )

    # https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L188
    runner_config = LanguageModelSAERunnerConfig.from_json(sae_dir.as_posix() + "/") # <--------------------------- FAIL
    sae_config = runner_config.get_training_sae_cfg_dict()
    sae = None
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_dir = Path(temp_dir)
        # Copy in the CFG
        sae_config_f = temp_dir / "cfg.json"
        with open(sae_config_f, "w") as f:
            json.dump(sae_config, f)
        # LINK (zero disk space) all the other files
        for name_f in sae_dir.iterdir():
            if name_f.name == "cfg.json":
                continue
            os.link(f, (temp_dir / f.name).as_posix())
        # Load SAE
        sae = SAE.load_from_dir(temp_dir, device=device)
    assert sae is not None and isinstance(sae, SAE)
    return sae
jbloomAus commented 4 months ago

Reasonable suggestions! I experimented with hierarchical approaches and it was very messy. I like the idea of cleaning this up but suspect we don't need a major rethink just yet. I'd accept a small contained PR that addresses this exact issue but am not open to any major overhauls right now. I expect a slight refactor so there is less differentiation between HF and local SAEs than we have currently. Sorry I don't have capacity to engage much more deeply at present. Please keep sharing any issues like these though!

On Wed, May 29, 2024, 6:50 PM razzle dazzle @.***> wrote:

If there's any value in doing some sort of refactor to organize things in a less error prone way I might be happy to contribute a PR in the future. I'm not sure if this is just a one off or there is a better way to architect this sort of stuff.

The way I would imagine doing storage to be really easy to understand is:

  1. Have some sort of centralized pydantic configuration definitions file/folder. Thus way all jsons, etc... are easy to find.
  2. Have generic to_json and from_json as well as generic "save this thing's config". Whether the thing is an SAE, an SAE training run, some sort of future architectural change like swapping out a layer, a finetune, or something else it should IMO always be the same and be differentiated by a type string in the config.
  3. Kind of implied/implies 1 & 2 but logically configs form a pretty flat tree structure where each config is just a root config with a type and some additional information. This may allow for smaller configs (easier to understand), more scalable configs (more types will be easy to use) and minimize these errors and make correctness more automated (i.e. you can automatically save multiple configs in a non-overlapping way by having a generic "give me the name of this config" method, for example).

I think this is useful in the future if you want to add more features. I think this will become more of a thing because people are going to want to train solutions to superposition not only in more ways but also potentially with slightly different architectures. Features can also include stuff like low-memory training or performance optimizations (important to me because I've gotten OOM problems), Features naturally split into different classes (i.e. performance optimization vs. SAE training parameters and strategy vs. dataset and model vs. hook strategy vs. probably some other stuff vs. other stuff and miscellaneous and niche stuff). This is why the hierarchical approach seems right, but on the other hand this can become hellish if the nesting is deep so a tiny bit of thought should be put into it depending on the necessity.

— Reply to this email directly, view it on GitHub https://github.com/jbloomAus/SAELens/issues/168#issuecomment-2137957373, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZYEAWT7ISSKLVWQRWLZEYINHAVCNFSM6AAAAABIPK6IMOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXHE2TOMZXGM . You are receiving this because you commented.Message ID: @.***>

4gatepylon commented 4 months ago

Fixed band-aid in case anyone needs to load these configs for SAEs below. Kept this external for now.

import tempfile
import os
from sae_lens import SAE
from pathlib import Path

def load_sae_from_dir(sae_dir: Path | str, device: str = "cpu") -> SAE:
    """
    Due to a bug (https://github.com/jbloomAus/SAELens/issues/168) in the SAE save implementation for SAE Lens we need to make
    a specialized workaround.

    WARNING this will be creating a directory where the files are LINKED with the exception of "cfg.json" which is copied. This is NOT efficient
    and you should not be calling it many times!

    This wraps: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L284.

    SPECIFICALLY fix cfg.json.
    """
    sae_dir = Path(sae_dir)

    if not all([x.is_file() for x in sae_dir.iterdir()]):
        raise ValueError(
            "Not all files are present in the directory! Only files allowed for loading SAE Directory."
        )

    # https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L188
    # Load ourselves instead of from_json because there are some __dir__ elements that are not in the JSON
    # They should ALL be enumerated in `derivatives`
    ##### BEGIN #####
    cfg_f = sae_dir / "cfg.json"
    with open(cfg_f, "r") as f:
        cfg = json.load(f)
    derivatives = [
        "tokens_per_buffer",
    ]
    derivative_values = [cfg[x] for x in derivatives]
    for x in derivatives:
        del cfg[x]
    runner_config = LanguageModelSAERunnerConfig(**cfg)
    assert all(
        [
            d in runner_config.__dict__ and runner_config.__dict__[d] == dv
            for d, dv in zip(derivatives, derivative_values)
        ]
    )
    del derivative_values
    del derivatives
    ##### END #####

    # Load the SAE
    sae_config = runner_config.get_training_sae_cfg_dict()
    sae = None
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_dir = Path(temp_dir)

        # Copy in the CFG
        sae_config_f = temp_dir / "cfg.json"
        with open(sae_config_f, "w") as f:
            json.dump(sae_config, f)
        # LINK (zero disk space) all the other files
        for name_f in sae_dir.iterdir():
            if name_f.name == "cfg.json":
                continue
            else:
                os.link(name_f.as_posix(), (temp_dir / name_f.name).as_posix())
        # Load SAE
        sae = SAE.load_from_pretrained(temp_dir, device=device)
    assert sae is not None and isinstance(sae, SAE)
    return sae