This turned out to be even easier than I thought, because when you set load_in_8bit=True and output_hidden_states=True, the hiddens are still in float16 not some weird int8 thing.
You can turn this on with the --int8 flag on the command line. It does require installing bitsandbytes first, but this should be as easy as python -m pip install bitsandbytes on the LAS nodes now.
Currently checking that this gives ~the same results as float16 inference, although I would be shocked if they were noticeably different.
This turned out to be even easier than I thought, because when you set
load_in_8bit=True
andoutput_hidden_states=True
, the hiddens are still in float16 not some weird int8 thing.You can turn this on with the
--int8
flag on the command line. It does require installingbitsandbytes
first, but this should be as easy aspython -m pip install bitsandbytes
on the LAS nodes now.Currently checking that this gives ~the same results as float16 inference, although I would be shocked if they were noticeably different.
This PR depends on #226