Thanks for your quick implementation! I was reading through bitnet/bitbnet_b158.py and just had a short question.
In your implementation of quantize_weights you use the same procedure as outlined in the paper "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits", but it looks like the quantized weights are stored in float32 while the activation quantization is explicitly casted to int8. I could be missing something, but how are you saving on memory (other than 8bit activations just like the paper) when the quantized weights are kept as float32s ?
def quantize_weights(self, W):
"""
Quantizes the weights using the absmean quantization function.
Args:
W (Tensor): The weight tensor to be quantized.
Returns:
Tensor: Quantized weight tensor.
"""
gamma = torch.mean(torch.abs(W)) + self.eps
W_scaled = W / gamma
W_quantized = torch.sign(W_scaled) * torch.clamp(
torch.abs(W_scaled).round(), max=1.0 # torch.float32
)
return W_quantized
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Thanks for your quick implementation! I was reading through
bitnet/bitbnet_b158.py
and just had a short question.In your implementation of
quantize_weights
you use the same procedure as outlined in the paper "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits", but it looks like the quantized weights are stored infloat32
while the activation quantization is explicitly casted toint8
. I could be missing something, but how are you saving on memory (other than 8bit activations just like the paper) when the quantized weights are kept as float32s ?Upvote & Fund