kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.69k stars 155 forks source link

Question about weight quantization methodology memory savings #25

Closed nnethercott closed 6 months ago

nnethercott commented 8 months ago

Thanks for your quick implementation! I was reading through bitnet/bitbnet_b158.py and just had a short question.

In your implementation of quantize_weights you use the same procedure as outlined in the paper "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits", but it looks like the quantized weights are stored in float32 while the activation quantization is explicitly casted to int8. I could be missing something, but how are you saving on memory (other than 8bit activations just like the paper) when the quantized weights are kept as float32s ?

   def quantize_weights(self, W):
        """
        Quantizes the weights using the absmean quantization function.

        Args:
            W (Tensor): The weight tensor to be quantized.

        Returns:
            Tensor: Quantized weight tensor.
        """
        gamma = torch.mean(torch.abs(W)) + self.eps
        W_scaled = W / gamma
        W_quantized = torch.sign(W_scaled) * torch.clamp(
            torch.abs(W_scaled).round(), max=1.0 # torch.float32 
        )
        return W_quantized

Upvote & Fund

Fund with Polar

github-actions[bot] commented 6 months ago

Stale issue message