kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.68k stars 154 forks source link

Expected BitLinear weight to be 1 or -1 #52

Closed sanjeev-bhandari closed 1 month ago

sanjeev-bhandari commented 6 months ago

Hello, I presume according to BitNet paper the weight should be -1 or 1. But

import torch
from bitnet import BitLinearNew

# Create a random tensor of shape (16, 10)
x = torch.randn(2, 10, 10)

# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(
    10,
    20,
)

# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)

print(layer.weight.dtype)
print(layer.weight)

Output

torch.float32
Parameter containing:
tensor([[ 0.1634,  0.2419, -0.0605,  0.1592,  0.2348, -0.1431, -0.1634,  0.0171,
         -0.1672, -0.1526],
        [-0.0848,  0.0079, -0.2014, -0.0492,  0.2833,  0.1290, -0.2156, -0.1515,
         -0.0473, -0.0839],
        [ 0.2230,  0.1434, -0.1410, -0.0626,  0.1189, -0.1652, -0.2978, -0.0287,
          0.1025,  0.2458],
        [-0.1670, -0.0222, -0.0272, -0.2312,  0.1880, -0.2040, -0.0305,  0.1009,
         -0.2247,  0.0124],
        [ 0.1351, -0.2926,  0.1891, -0.1614,  0.2894, -0.2931,  0.0802,  0.2884,
          0.0454, -0.1398],
        [-0.2954,  0.2651, -0.0062, -0.1592,  0.2138, -0.2038,  0.2965, -0.2545,
          0.0505, -0.0811],
        [-0.3062, -0.1191, -0.1521,  0.1021, -0.1865, -0.1102,  0.2120, -0.2865,
          0.1754,  0.1763],
        [ 0.1375, -0.2975,  0.0399, -0.1723, -0.0526, -0.2694,  0.1838, -0.1826,
          0.2806, -0.1438],
        [-0.3150,  0.2163,  0.1946, -0.0244,  0.0657, -0.1531, -0.0310,  0.0071,
          0.2590,  0.0985],
        [ 0.0402,  0.0704, -0.1441, -0.1929, -0.2450,  0.2408, -0.0750,  0.0238,
          0.3030,  0.0516],
        [ 0.1537, -0.2231, -0.0092, -0.1068,  0.3131,  0.0859, -0.1692, -0.2364,
          0.2257,  0.2601],
        [-0.0478, -0.2978, -0.2025, -0.2411, -0.3061, -0.2566,  0.0564, -0.0906,
          0.2113,  0.3118],
        [-0.1048,  0.2073, -0.2126, -0.1883,  0.0463, -0.1716, -0.3052,  0.0548,
          0.2079,  0.2587],
        [-0.1387,  0.1778, -0.1886,  0.1239,  0.0265, -0.0421, -0.1020,  0.2481,
         -0.0840,  0.1879],
        [ 0.0707, -0.0534,  0.0623,  0.0803,  0.3135,  0.2192, -0.1202,  0.3139,
          0.0781, -0.0512],
        [ 0.2812,  0.2515, -0.0371,  0.0248,  0.0231, -0.0437,  0.0875,  0.3085,
         -0.0482, -0.0092],
        [ 0.1735,  0.2584, -0.0900, -0.1616,  0.1253,  0.1352,  0.1841,  0.1416,
         -0.0686, -0.0269],
        [-0.3121, -0.1050,  0.0265,  0.0242,  0.1973,  0.1816, -0.0084,  0.2866,
          0.2559, -0.2523],
        [ 0.1272, -0.2361,  0.0847, -0.0724,  0.2531,  0.0948, -0.0765, -0.1252,
         -0.0459, -0.0133],
        [-0.0660,  0.0650,  0.2529, -0.1763, -0.1248, -0.1073, -0.2926,  0.1837,
          0.1265, -0.0090]], requires_grad=True)

Am I missing something?

Upvote & Fund

Fund with Polar

rolson24 commented 6 months ago

I'm pretty sure the reason it shows fp32 numbers in the parameters is because when training the network, you need to use the original floating point values for backprop (1.58 bit quant destroys gradient). Then when you do the forward pass, the weights are re-quantized every training step. I believe when you actually deploy the model, you would simply take the quantized weights and use them.

jmbrito01 commented 4 months ago

The answer from @rolson24 is correct in the sense that the weights are only converted before the linear operation and then dequantized just like the original paper describes(see image below). However, digging a little deeper in the code, i found that the weight quantization is not actually being done correctly, the sign operation in the weight quantization is being done before the scale multiplication. Here's how to test it:

import torch
# https://github.com/kyegomez/BitNet/blob/f56addac025b8bd58a02f9f00eb4cf530770658a/bitnet/bitlinear.py#L20C1-L24C13
def current_weight_quant(w):
    scale = w.abs().mean()
    e = w.mean()
    u = (w - e).sign() * scale
    return u

def correct_weight_quant(w):
    scale = w.abs().mean()
    e = w.mean()
    u = torch.sign((w - e) * scale)
    return u

weights = torch.rand(15)
print("Original weights: ", weights)
print("Repo quant fn: ", current_weight_quant(weights))
print("Correct quant fn:", correct_weight_quant(weights))

Output:

Original weights:  tensor([0.5857, 0.0053, 0.8400, 0.5586, 0.8302, 0.9758, 0.6332, 0.4917, 0.4092,
        0.6722, 0.5738, 0.1896, 0.5210, 0.6124, 0.1334])
Repo quant fn:  tensor([ 0.5355, -0.5355,  0.5355,  0.5355,  0.5355,  0.5355,  0.5355, -0.5355,
        -0.5355,  0.5355,  0.5355, -0.5355, -0.5355,  0.5355, -0.5355])
Correct quant fn: tensor([ 1., -1.,  1.,  1.,  1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,
        -1.])

Just created the PR #59 to get it fixed.

image

izavala97 commented 3 months ago

What about zero case @jmbrito01 ?, paper mentions weights can be {-1,0,1}, i only see {-1,1}

Edit: Ok me bad, i have read an update of the article, and i have realized that b1 and b1.58 are different architectures, in this case b1 refers to BitLinear implementation and its values are always in range {-1,1}.

Tip: We should have a unit test section to avoid this kind of issues and validate future PR's

github-actions[bot] commented 1 month ago

Stale issue message