jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
412 stars 111 forks source link

[Bug Report] Unsupported scheduler error when training SAE #105

Closed dtch1997 closed 5 months ago

dtch1997 commented 5 months ago

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug Example SAE training code encounters an error in sae-lens==0.7.0

Code example

import os

import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_lens import LanguageModelSAERunnerConfig, language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point="blocks.2.hook_resid_pre",
    hook_point_layer=2,
    d_in=768,
    dataset_path="Skylion007/openwebtext",
    is_dataset_tokenized=False,
    # SAE Parameters
    expansion_factor=64,
    b_dec_init_method="geometric_median",
    # Training Parameters
    lr=0.0004,
    l1_coefficient=0.00008,
    lr_scheduler_name="constantwithwarmup",
    train_batch_size=4096,
    context_size=128,
    lr_warm_up_steps=5000,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    training_tokens=1_000_000 * 300,
    store_batch_size=32,
    # Dead Neurons and Sparsity
    use_ghost_grads=True,
    feature_sampling_window=1000,
    dead_feature_window=5000,
    dead_feature_threshold=1e-6,
    # WANDB
    log_to_wandb=True,
    wandb_project="gpt2",
    wandb_entity=None,
    wandb_log_frequency=100,
    # Misc
    device="cuda",
    seed=42,
    n_checkpoints=10,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

Error message:

Traceback (most recent call last):
  File "/home/daniel/ml_workspace/SAELens/scripts/train_gated_sae.py", line 51, in <module>
    sparse_autoencoder = language_model_sae_runner(cfg)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/lm_runner.py", line 34, in language_model_sae_runner
    sparse_autoencoder = train_sae_on_language_model(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 92, in train_sae_on_language_model
    return train_sae_group_on_language_model(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 132, in train_sae_group_on_language_model
    train_contexts = {
                     ^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 133, in <dictcomp>
    name: _build_train_context(sae, total_training_steps)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 302, in _build_train_context
    scheduler = get_scheduler(
                ^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/optim.py", line 38, in get_scheduler
    main_scheduler = _get_main_scheduler(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/optim.py", line 98, in _get_main_scheduler
    raise ValueError(f"Unsupported scheduler: {scheduler_name}")
ValueError: Unsupported scheduler: constantwithwarmup

System Info Describe the characteristic of your environment:

Checklist

themachinefan commented 5 months ago

I have the same issue. It seems lr_scheduler_name should be set to "constant" (warm up is added to any scheduler as long as lr_warm_up_steps != 0)

(not really a bug though, constantwithwarmup is just no longer supported)

jbloomAus commented 5 months ago

Apologies. The docs got out of date. I think the training tutorial has up to date hyperparameter examples. @dtch1997 do you think you could take a moment to make a PR to put those hyper pars in the docs? (and maybe add a note in that tutorial about trying to keep them in sync with the docs).