jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
402 stars 110 forks source link

[Bug Report] Unable to train Mamba SAE #311

Open joelburget opened 1 week ago

joelburget commented 1 week ago

Describe the bug

Error running with Mamba: 'HookedMamba' object has no attribute 'W_E'.

Code example

cfg = LanguageModelSAERunnerConfig(
    model_name="state-spaces/mamba-2.8b",
    model_class_name="HookedMamba",
    ...
)
sae = SAETrainingRunner(cfg).run()

Full code: https://github.com/joelburget/mamba-sae/blob/2f87fb99660516c47121aa7a0f65d8944c42778b/hyperparam_sweep.py

Traceback (most recent call last):
  File "/workspace/mamba-sae/hyperparam_sweep.py", line 70, in <module>
    sae = SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__
    self._init_sae_group_b_decs()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs
    layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer
    self._storage_buffer = self.get_buffer(self.half_buffer_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer
    refill_batch_tokens = self.get_batch_tokens(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens
    return torch.stack(sequences, dim=0).to(self.model.W_E.device)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?
Full output ``` root@39393c1a11a5:/workspace/mamba-sae# wandb agent PEAR-ML/mamba-sae-sweep/i0mhj40q wandb: Starting wandb agent ๐Ÿ•ต๏ธ 2024-10-01 15:28:00,877 - wandb.wandb_agent - INFO - Running runs: [] 2024-10-01 15:28:01,254 - wandb.wandb_agent - INFO - Agent received command: run 2024-10-01 15:28:01,254 - wandb.wandb_agent - INFO - Agent starting run with config: learning_rate: 0.0012748220954754614 sparsity_penalty: 0.0922662545654348 2024-10-01 15:28:01,259 - wandb.wandb_agent - INFO - About to run command: /usr/bin/env python hyperparam_sweep.py --learning_rate=0.0012748220954754614 --sparsity_penalty=0.0922662545654348 Resolving data files: 0%| | 0/37 [00:00 sae = SAETrainingRunner(cfg).run() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__ self._init_sae_group_b_decs() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :] File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer self._storage_buffer = self.get_buffer(self.half_buffer_size) File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer refill_batch_tokens = self.get_batch_tokens( File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens return torch.stack(sequences, dim=0).to(self.model.W_E.device) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'? ```

System Info

root@39393c1a11a5:/workspace/mamba-sae# cat requirements.txt
torch>=2.2.0
transformers[sentencepiece]>=4.39.2
accelerate>=0.27.2
datasets>=2.15.0
wandb
sae-lens[mamba]
root@39393c1a11a5:/workspace/mamba-sae# uname -a
Linux 39393c1a11a5 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux
root@39393c1a11a5:/workspace/mamba-sae# python --version
Python 3.10.12

Checklist

joelburget commented 1 week ago

I was also able to repro with tutorials/mamba_train_example.py after a fresh clone:

Fixing a couple other errors first Error 1 ``` root@39393c1a11a5:/workspace/SAELens# python3 tutorials/mamba_train_example.py Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08 n_tokens_per_buffer (millions): 0.524288 Lower bound: n_contexts_per_buffer (millions): 0.004096 Total training steps: 73242 Total wandb updates: 732 n_tokens_per_feature_sampling_window (millions): 524.288 n_tokens_per_dead_feature_window (millions): 2621.44 We will reset the sparsity calculation 73 times. Number tokens in sparsity calculation window: 4.10e+06 Using Ghost Grads. Traceback (most recent call last): File "/workspace/SAELens/tutorials/mamba_train_example.py", line 57, in SAETrainingRunner(cfg).run() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 57, in __init__ self.model = load_model( File "/usr/local/lib/python3.10/dist-packages/sae_lens/load_model.py", line 39, in load_model HookedMamba.from_pretrained( TypeError: HookedMamba.from_pretrained() got an unexpected keyword argument 'center_writing_weights' ``` Diff ``` diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py index d76a673..a055581 100644 --- a/tutorials/mamba_train_example.py +++ b/tutorials/mamba_train_example.py @@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig( "fast_ssm": True, "fast_conv": True, }, + model_from_pretrained_kwargs={} ) ``` Error 2 ``` Traceback (most recent call last): File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in SAETrainingRunner(cfg).run() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 66, in __init__ self.activations_store = ActivationsStore.from_config( File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 69, in from_config return cls( File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 218, in __init__ raise ValueError( ValueError: pretokenized dataset has context_size 1024, but the provided context_size is 128. ``` Diff ``` --- a/tutorials/mamba_train_example.py +++ b/tutorials/mamba_train_example.py @@ -27,7 +27,7 @@ cfg = LanguageModelSAERunnerConfig( l1_coefficient=0.00006 * 0.2, lr_scheduler_name="cosineannealingwarmrestarts", train_batch_size_tokens=4096, - context_size=128, + context_size=1024, lr_warm_up_steps=5000, # Activation Store Parameters n_batches_in_buffer=128, @@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig( "fast_ssm": True, "fast_conv": True, }, + model_from_pretrained_kwargs={} ) SAETrainingRunner(cfg).run() ```
root@39393c1a11a5:/workspace/SAELens# python3 tutorials/mamba_train_example.py
Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 4.194304
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 20971.52
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06
Using Ghost Grads.
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
Moving model to device:  cuda
Resolving data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 37/37 [00:00<00:00, 254.39it/s]
Traceback (most recent call last):
  File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in <module>
    SAETrainingRunner(cfg).run()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in __init__
    self._init_sae_group_b_decs()
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs
    layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer
    self._storage_buffer = self.get_buffer(self.half_buffer_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer
    refill_batch_tokens = self.get_batch_tokens(
  File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens
    return torch.stack(sequences, dim=0).to(self.model.W_E.device)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?
jbloomAus commented 1 week ago

Thanks for flagging this. This might be an issue with MambaLens or we weren't testing something we needed to during changes. Will follow up shortly.

On Tue, Oct 1, 2024, 9:56โ€ฏPM Joel Burget @.***> wrote:

I was also able to repro with tutorials/mamba_train_example.py after a fresh clone: Fixing a couple other errors first

Error 1

@.***:/workspace/SAELens# python3 tutorials/mamba_train_example.py Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08 n_tokens_per_buffer (millions): 0.524288 Lower bound: n_contexts_per_buffer (millions): 0.004096 Total training steps: 73242 Total wandb updates: 732 n_tokens_per_feature_sampling_window (millions): 524.288 n_tokens_per_dead_feature_window (millions): 2621.44 We will reset the sparsity calculation 73 times. Number tokens in sparsity calculation window: 4.10e+06 Using Ghost Grads. Traceback (most recent call last): File "/workspace/SAELens/tutorials/mamba_train_example.py", line 57, in SAETrainingRunner(cfg).run() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 57, in init self.model = load_model( File "/usr/local/lib/python3.10/dist-packages/sae_lens/load_model.py", line 39, in load_model HookedMamba.from_pretrained( TypeError: HookedMamba.from_pretrained() got an unexpected keyword argument 'center_writing_weights'

Diff

diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py index d76a673..a055581 100644 --- a/tutorials/mamba_train_example.py +++ b/tutorials/mamba_train_example.py @@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig( "fast_ssm": True, "fast_conv": True, },

  • model_from_pretrained_kwargs={} )

Error 2

Traceback (most recent call last): File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in SAETrainingRunner(cfg).run() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 66, in init self.activations_store = ActivationsStore.from_config( File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 69, in from_config return cls( File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 218, in init raise ValueError( ValueError: pretokenized dataset has context_size 1024, but the provided context_size is 128.

Diff

--- a/tutorials/mamba_train_example.py +++ b/tutorials/mamba_train_example.py @@ -27,7 +27,7 @@ cfg = LanguageModelSAERunnerConfig( l1_coefficient=0.00006 * 0.2, lr_scheduler_name="cosineannealingwarmrestarts", train_batch_size_tokens=4096,

  • context_size=128,
  • context_size=1024, lr_warm_up_steps=5000,

    Activation Store Parameters

    n_batches_in_buffer=128, @@ -52,6 +52,7 @@ cfg = LanguageModelSAERunnerConfig( "fast_ssm": True, "fast_conv": True, },

  • model_from_pretrained_kwargs={} )

    SAETrainingRunner(cfg).run()

@.:/workspace/SAELens# python3 tutorials/mamba_train_example.py Run name: 131072-L1-1.2e-05-LR-0.0004-Tokens-3.000e+08 n_tokens_per_buffer (millions): 4.194304 Lower bound: n_contexts_per_buffer (millions): 0.004096 Total training steps: 73242 Total wandb updates: 732 n_tokens_per_feature_sampling_window (millions): 4194.304 n_tokens_per_dead_feature_window (millions): 20971.52 We will reset the sparsity calculation 73 times. Number tokens in sparsity calculation window: 4.10e+06 Using Ghost Grads. /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: clean_up_tokenization_spaces was not set. It will be set to True by default. This behavior will be deprecated in transformers v4.45, and will be then set to False by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884 warnings.warn( Moving model to device: cuda Resolving data files: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 37/37 [00:00<00:00, 254.39it/s] Traceback (most recent call last): File "/workspace/SAELens/tutorials/mamba_train_example.py", line 58, in SAETrainingRunner(cfg).run() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 82, in init self._init_sae_group_b_decs() File "/usr/local/lib/python3.10/dist-packages/sae_lens/sae_training_runner.py", line 170, in _init_sae_group_b_decs layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :] File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 376, in storage_buffer self._storage_buffer = self.get_buffer(self.half_buffer_size) File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 545, in get_buffer refill_batch_tokens = self.get_batch_tokens( File "/usr/local/lib/python3.10/dist-packages/sae_lens/training/activations_store.py", line 410, in get_batch_tokens return torch.stack(sequences, dim=0).to(self.model.W_E.device) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1688, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'HookedMamba' object has no attribute 'W_E'. Did you mean: 'W_K'?

โ€” Reply to this email directly, view it on GitHub https://github.com/jbloomAus/SAELens/issues/311#issuecomment-2387059513, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZYCAPWVQTOFQAZTMU3ZZMEABAVCNFSM6AAAAABPF3A52SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBXGA2TSNJRGM . You are receiving this because you are subscribed to this thread.Message ID: @.***>