stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
559 stars 50 forks source link

[P1] Speed up training of multiple DAS interventions with caching #76

Open aryamanarora opened 6 months ago

aryamanarora commented 6 months ago

Commonly, we want to exhaustively train DAS on every layer and position (or e.g. every attention head in a layer) to find which ones are causally relevant for the model's computations. When dealing with a fixed dataset, we could speed this process up by caching and reusing activations. Unclear what the best way to implement this is; should already be possible to have a minimal example with CollectIntervention and the activations_sources= arg in inference.