Closed norabelrose closed 1 year ago
Merging #94 (08c2218) into main (9974450) will increase coverage by
0.00%
. The diff coverage is82.85%
.:exclamation: Current head 08c2218 differs from pull request most recent head 927d166. Consider uploading reports for the commit 927d166 to get more accurate results
@@ Coverage Diff @@
## main #94 +/- ##
=======================================
Coverage 78.74% 78.74%
=======================================
Files 32 32
Lines 1750 1769 +19
=======================================
+ Hits 1378 1393 +15
- Misses 372 376 +4
Impacted Files | Coverage Δ | |
---|---|---|
tuned_lens/scripts/eval_loop.py | 88.70% <66.66%> (+0.09%) |
:arrow_up: |
tuned_lens/utils.py | 49.55% <71.42%> (+1.44%) |
:arrow_up: |
tuned_lens/scripts/ingredients.py | 84.27% <84.61%> (-0.60%) |
:arrow_down: |
tuned_lens/scripts/train_loop.py | 71.80% <85.71%> (+0.15%) |
:arrow_up: |
tests/test_lenses.py | 86.76% <100.00%> (+0.19%) |
:arrow_up: |
tuned_lens/nn/lenses.py | 91.81% <100.00%> (+0.15%) |
:arrow_up: |
tuned_lens/nn/unembed.py | 87.17% <100.00%> (ø) |
Importantly this PR also changes the dtype of tuned lens parameters. Before they were always float32 (or more precisely, they were always whatever is returned by
torch.get_default_dtype()
. But that's stupid and wastes VRAM for large models. We now make the tuned lens the same dtype as the model parameters (except in the case of int8, where we use float16).This PR also rolls the
--int8
functionality into the new--precision
flag; just set--precision int8
As a bonus, this PR also ensures that HuggingFace doesn't stupidly try to load a model from a tuned lens directory with the same name using the new
prevent_name_conflicts
context manager.