OpenMOSS / Language-Model-SAEs

For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research.
32 stars 6 forks source link

Add finetuning code #2

Closed StarConnor closed 6 months ago

StarConnor commented 6 months ago

No change

  1. Haven't changed any line of the original source code.

    Add a file

  2. Add a file called "sae_finetuning.py" in the "core" directory
    • Copied from sae_training.py with minor changes showing below.
    • Remove the l1 loss from loss.
    • Remove ghost related code.
    • Remove "sparsity/dead_features" metric.
    • Add norm_ratio calculation as a metric to be shown on every step on wandb.
    • Import func "norm_ratio" from core/utils/misc.py.

      Add some code to some files

      In core/config.py

  3. Add a class called "LanguageModelSAEFinetuningConfig" to "core/config.py"
    • Copied from the class LanguageModelSAETrainingConfig with minor changes showing below.
    • Remove "dead_feature_window" variable.

      In core/optim.py

  4. Add a lambda func called "get_smoothing_lambda" to "core/optim.py"
    • This func would smooth the conjunction of the linear warmup process and linear cooldown process with the constant process.
  5. Add a scheduler method called "constantwithwarmupsmooth" to "core/optim.py".
    • Added within the if-else judgements.

      In core/runner.py

  6. Add a runner func called "finetune_runner" to "core/runner.py"
    • Copied from the func language_model_sae_runner in this file with minor changes showing below.
    • Add a line of code to freeze the encoder parameter of sae which is a func in the class SAE before finetuning.
    • Change the training func from "train_sae" to "finetune_sae".
    • Import the func "finetune_sae" from core/sae_finetuning.py.

      In core/utils/misc.py

  7. Add a metric func called "norm_ratio" to "core/utils/misc.py".
    • This func calculates the norm ratio of the input and the output of the sae as a metric.