LLaMA-BitNet is a repository dedicated to empowering users to train their own BitNet models built upon LLaMA 2 model, inspired by the groundbreaking paper 'The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits'.
The Training Tips, Code and FAQ specifies that BitLinear has different forward() definitions between training vs. inference.
If I understand correctly, here convert_to_bitnet() is being used in both scenarios? Whilst this does produce a working LLM, there's no efficiency gains being made here.
As the FAQ states for inference:
The model weights are offline quantized to 1.58 bits.
The standard F.Linear operation is replaced with a customized low-bit kernel.
Without doing these two steps, all the weights are still fp16, and as such, still going through full precision floating point operations.
I don't mean to criticize, in fact I think this codebase and approach overall is the cleanest I've seen for BitNet!
But it is closer to a regular llama architecture. There is meant to be a tradeoff of precision for speed by constraining to ternary $[-1, 0, 1]$, but this is not actually taking any advantages.
The Training Tips, Code and FAQ specifies that
BitLinear
has differentforward()
definitions between training vs. inference.If I understand correctly, here
convert_to_bitnet()
is being used in both scenarios? Whilst this does produce a working LLM, there's no efficiency gains being made here.As the FAQ states for inference:
Without doing these two steps, all the weights are still
fp16
, and as such, still going through full precision floating point operations.I don't mean to criticize, in fact I think this codebase and approach overall is the cleanest I've seen for BitNet!
But it is closer to a regular llama architecture. There is meant to be a tradeoff of precision for speed by constraining to ternary $[-1, 0, 1]$, but this is not actually taking any advantages.