stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
545 stars 46 forks source link

Save/load trainable params in `IntervenableBase` methods #153

Open aryamanarora opened 2 months ago

aryamanarora commented 2 months ago

Description

Add saving/loading of trainable parameters in the model (e.g. classification heads) to IntervenableModel.save() and IntervenableModel.load(). Draft PR since some tests are failing, will finalise tomorrow.

Testing Done

Saving/loading of Gemma 2B-IT for sequence classification works perfectly.

Checklist: