jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
193 stars 67 forks source link

Performance improvements + using multiple GPUs. #189

Closed jbloomAus closed 2 weeks ago

jbloomAus commented 2 weeks ago

Description

I'm just working with some larger models and even though IO is still likely our bottleneck, there were a few easy wins with multiple GPUs I wanted to get out.

These were:

We should add these to the docs at some point but here's an example of using multiple devices for the model, SAE and activations store.

import torch

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae_training_runner import SAETrainingRunner

# For brevity, assume some basic config dict
cfg_dict = BASE_CFG # assume you have a config with Gemma-2b 
# we have 4 devices in this example
assert torch.cuda.device_count() == 4

# Model goes on cuda:0 and cuda:1. Setting n_devices will cause us to assume we're on cuda. 
cfg_dict["model_from_pretrained_kwargs"] = {"n_devices": torch.cuda.device_count() - 2}  

# SAE goes on cuda:2
cfg_dict["device"] = "cuda:2"

# Activation store goes on cuda:3  (faster than having activation store on CPU
cfg_dict["act_store_device"] = "cuda:3"

# Instantiate Config.
cfg = LanguageModelSAERunnerConfig(**cfg_dict)  # type: ignore
# look at the next cell to see some instruction for what to do while this is running.
sae = SAETrainingRunner(cfg).run()

Type of change

Please delete options that are not relevant.

Checklist:

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

codecov[bot] commented 2 weeks ago

Codecov Report

Attention: Patch coverage is 51.85185% with 13 lines in your changes missing coverage. Please review.

Project coverage is 59.58%. Comparing base (33a612d) to head (7e9e501). Report is 1 commits behind head on main.

Files Patch % Lines
sae_lens/load_model.py 0.00% 7 Missing and 1 partial :warning:
sae_lens/evals.py 76.92% 3 Missing :warning:
sae_lens/training/sae_trainer.py 33.33% 2 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #189 +/- ## ========================================== - Coverage 59.68% 59.58% -0.11% ========================================== Files 25 25 Lines 2642 2660 +18 Branches 446 450 +4 ========================================== + Hits 1577 1585 +8 - Misses 987 996 +9 - Partials 78 79 +1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.