jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
459 stars 122 forks source link

[Proposal] Simple get_sae_config method for pretrained SAEs #315

Closed hijohnnylin closed 1 month ago

hijohnnylin commented 1 month ago

Proposal

I would like a simple get_sae_config that takes the release and id (under saes), and returns the config for the SAE. This should not require downloading the whole SAE weights.

Motivation

Use case: If someone is looking to get the config to answer a quick question (eg what was the width of that one SAE?), they shouldn't need to have any knowledge of these conversion functions, or be forced to load the weights into memory.

Currently, PretrainedSaeLoader loads four different "types" of SAEs (each with a different conversion function): default (no conversion func), connor_rob_hook_z, gemma_2, and dictionary_learning_1.

To get just the config for each of them, I have to special case each of them, like so:

        if conversion_func is None:
            config = get_sae_config_from_hf(
                item.repo_id, folder_name, force_download=FORCE_DOWNLOAD
            )
        elif conversion_func == "connor_rob_hook_z":
            print("connor_rob_hook_z")
            # we don't have a special function to just get the config, so we have to download the whole thing
            config, _, _ = connor_rob_hook_z_loader(
                item.repo_id, folder_name, force_download=FORCE_DOWNLOAD
            )
        elif conversion_func == "dictionary_learning_1":
            print("dictionary_learning_1")
            # we don't have a special function to just get the config, so we have to download the whole thing
            config, _, _ = dictionary_learning_sae_loader_1(
                item.repo_id, folder_name, force_download=FORCE_DOWNLOAD
            )
        elif conversion_func == "gemma_2":
            print("gemma_2")
            config = get_gemma_2_config(item.repo_id, folder_name)

There are some problems with this: 1) Difficult to maintain - Requires updating my function every time there is a new conversion function. 2) Unnecessary loading weights - The connor_rob_hook_z and dictionary_learning_1 conversion types require downloading (and sometimes loading) the weights of the SAEs. I don't want to do this when all I want is the config.

Pitch

A simple get_sae_config method that takes two arguments: release_id and sae_id (and optionally, force_download if we suspect the config has changed). It returns the config without downloading weights or loading them.

Alternatives

The current alternative is to write the special cases as above.

Checklist

hijohnnylin commented 1 month ago

looks like #305 is related, but not exactly the same. possibly can eliminate two problems with one fix?

anthonyduong9 commented 1 month ago

@hijohnnylin I pushed changes that implement a get_sae_config() according to https://github.com/jbloomAus/SAELens/issues/305#issuecomment-2384123586 and was about to open a PR to fix https://github.com/jbloomAus/SAELens/issues/305. Do you want me to open a PR so you can review it (to see if we can solve both issues at once, or at least not make it difficult to fix this issue), or just discuss here?

hijohnnylin commented 1 month ago

@anthonyduong9

This looks great! thank you!

I think it's highly likely a slight modification to your change will satisfy both issues. I've skimmed your changes (not a full review since I'm not aware of all the edge cases in play here). Opening a PR would be great, I'll leave my suggestion/comment here so it's not blocked on creating a PR.

Feedback below is based on my needs - it's possible that the other PR has different requirements! Also, I defer to @chanind and @jbloomAus for final review as I'm not super familiar with conventions in this repo.

i hope that's helpful, apologies for missing your branch when i created this issue. thank you!

anthonyduong9 commented 1 month ago

@hijohnnylin thanks! I looked at what I'd need to do to implement your proposal, and it looks like I'd need to make additional changes about as large as my original changes. Basically, I'd need to change the interface

https://github.com/jbloomAus/SAELens/blob/7a1315b882af8f1822aa0d014725900e1f617bc1/sae_lens/toolkit/pretrained_sae_loaders.py#L12-L22

and in all of the implementations, use release and sae_id to perform lookups on an instance of

https://github.com/jbloomAus/SAELens/blob/7a1315b882af8f1822aa0d014725900e1f617bc1/sae_lens/toolkit/pretrained_saes_directory.py#L9-L19

which we can just call get_pretrained_saes_directory() to get.

@chanind the proposal seems a little different from what you described in https://github.com/jbloomAus/SAELens/issues/305#issuecomment-2384123586, but I think we should implement it, as it abstracts as much detail away from the user as possible, and also simplifies the code. Please let me know if this sounds good to you, and if so, if you'd prefer I open one PR per issue or one PR to fix both issues.

chanind commented 1 month ago

Seems reasonable! We should be able to look up the full sae info from a release and a repo_id anyway, which can give use the folder name. These issues can both be fixed in the same PR IMO, they seem like they're both fixable by the same thing.