jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
419 stars 111 forks source link

[Bug Report] SAE training tutorial metrics do not match linked run #276

Open naterush opened 1 month ago

naterush commented 1 month ago

Describe the bug

Hey. Working through the training tutorial, and without any changes, I'm unable to train a basic SAE with loss numbers that are as good as linked. Not sure if this is numerical instability, or something's changed, or if my differences are actually not consequential -- so I'm opening this issue to get tot he bottom of it!

My steps:

  1. Opened the notebook in Google Colab (see notebook here)
  2. Selected an A100 GPU (for fast execution)
  3. Executed the notebook all the way through

Differences between my training run and yours

  1. My overall loss is 360, yours is 133
  2. My L0 is 160ish, yours is 80ish

There are a lot more differences - but wondering if you have thoughts on why this is. I'm new to SAE work generally, so any helpful tips here would be appreciated.

Code example

See notebook here

System Info

  1. Google Colab Pro+
  2. A100 GPU, took about 1 hour to train

Checklist

niniack commented 2 weeks ago

+1, I've been toying around with the library to get results from the wandb tutorial run, as well as these runs https://wandb.ai/jbloom/mats_sae_training_gpt2_small_resid_pre_5?nw=nwuserjbloom but have not had success with either.

I have replicated the hyperparameters that were set in the gpt2 runs (linked above) to no avail. I suspect that later versions of the library introduced some changes which needs different hyperparameters? I don't have a good theory.

Side note: @naterush your wandb run is private, other users cannot see the results!

Numeri commented 2 weeks ago

I've also run it several times and not managed to get anything with good loss curves – it plateaus very quickly around MSE loss of 200 and L1 loss of 165.

jbloomAus commented 2 weeks ago

Odd, I'll take a look.

On Thu, Oct 3, 2024, 10:34 AM Kaden Uhlig @.***> wrote:

I've also run it several times and not managed to get anything with good loss curves – it plateaus very quickly around MSE loss of 200 and L1 loss of 165.

— Reply to this email directly, view it on GitHub https://github.com/jbloomAus/SAELens/issues/276#issuecomment-2391960505, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQPMYZ2B6U3EQCZUI77QPOTZZV5ZXAVCNFSM6AAAAABNSWINMGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOJRHE3DANJQGU . You are receiving this because you are subscribed to this thread.Message ID: @.***>