jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
412 stars 111 forks source link

KeyError in `geometric_medians` in training #23

Closed chanind closed 7 months ago

chanind commented 7 months ago

Running tinystories training gives an error:

"mats_sae_training/sae_training/train_sae_on_language_model.py", line 90, in train_sae_on_language_model
    geometric_medians[sae_layer_id].append(median)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 0

This can be solved by setting b_dec_init_method="mean" currently.

code:

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    cfg = LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distibuion)
        model_name="tiny-stories-2L-33M",
        hook_point="blocks.1.mlp.hook_post",
        hook_point_layer=1,
        d_in=4096,
        dataset_path="roneneldan/TinyStories",
        is_dataset_tokenized=False,
        # SAE Parameters
        expansion_factor=4,
        # Training Parameters
        lr=1e-4,
        l1_coefficient=3e-4,
        train_batch_size=4096,
        context_size=128,
        # Activation Store Parameters
        n_batches_in_buffer=128,
        total_training_tokens=1_000_000 * 10,  # want 500M eventually.
        store_batch_size=32,
        # Resampling protocol
        feature_sampling_window=2500,  # Doesn't currently matter.
        dead_feature_window=1250,
        dead_feature_threshold=0.0005,
        # Misc
        device=device,
        seed=42,
        n_checkpoints=0,
        checkpoint_path="checkpoints",
        dtype=torch.float32,
        # Wandb
        log_to_wandb=True,
        wandb_project="mats_sae_training_language_benchmark_tests",
        wandb_entity=None,
        wandb_log_frequency=10,
    )

    sparse_autoencoder = language_model_sae_runner(cfg)
themachinefan commented 7 months ago

Changing to geometric_medians[sae_layer_id] = median should work I think

jbloomAus commented 7 months ago

fixed. evidence we need to test more thoroughly.