stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
545 stars 46 forks source link

[P2] Add Sparse Autoencoder Interventions #164

Closed explanare closed 4 days ago

explanare commented 4 days ago

Description

Add an AutoencoderLayer and an AutoencoderIntervention to support interpretability methods that use autoencoders to learn interpretable feature space, including Sparse Autoencoders.

The AutoencoderIntervention supports loading pre-trained autoencoders trained outside pyvene framework, with the get_intervenable_with_autoencoder function below:

def get_intervenable_with_autoencoder(
    model, autoencoder, intervention_dimensions, layer):
  intervention = pv.AutoencoderIntervention(
      embed_dim=autoencoder.input_dim,
      latent_dim=autoencoder.latent_dim)
  # Copy the pretrained autoencoder.
  intervention.autoencoder.load_state_dict(autoencoder.state_dict())
  intervention.set_interchange_dim(interchange_dimensions)
  inv_config = pv.IntervenableConfig(
      model_type=type(model),
      representations=[
          pv.RepresentationConfig(
              layer,  # layer
              "block_output",  # intervention repr
              "pos",  # intervention unit
              1,  # max number of unit
              intervention=intervention,
              latent_dim=autoencoder.latent_dim)
      ],
      intervention_types=pv.AutoencoderIntervention,
  )
  intervenable = pv.IntervenableModel(inv_config, model)
  intervenable.set_device("cuda")
  intervenable.disable_model_gradients()
  return intervenable

The resulting intervenable, including the intervention dimensions and the autoencoder, can be saved as:

intervenable.save("path/to/save/dir")

Fix #77

Testing Done

[internal only] https://colab.research.google.com/drive/1_fxM7JUqkMy6Erz6K1JV0NwQBw1r8g0k?usp=sharing

Will add this colab as a tutorial.

Checklist:

frankaging commented 4 days ago

Thanks! Merging this with the failed check. The failure is due to a versioning issue with huggingface-hub. I will take care of that after this change.