jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
490 stars 127 forks source link

[Bug Report] #337

Open Yoon-Jeong-ho opened 1 month ago

Yoon-Jeong-ho commented 1 month ago

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug I encountered a RuntimeError during training while using sae_lens. The error appears to be related to a mismatch between the device used for tensor operations and the indices (CPU vs CUDA).

error message

 3 Training SAE:   0%|                                                                                                                           | 0/2048000000 [00:00<?, ?it/s]/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/activations_store.py:283: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
 4   yield torch.tensor(ng factor:   0%|                                                                                                               | 0/1000 [00:00<?, ?it/s]
 5 Estimating norm scaling factor: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:32<00:00,  1.95it/s]
 6 89900| MSE Loss 15.375 | L1 51.068:  18%|█████████████▊                                                               | 368230400/2048000000 [13:27:20<66:45:33, 6989.32it/s]Traceback (most recent call last):
 7   File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module>
 8     sparse_autoencoder = SAETrainingRunner(cfg).run()
 9                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run
11     sae = self.run_trainer_with_interruption_handling(trainer)
12           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling
14     sae = trainer.fit()
15           ^^^^^^^^^^^^^
16   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit
17     self._run_and_log_evals()
18   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
19     return func(*args, **kwargs)
20            ^^^^^^^^^^^^^^^^^^^^^
21   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals
22     eval_metrics = run_evals(
23                    ^^^^^^^^^^
24   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
25     return func(*args, **kwargs)
26            ^^^^^^^^^^^^^^^^^^^^^
27   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals
28     metrics |= get_sparsity_and_variance_metrics(
29                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
30   File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics
31     flattened_sae_input = flattened_sae_input[flattened_mask]
32                           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
33 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:2)

Code example

import os
from setproctitle import setproctitle

setproctitle("aa007878")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 4"
gpu_num = 4
import torch

from huggingface_hub import login

# HuggingFace API 토큰으로 로그인

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
scale_factor = 1
total_training_steps = int(500000 * scale_factor)  # probably we should do more
batch_size = int(4096 / scale_factor)
total_training_tokens = total_training_steps * batch_size

context_size = 1024
latent_size = 8
layer = 10
l1_coefficient = 0.05

lr_warm_up_steps = 200
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="meta-llama/Llama-3.2-1B",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name=f"blocks.{layer}.hook_resid_pre",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=layer,  # Only one layer in the model.
    d_in=2048,  # the width of the mlp output.
    dataset_path=f"yoonLM/llama3.2_org_1b_tokenizingdata_{context_size}", 
    is_dataset_tokenized=True,
    streaming=False,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=latent_size,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-6,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    #lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_scheduler_name="constant",
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=l1_coefficient,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=context_size,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-8,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project=f"sae_LLaMa3.2_1B_{context_size}_{latent_size}_{l1_coefficient}",
    wandb_log_frequency=300,
    eval_every_n_wandb_logs=300,
    model_from_pretrained_kwargs={"n_devices": gpu_num},
    # Misc
    device= device,
    seed=42,
    n_checkpoints=5,
    checkpoint_path=f"checkpoints_LLama3.2_1B_{context_size}_{latent_size}_{l1_coefficient}",
    dtype="float32"
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()   

from sae_lens import upload_saes_to_huggingface

layer_sae_path = f"layer_{layer}_sae"
sparse_autoencoder.save_model(layer_sae_path)

saes_dict = {
    f"blocks.{layer}.hook_resid_pre": layer_sae_path, # values can be an SAE object
}

upload_saes_to_huggingface(
    saes_dict,
    # change this to your own huggingface username and repo
    hf_repo_id=f"yoonLM/sae_llama3.2org_1B_{context_size}_{latent_size}_l1_{l1_coefficient}",
)

System Info Python : 3.11.9 CUDA : 12.4 GPU : NVIDIA RTX A6000 PyTorch : 2.0.1 ununtu : 20.04.1 LTS sae-lens : 3.22.2 torch : 2.4.1 transformer-lens : 2.7.0

Checklist

chanind commented 1 month ago

I can't reproduce this on my local machine, but I also don't have multiple GPUs. Does this only happen when using multiple GPUs?

Yoon-Jeong-ho commented 1 month ago

Yes, I didn't encounter this error when using just one GPU, but when using multiple GPUs with a larger context size and latent size, causing higher GPU memory usage, this error occurs.

Yoon-Jeong-ho commented 1 month ago

Previously, the same error occurred when learning the same-size sparse autoencoder.

self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) 3 Training SAE: 0%| | 0/2048000000 [00:00<?, ?it/s]/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/activations_store.py:283: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 4 yield torch.tensor(ng factor: 0%| | 0/1000 [00:00<?, ?it/s] 5 Estimating norm scaling factor: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:33<00:00, 1.95it/s] 6 249900| MSE Loss 7.922 | L1 72.222: 50%|█████████████████████████████████████▉ | 1023590400/2048000000 [37:11:49<40:19:52, 7055.51it/s]Traceback (most recent call last): 7 File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module> 8 sparse_autoencoder = SAETrainingRunner(cfg).run() 9 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 10 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run 11 sae = self.run_trainer_with_interruption_handling(trainer) 12 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 13 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling 14 sae = trainer.fit() 15 ^^^^^^^^^^^^^ 16 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit 17 self._run_and_log_evals() 18 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 19 return func(*args, **kwargs) 20 ^^^^^^^^^^^^^^^^^^^^^ 21 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals 22 eval_metrics = run_evals( 23 ^^^^^^^^^^ 24 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 25 return func(*args, **kwargs) 26 ^^^^^^^^^^^^^^^^^^^^^ 27 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals 28 metrics |= get_sparsity_and_variance_metrics( 29 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 30 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics 31 flattened_sae_input = flattened_sae_input[flattened_mask] 32 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^ 33 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

Yoon-Jeong-ho commented 1 month ago

I tried to lower the learning rate under the same conditions, but the same error occurred in the same place.

self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) 3 Estimating norm scaling factor: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:41<00:00, 1.92it/s] 4 89900| MSE Loss 141.638 | L1 62.357: 18%|████████████████████████▎ | 368230400/2048000000 [13:34:58<67:16:47, 6935.25it/s]Traceback (most recent call last): 5 File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module> 6 sparse_autoencoder = SAETrainingRunner(cfg).run() 7 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run 9 sae = self.run_trainer_with_interruption_handling(trainer) 10 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 11 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling 12 sae = trainer.fit() 13 ^^^^^^^^^^^^^ 14 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit 15 self._run_and_log_evals() 16 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 17 return func(*args, **kwargs) 18 ^^^^^^^^^^^^^^^^^^^^^ 19 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals 20 eval_metrics = run_evals( 21 ^^^^^^^^^^ 22 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 23 return func(*args, **kwargs) 24 ^^^^^^^^^^^^^^^^^^^^^ 25 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals 26 metrics |= get_sparsity_and_variance_metrics( 27 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 28 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics 29 flattened_sae_input = flattened_sae_input[flattened_mask] 30 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^ 31 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:2)

callummcdougall commented 1 month ago

Update here - I believe the root cause of this issue might be the PR last week which added get_sae_config, this introduced a bug where SAEs wouldn't be loaded onto the specified device. I also had a device-related bug, and making this change has now fixed things. See link to my PR.

chanind commented 1 month ago

@Yoon-Jeong-ho Is this fixed in the most recent version of SAELens (4.0.9)? Thanks for the fix @callummcdougall!