kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.55k stars 143 forks source link

The output of BitLinear is quite abnormal #35

Closed Jiangxg closed 6 months ago

Jiangxg commented 6 months ago

Describe the bug I print the mean and variance of the tensor y in example.py. Its mean and variance are abnormal, as follows:

mean and var of BitLinear output: -0.567935049533844 1149.9969482421875

To make sure, I print the mean and variance of outputs from Linear and BitLinear, simutaneously.

mean and var of Linear output: 0.012186492793262005 0.33256232738494873 mean and var of BitLinear output: 0.9070871472358704 992.69384765625

I believe there are mistakes in the implementation of BitLinear in bitnet/bitlinear.py.

To Reproduce Steps to reproduce the behavior:

  1. print the mean and variance of y in example.py
  2. insert output_linear = torch.nn.functional.linear(x, self.weight, self.bias) in bitnet/bitlinear.py line 129. Then print the mean and variance of output_linear

Upvote & Fund

Fund with Polar

suzuke commented 6 months ago

The implementation of this binear is completely wrong, not only does it not follow the process outlined in the Bitnet paper, but it also misunderstands all the computational principles. I don't understand why it still receives so many stars.

suzuke commented 6 months ago

Gemma, beta, and alpha are calculated using weights and input before quantization. These parameters are then utilized for weights binarization and input quantization. The binarized weights and quantized input undergo linear operations to produce the output, which is then dequantized using the previously calculated gemma, beta. It's not meaningful to calculate gemma and beta separately for quantization and dequantization stages, and even the implementation of grouping here is entirely nonsensical.

2020zyc commented 6 months ago

Gemma, beta, and alpha are calculated using weights and input before quantization. These parameters are then utilized for weights binarization and input quantization. The binarized weights and quantized input undergo linear operations to produce the output, which is then dequantized using the previously calculated gemma, beta. It's not meaningful to calculate gemma and beta separately for quantization and dequantization stages, and even the implementation of grouping here is entirely nonsensical.

hi, I don't understand what u say. Could u tell more? The code just calculates the gamma/beta in quantization stage dynamically, then uses the two statistics to dequant activation. No extra calculation of gamma/beta in dequantization stage. You of course can take the previous calculation out of the quantization stage, but still need dynamically get the gamma/beta.

2020zyc commented 6 months ago

Gemma, beta, and alpha are calculated using weights and input before quantization. These parameters are then utilized for weights binarization and input quantization. The binarized weights and quantized input undergo linear operations to produce the output, which is then dequantized using the previously calculated gemma, beta. It's not meaningful to calculate gemma and beta separately for quantization and dequantization stages, and even the implementation of grouping here is entirely nonsensical.

Another implementation is BIT-Transformers. I don't know how its BitLinear works, especially the forward function. No obvious beta/gamma and no need to dequant output. Could u understand this code? Thanks

forward image

suzuke commented 6 months ago

The issues I mentioned have been addressed in the commit 6cdb2ea998e843b454f2fbaaef73bc6bf92c305f.

Jiangxg commented 6 months ago

The issues I mentioned have been addressed in the commit 6cdb2ea.

Yes, most of the problem has been addressed. Still got a bug in the implementation of grouping. I am working on that.