AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
432 stars 47 forks source link

Support tuned lens training in 8 bit with `bitsandbytes` #88

Closed norabelrose closed 1 year ago

norabelrose commented 1 year ago

This adds an int8 field to the Model config class which sets load_in_8bit=True when AutoModelForCausalLM.from_pretrained is called.

I had to do a little bit of refactoring in order for this to work because load_in_8bit=True requires a device_map to be set, and in main the Model class doesn't actually know what device it's supposed to be on. Also, FSDP is simply incompatible with the device_map flag, so I have to turn that off when fsdp is enabled.

codecov[bot] commented 1 year ago

Codecov Report

Merging #88 (8ceb6ea) into main (9974450) will increase coverage by 0.03%. The diff coverage is 85.71%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #88      +/-   ##
==========================================
+ Coverage   78.74%   78.77%   +0.03%     
==========================================
  Files          32       32              
  Lines        1750     1753       +3     
==========================================
+ Hits         1378     1381       +3     
  Misses        372      372              
Impacted Files Coverage Δ
tuned_lens/scripts/eval_loop.py 88.70% <66.66%> (+0.09%) :arrow_up:
tuned_lens/scripts/train_loop.py 71.80% <85.71%> (+0.15%) :arrow_up:
tuned_lens/scripts/ingredients.py 84.96% <100.00%> (+0.09%) :arrow_up: