dhakalnirajan / LLaMA-BitNet

LLaMA-BitNet is a repository dedicated to empowering users to train their own BitNet models built upon LLaMA 2 model, inspired by the groundbreaking paper 'The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits'.
https://arxiv.org/pdf/2402.17764
MIT License
12 stars 3 forks source link

Inference mode kernel #3

Open wx02shi opened 2 months ago

wx02shi commented 2 months ago

The Training Tips, Code and FAQ specifies that BitLinear has different forward() definitions between training vs. inference.

If I understand correctly, here convert_to_bitnet() is being used in both scenarios? Whilst this does produce a working LLM, there's no efficiency gains being made here.

As the FAQ states for inference:

  1. The model weights are offline quantized to 1.58 bits.
  2. The standard F.Linear operation is replaced with a customized low-bit kernel.

Without doing these two steps, all the weights are still fp16, and as such, still going through full precision floating point operations.

I don't mean to criticize, in fact I think this codebase and approach overall is the cleanest I've seen for BitNet!
But it is closer to a regular llama architecture. There is meant to be a tradeoff of precision for speed by constraining to ternary $[-1, 0, 1]$, but this is not actually taking any advantages.

dhakalnirajan commented 1 month ago

If you can implement the solution for actual ternary operators which you mentioned and open pull request, I can merge that. 😉