HP2706 / Auto_HookPoint

6 stars 1 forks source link

SAELens / Transformers adapter #1

Closed joelburget closed 2 months ago

joelburget commented 2 months ago

I was able to train an SAE on a transformers model using this library with the following code:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from Auto_HookPoint import auto_hook 
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
from transformer_lens.hook_points import HookedRootModule
from typing import Union, List, Optional, Literal
import transformer_lens.utils as tl_utils
from dataclasses import dataclass

@dataclass
class Cfg:
    device: str

class HookedTransformerAdapter(HookedRootModule):
    def __init__(self, model_name, n_ctx=8192):
        super().__init__()
        self.cfg = Cfg(device='cpu')
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.n_ctx = n_ctx
        self.W_E = self.model.model.embed_tokens.weight
        self.setup()

    def to_tokens(
        self,
        input: Union[str, List[str]],
        prepend_bos: Optional[Union[bool, None]] = True,
        padding_side: Optional[Union[Literal["left", "right"], None]] = None,
        move_to_device: bool = True,
        truncate: bool = True,
    ):
        return self.tokenizer(
            input,
            return_tensors="pt",
            padding=True,
            truncation=truncate,
            max_length=self.n_ctx if truncate else None,
        )["input_ids"]

    def run_with_cache(self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs):
        out, cache_dict = super().run_with_cache(
            *model_args, remove_batch_dim=remove_batch_dim  # , **kwargs
        )
        if return_cache_object:
            cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
            return out, cache
        else:
            return out, cache_dict

    def forward(self, *args, **kwargs):
        return self.model.forward(*args, **kwargs)

hooked_model = HookedTransformerAdapter(model_name)
sparse_autoencoder = SAETrainingRunner(cfg, override_model=hooked_model).run()

I wanted to open this ticket (a) to give people a pointer if they're trying to do the same thing, (b) suggest adding something like this to the docs, and (c) discuss if there's a better way (this was largely hacked together by running it, seeing what failed, and iterating). (maybe @jbloomaus will also have suggestions)

HP2706 commented 2 months ago

Hi i am not sure where exactly you are using auto_hook in the example, but i assume you are thinking of something like this

class HookedTransformerAdapter(HookedRootModule):
    def __init__(self, model_name, n_ctx=8192):
        super().__init__()
        self.cfg = Cfg(device='cpu')
        self.model = auto_hook(AutoModelForCausalLM.from_pretrained(model_name))
        .....

is this correct??

HP2706 commented 2 months ago

as far as whether there is an easier way, i dont see how this could be done differently, you are right that SaeTrainer assumes that hooked_model has attributes W_e and methods run_with_cache and to_tokens. I think this is a useful example and i will add it to the readme and examples.

joelburget commented 2 months ago

That's exactly right. I ended up with:

    def __init__(self, model_name, n_ctx=8192):
        super().__init__()
        self.cfg = Cfg(device='cpu')
        self.model = auto_hook(AutoModelForCausalLM.from_pretrained(model_name))
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.n_ctx = n_ctx
        # self.model._module: MixtralForCausalLM
        # self.model._module.model._module: MixtralModel
        self.W_E = self.model._module.model._module.embed_tokens._module.weight
        self.setup()
joelburget commented 2 months ago

I saw that you added https://github.com/HP2706/Auto_HookPoint/blob/main/examples/sae_lens.py, thanks! Were you able to run it successfully? I hit a snag when I tried: https://gist.github.com/joelburget/bb5118d828713810df27c615ae1724c6.

HP2706 commented 2 months ago

i think i got it to work i havent seen that error before. when i tried it i used a dummy version of mixtral to avoid testing on gpu. Will try the running the full example and get back to you.

HP2706 commented 2 months ago

Hi i can now confirm that the sae_lens example works, check the examples folder

joelburget commented 2 months ago

Can confirm it seems to be working!