Closed joelburget closed 2 months ago
Hi i am not sure where exactly you are using auto_hook in the example, but i assume you are thinking of something like this
class HookedTransformerAdapter(HookedRootModule):
def __init__(self, model_name, n_ctx=8192):
super().__init__()
self.cfg = Cfg(device='cpu')
self.model = auto_hook(AutoModelForCausalLM.from_pretrained(model_name))
.....
is this correct??
as far as whether there is an easier way, i dont see how this could be done differently, you are right that SaeTrainer assumes that hooked_model has attributes W_e and methods run_with_cache and to_tokens. I think this is a useful example and i will add it to the readme and examples.
That's exactly right. I ended up with:
def __init__(self, model_name, n_ctx=8192):
super().__init__()
self.cfg = Cfg(device='cpu')
self.model = auto_hook(AutoModelForCausalLM.from_pretrained(model_name))
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.n_ctx = n_ctx
# self.model._module: MixtralForCausalLM
# self.model._module.model._module: MixtralModel
self.W_E = self.model._module.model._module.embed_tokens._module.weight
self.setup()
I saw that you added https://github.com/HP2706/Auto_HookPoint/blob/main/examples/sae_lens.py, thanks! Were you able to run it successfully? I hit a snag when I tried: https://gist.github.com/joelburget/bb5118d828713810df27c615ae1724c6.
i think i got it to work i havent seen that error before. when i tried it i used a dummy version of mixtral to avoid testing on gpu. Will try the running the full example and get back to you.
Hi i can now confirm that the sae_lens example works, check the examples folder
Can confirm it seems to be working!
I was able to train an SAE on a
transformers
model using this library with the following code:I wanted to open this ticket (a) to give people a pointer if they're trying to do the same thing, (b) suggest adding something like this to the docs, and (c) discuss if there's a better way (this was largely hacked together by running it, seeing what failed, and iterating). (maybe @jbloomaus will also have suggestions)