AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
437 stars 47 forks source link

Added support for creating a prediction trajectory from a model cache #103

Closed levmckinney closed 1 year ago

levmckinney commented 1 year ago

This PR adds several new features that allow for easy integration with the TransformerLens repo.

Features:

Recommending reading order

  1. If you are not familiar with the prediction trajectory class start with the prediction_trajectories.ipynb notebook. Note you may need to comment the thing out to get the plots to appear on your local system.
  2. Read through the combining_with_transformer_lens.ipynb tutorial. This covers most of the new features added with some worked examples.
  3. Look at the prediction_trajectory.py

What I want feedback on

Does this interface with the transformer lens repo make sense? Is there anything obvious missing? Is the tutorial clear? Are there parts that need more explanation or might be misleading.

codecov[bot] commented 1 year ago

Codecov Report

Merging #103 (cf8b597) into main (1308bdd) will increase coverage by 1.31%. The diff coverage is 90.13%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #103      +/-   ##
==========================================
+ Coverage   79.57%   80.88%   +1.31%     
==========================================
  Files          32       32              
  Lines        1826     2077     +251     
==========================================
+ Hits         1453     1680     +227     
- Misses        373      397      +24     
Impacted Files Coverage Δ
tuned_lens/nn/lenses.py 91.81% <ø> (ø)
tuned_lens/plotting/token_formatter.py 88.00% <25.00%> (-12.00%) :arrow_down:
tuned_lens/nn/unembed.py 86.25% <71.42%> (-0.93%) :arrow_down:
tuned_lens/model_surgery.py 61.94% <83.33%> (+4.33%) :arrow_up:
tuned_lens/plotting/trajectory_plotting.py 88.37% <86.66%> (+3.26%) :arrow_up:
tuned_lens/plotting/prediction_trajectory.py 92.34% <87.61%> (-3.81%) :arrow_down:
tests/plotting/test_prediction_trajectory.py 100.00% <100.00%> (ø)
tests/plotting/test_trajectory_plotting.py 100.00% <100.00%> (ø)

... and 6 files with indirect coverage changes