Closed sanjeev-bhandari closed 1 month ago
I'm pretty sure the reason it shows fp32 numbers in the parameters is because when training the network, you need to use the original floating point values for backprop (1.58 bit quant destroys gradient). Then when you do the forward pass, the weights are re-quantized every training step. I believe when you actually deploy the model, you would simply take the quantized weights and use them.
The answer from @rolson24 is correct in the sense that the weights are only converted before the linear operation and then dequantized just like the original paper describes(see image below). However, digging a little deeper in the code, i found that the weight quantization is not actually being done correctly, the sign operation in the weight quantization is being done before the scale multiplication. Here's how to test it:
import torch
# https://github.com/kyegomez/BitNet/blob/f56addac025b8bd58a02f9f00eb4cf530770658a/bitnet/bitlinear.py#L20C1-L24C13
def current_weight_quant(w):
scale = w.abs().mean()
e = w.mean()
u = (w - e).sign() * scale
return u
def correct_weight_quant(w):
scale = w.abs().mean()
e = w.mean()
u = torch.sign((w - e) * scale)
return u
weights = torch.rand(15)
print("Original weights: ", weights)
print("Repo quant fn: ", current_weight_quant(weights))
print("Correct quant fn:", correct_weight_quant(weights))
Output:
Original weights: tensor([0.5857, 0.0053, 0.8400, 0.5586, 0.8302, 0.9758, 0.6332, 0.4917, 0.4092,
0.6722, 0.5738, 0.1896, 0.5210, 0.6124, 0.1334])
Repo quant fn: tensor([ 0.5355, -0.5355, 0.5355, 0.5355, 0.5355, 0.5355, 0.5355, -0.5355,
-0.5355, 0.5355, 0.5355, -0.5355, -0.5355, 0.5355, -0.5355])
Correct quant fn: tensor([ 1., -1., 1., 1., 1., 1., 1., -1., -1., 1., 1., -1., -1., 1.,
-1.])
Just created the PR #59 to get it fixed.
What about zero case @jmbrito01 ?, paper mentions weights can be {-1,0,1}, i only see {-1,1}
Edit: Ok me bad, i have read an update of the article, and i have realized that b1 and b1.58 are different architectures, in this case b1 refers to BitLinear implementation and its values are always in range {-1,1}.
Tip: We should have a unit test section to avoid this kind of issues and validate future PR's
Stale issue message
Hello, I presume according to BitNet paper the weight should be -1 or 1. But
Output
Am I missing something?
Upvote & Fund