Closed 4gatepylon closed 4 months ago
Really appreciate you writing the issue here!
I've taken a look and I think I understand the issue. I built most of the tests around the .from_pretrained
method which runs through downloaded SAEs and we're missing a direct test SAEs saved by a trainer (a lot of stuff got refactored recently so maybe unsurprising we missed a path). As a result, we have a bunch of code handling defaulting of parameters that doesn't get used if you are loading from a local path. I'm not sure exactly when we lost that test coverage but that's on me and I'll get on it ASAP. Hopefully will have a fix out by tomorrow. In the meantime, you can probably add it manually to the config or edit the source code to default it if you are otherwise blocked.
If there's any value in doing some sort of refactor to organize things in a less error prone way I might be happy to contribute a PR in the future. I'm not sure if this is just a one off or there is a better way to architect this sort of stuff.
The way I would imagine doing storage to be really easy to understand is:
to_json
and from_json
as well as generic "save this thing's config". Whether the thing is an SAE, an SAE training run, some sort of future architectural change like swapping out a layer, a finetune, or something else it should IMO always be the same and be differentiated by a type string in the config.I think this is useful in the future if you want to add more features. I think this will become more of a thing because people are going to want to train solutions to superposition not only in more ways but also potentially with slightly different architectures. Features can also include stuff like low-memory training or performance optimizations (important to me because I've gotten OOM problems), Features naturally split into different classes (i.e. performance optimization vs. SAE training parameters and strategy vs. dataset and model vs. hook strategy vs. probably some other stuff vs. other stuff and miscellaneous and niche stuff). This is why the hierarchical approach seems right, but on the other hand this can become hellish if the nesting is deep so a tiny bit of thought should be put into it depending on the necessity.
Hmm, I'm not sure if this is an error or I'm misusing but also something to maybe look into along with this: https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L375. __dict__
includes whatever you set, even if it's not one of the dataclass elements, for example tokens per buffer: https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L240.
I have not done an ablation for this one yet, but I can get a hotfix for myself, dw about that. Just pointing out that this might be of interest.
IPython Example:
In [13]: @dataclass
...: class A:
...: x: int = 3
...:
In [14]: a = A(x=5)
In [15]: a
Out[15]: A(x=5)
In [16]: a.__dict__
Out[16]: {'x': 5}
In [17]: a.y = 120
In [18]: a.__dict__
Out[18]: {'x': 5, 'y': 120}
In [19]: A(**a.__dict__)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[19], line 1
----> 1 A(**a.__dict__)
TypeError: A.__init__() got an unexpected keyword argument 'y'
If I try to load the cfg.json
I get this error:
Traceback (most recent call last):
File "/my/folder/main.py", line 410, in <module>
cli()
File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
return self.main(*args, **kwargs)
File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1078, in main
rv = self.invoke(ctx)
File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/my/folder/.venv/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/my/folder/main.py", line 249, in check_trojan
return check_trojan_f(query, run_dir, generation_path, with_sae)
File "/my/folder/trojans.py", line 684, in check_trojan
sae = load_sae_from_dir(__with_sae, device=device)
File "/my/folder/saes.py", line 45, in load_sae_from_dir
runner_config = LanguageModelSAERunnerConfig.from_json(sae_dir.as_posix() + "/")
File "/my/folder/.venv/lib/python3.10/site-packages/sae_lens/config.py", line 396, in from_json
return cls(**cfg)
TypeError: LanguageModelSAERunnerConfig.__init__() got an unexpected keyword argument 'tokens_per_buffer'
Calling code:
def load_sae_from_dir(sae_dir: Path | str, device: str = "cpu") -> SAE:
"""
Due to a bug (https://github.com/jbloomAus/SAELens/issues/168) in the SAE save implementation for SAE Lens we need to make
a specialized workaround.
WARNING this will be creating a directory where the files are LINKED with the exception of "cfg.json" which is copied. This is NOT efficient
and you should not be calling it many times!
This wraps: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L284.
SPECIFICALLY fix cfg.json.
"""
sae_dir = Path(sae_dir)
if not all([x.is_file() for x in sae_dir.iterdir()]):
raise ValueError(
"Not all files are present in the directory! Only files allowed for loading SAE Directory."
)
# https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L188
runner_config = LanguageModelSAERunnerConfig.from_json(sae_dir.as_posix() + "/") # <--------------------------- FAIL
sae_config = runner_config.get_training_sae_cfg_dict()
sae = None
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
# Copy in the CFG
sae_config_f = temp_dir / "cfg.json"
with open(sae_config_f, "w") as f:
json.dump(sae_config, f)
# LINK (zero disk space) all the other files
for name_f in sae_dir.iterdir():
if name_f.name == "cfg.json":
continue
os.link(f, (temp_dir / f.name).as_posix())
# Load SAE
sae = SAE.load_from_dir(temp_dir, device=device)
assert sae is not None and isinstance(sae, SAE)
return sae
Reasonable suggestions! I experimented with hierarchical approaches and it was very messy. I like the idea of cleaning this up but suspect we don't need a major rethink just yet. I'd accept a small contained PR that addresses this exact issue but am not open to any major overhauls right now. I expect a slight refactor so there is less differentiation between HF and local SAEs than we have currently. Sorry I don't have capacity to engage much more deeply at present. Please keep sharing any issues like these though!
On Wed, May 29, 2024, 6:50 PM razzle dazzle @.***> wrote:
If there's any value in doing some sort of refactor to organize things in a less error prone way I might be happy to contribute a PR in the future. I'm not sure if this is just a one off or there is a better way to architect this sort of stuff.
The way I would imagine doing storage to be really easy to understand is:
- Have some sort of centralized pydantic configuration definitions file/folder. Thus way all jsons, etc... are easy to find.
- Have generic to_json and from_json as well as generic "save this thing's config". Whether the thing is an SAE, an SAE training run, some sort of future architectural change like swapping out a layer, a finetune, or something else it should IMO always be the same and be differentiated by a type string in the config.
- Kind of implied/implies 1 & 2 but logically configs form a pretty flat tree structure where each config is just a root config with a type and some additional information. This may allow for smaller configs (easier to understand), more scalable configs (more types will be easy to use) and minimize these errors and make correctness more automated (i.e. you can automatically save multiple configs in a non-overlapping way by having a generic "give me the name of this config" method, for example).
I think this is useful in the future if you want to add more features. I think this will become more of a thing because people are going to want to train solutions to superposition not only in more ways but also potentially with slightly different architectures. Features can also include stuff like low-memory training or performance optimizations (important to me because I've gotten OOM problems), Features naturally split into different classes (i.e. performance optimization vs. SAE training parameters and strategy vs. dataset and model vs. hook strategy vs. probably some other stuff vs. other stuff and miscellaneous and niche stuff). This is why the hierarchical approach seems right, but on the other hand this can become hellish if the nesting is deep so a tiny bit of thought should be put into it depending on the necessity.
— Reply to this email directly, view it on GitHub https://github.com/jbloomAus/SAELens/issues/168#issuecomment-2137957373, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZYEAWT7ISSKLVWQRWLZEYINHAVCNFSM6AAAAABIPK6IMOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXHE2TOMZXGM . You are receiving this because you commented.Message ID: @.***>
Fixed band-aid in case anyone needs to load these configs for SAEs below. Kept this external for now.
import tempfile
import os
from sae_lens import SAE
from pathlib import Path
def load_sae_from_dir(sae_dir: Path | str, device: str = "cpu") -> SAE:
"""
Due to a bug (https://github.com/jbloomAus/SAELens/issues/168) in the SAE save implementation for SAE Lens we need to make
a specialized workaround.
WARNING this will be creating a directory where the files are LINKED with the exception of "cfg.json" which is copied. This is NOT efficient
and you should not be calling it many times!
This wraps: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L284.
SPECIFICALLY fix cfg.json.
"""
sae_dir = Path(sae_dir)
if not all([x.is_file() for x in sae_dir.iterdir()]):
raise ValueError(
"Not all files are present in the directory! Only files allowed for loading SAE Directory."
)
# https://github.com/jbloomAus/SAELens/blob/9dacd4a9672c138b7c900ddd9a28d1b3b3a0870c/sae_lens/config.py#L188
# Load ourselves instead of from_json because there are some __dir__ elements that are not in the JSON
# They should ALL be enumerated in `derivatives`
##### BEGIN #####
cfg_f = sae_dir / "cfg.json"
with open(cfg_f, "r") as f:
cfg = json.load(f)
derivatives = [
"tokens_per_buffer",
]
derivative_values = [cfg[x] for x in derivatives]
for x in derivatives:
del cfg[x]
runner_config = LanguageModelSAERunnerConfig(**cfg)
assert all(
[
d in runner_config.__dict__ and runner_config.__dict__[d] == dv
for d, dv in zip(derivatives, derivative_values)
]
)
del derivative_values
del derivatives
##### END #####
# Load the SAE
sae_config = runner_config.get_training_sae_cfg_dict()
sae = None
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
# Copy in the CFG
sae_config_f = temp_dir / "cfg.json"
with open(sae_config_f, "w") as f:
json.dump(sae_config, f)
# LINK (zero disk space) all the other files
for name_f in sae_dir.iterdir():
if name_f.name == "cfg.json":
continue
else:
os.link(name_f.as_posix(), (temp_dir / name_f.name).as_posix())
# Load SAE
sae = SAE.load_from_pretrained(temp_dir, device=device)
assert sae is not None and isinstance(sae, SAE)
return sae
Describe the bug Saved models cannot be loaded because some of the parameters in the config are missing. Please look the bottom first at additional context since it includes a likely cause and implies the solution.
Code example + error Here is
tree
of the model save directory:Here is the config JSON:
and here is the code saving the path: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/training/sae_trainer.py#L176
and here is the code generating the error (loading):
and here is the ERROR:
System Info This should NOT matter (look at additional context first please)
Requirements:
Additional context Perhaps, I have made some sort of usage mistake, but I have reason to believe this is a genuine mistake because AFAIK the SAE should have its
cfg
file overwritten; look:cfg.json
: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L276 (JSON def here)cfg.json
: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae_training_runner.py#L176I have not done an ablation test for this yet.
Checklist