AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
432 stars 47 forks source link

Support manually setting the dtype of the model and tuned lens weights #94

Closed norabelrose closed 1 year ago

norabelrose commented 1 year ago

Importantly this PR also changes the dtype of tuned lens parameters. Before they were always float32 (or more precisely, they were always whatever is returned by torch.get_default_dtype(). But that's stupid and wastes VRAM for large models. We now make the tuned lens the same dtype as the model parameters (except in the case of int8, where we use float16).

This PR also rolls the --int8 functionality into the new --precision flag; just set --precision int8

As a bonus, this PR also ensures that HuggingFace doesn't stupidly try to load a model from a tuned lens directory with the same name using the new prevent_name_conflicts context manager.

codecov[bot] commented 1 year ago

Codecov Report

Merging #94 (08c2218) into main (9974450) will increase coverage by 0.00%. The diff coverage is 82.85%.

:exclamation: Current head 08c2218 differs from pull request most recent head 927d166. Consider uploading reports for the commit 927d166 to get more accurate results

Impacted file tree graph

@@           Coverage Diff           @@
##             main      #94   +/-   ##
=======================================
  Coverage   78.74%   78.74%           
=======================================
  Files          32       32           
  Lines        1750     1769   +19     
=======================================
+ Hits         1378     1393   +15     
- Misses        372      376    +4     
Impacted Files Coverage Δ
tuned_lens/scripts/eval_loop.py 88.70% <66.66%> (+0.09%) :arrow_up:
tuned_lens/utils.py 49.55% <71.42%> (+1.44%) :arrow_up:
tuned_lens/scripts/ingredients.py 84.27% <84.61%> (-0.60%) :arrow_down:
tuned_lens/scripts/train_loop.py 71.80% <85.71%> (+0.15%) :arrow_up:
tests/test_lenses.py 86.76% <100.00%> (+0.19%) :arrow_up:
tuned_lens/nn/lenses.py 91.81% <100.00%> (+0.15%) :arrow_up:
tuned_lens/nn/unembed.py 87.17% <100.00%> (ø)