microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
190 stars 21 forks source link

Perplexity evaluation too high for 1bitLLM/bitnet_b1_58-3B #47

Closed MekkCyber closed 3 weeks ago

MekkCyber commented 3 weeks ago

Hello Everyone,

I am trying to evaluate the perplexity of 1bitLLM/bitnet_b1_58-3B, using the script available in integration/BitNet. However, I am getting a very high loss, and perplexity. Is it normal ?

avg_loss = 14.603133460411653: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 174/174 [02:04<00:00,  1.40it/s]
wikitext2 PPL: 24887.495944644015
[23059.264828634798, 24887.495944644015]
Avg PPL: 23973.38038663941
LeiWang1999 commented 3 weeks ago

Thanks for your report @MekkCyber ! it's unusual, let me check it out and get back to you, thanks.

igrekun commented 3 weeks ago

Recent update might've broken something. Here are the results between simulated and bitblas matmul

t_input = torch.randn(1, 64).to("cuda")
t_linear = utils_quant.BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)
print("s_diff", (n_output - s_output).abs().mean(), "q_diff", (n_output - q_output).abs().mean())

s_diff tensor(6.7442e-05, device='cuda:0') q_diff tensor(0.4397, device='cuda:0')

LeiWang1999 commented 3 weeks ago

Thank you for your reports. Previously we made a pull request that make int1 represents the range [-1, 1], and int2 represents [-2, 1]. It seems this has introduced a bug in the bitnet integration. I'm currently working on other items, but I will address this issue in the coming days.

LeiWang1999 commented 3 weeks ago

Hi @MekkCyber and @igrekun The pipeline and correctness tests passed on my main branch; could you please provide your BitBlas version and device details?

igrekun commented 3 weeks ago

BitBLAS version: bitblas==0.0.1.dev7 Devices tested: Nvidia T4, Nvidia A100. Cuda version: 12.1 Ran the snippet above that computes mean difference. Should I test the main branch and build from source?

LeiWang1999 commented 3 weeks ago

Recent update might've broken something. Here are the results between simulated and bitblas matmul

t_input = torch.randn(1, 64).to("cuda")
t_linear = utils_quant.BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)
print("s_diff", (n_output - s_output).abs().mean(), "q_diff", (n_output - q_output).abs().mean())

s_diff tensor(6.7442e-05, device='cuda:0') q_diff tensor(0.4397, device='cuda:0')

would you mind provide the whole script to reproduce this result? :)

igrekun commented 3 weeks ago

Sure, put this anyhwere in integration/BitNet folder

import torch
from utils_quant import BitLinear

t_input = torch.randn(1, 64).to("cuda")
t_linear = BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)
print("s_diff", (n_output - s_output).abs().mean(), "q_diff", (n_output - q_output).abs().mean())
LeiWang1999 commented 3 weeks ago

Thanks @igrekun , pr https://github.com/microsoft/BitBLAS/pull/49 fixed this issue and you can apply it by installing 0.0.1.dev8

pip install bitblas==0.0.1.dev8
igrekun commented 3 weeks ago

Validated. Works perfectly on A100. Gives different results on T4 (Compute capability 7.5) though this is not listed as supported so guess we can close this.

LeiWang1999 commented 3 weeks ago

hi @igrekun what's the results on T4?

I thinks that because T4 doesn't have INT8 TensorCore, which might lead to compilation issues, I think we can fallback into CUDA Core in future release.

igrekun commented 3 weeks ago

That's most likely the cause, logs don't show any grid tuning results.

2024-06-05 13:25:08 [BitBLAS:INFO]: Start fast tuning with dynamic range
BitBLAS Tuning done, appended operator to global_operator_cache.
LeiWang1999 commented 3 weeks ago

Thanks, closed :)

TeaPoly commented 1 week ago
GPU: NVIDIA RTX A4000
bitblas=0.0.1.dev8

Use the same example:

import torch
from utils_quant import BitLinear

t_input = torch.randn(1, 64).to("cuda")
t_linear = BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)
print("s_diff", (n_output - s_output).abs().mean(), "q_diff", (n_output - q_output).abs().mean())

Result is here:

2024-06-17 14:28:03 [BitBLAS:INFO]: Database path /home/huanglk/.cache/bitblas does not exist, skipping loading operators from the database
2024-06-17 14:28:06 [BitBLAS:INFO]: Start fast tuning with dynamic range
BitBLAS Tuning done, appended operator to global_operator_cache.
s_diff tensor(6.4458e-05, device='cuda:0') q_diff tensor(nan, device='cuda:0')
LeiWang1999 commented 1 week ago

@TeaPoly Hi, would you mind update the bitblas version to 0.0.1.dev12 and check the scripts again ?

TeaPoly commented 1 week ago

@TeaPoly Hi, would you mind update the bitblas version to 0.0.1.dev12 and check the scripts again ?

pip install bitblas==0.0.1.dev12 and rm -r ~/.cache/bitblas

The result is same.

When I print more detail, q_output has INF.

import torch
from utils_quant import BitLinear

torch.random.manual_seed(42)
t_input = torch.randn(1, 64).to("cuda")
t_linear = BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)

print("n_output:", n_output)
print("q_output:", q_output)
# "s_diff", (n_output - s_output).abs().mean(),
print("q_diff", (n_output - q_output).abs().mean())
n_output: tensor([[-0.5695, -0.9579, -0.4282,  0.5234, -0.4041,  0.0607,  0.1152,  0.1612,
         -0.1382,  0.3141, -0.2858, -0.0890,  0.0345,  0.4156, -0.7401,  0.5517,
         -0.4030,  0.2094,  0.4009, -0.4093, -0.2984,  0.4910, -0.3507,  0.2638,
         -0.5444, -0.9558, -0.0147,  0.0858, -0.5590,  0.5276, -0.5507,  0.1371,
         -0.4973,  0.2837, -0.0764, -0.0722, -0.1905,  0.2418,  0.4187, -0.5789,
         -0.2408, -0.4397,  0.0241, -0.7349,  0.3748, -0.1696, -0.2104,  0.4167,
          0.1413, -0.4627,  0.1015, -0.4386, -0.6637, -0.4659, -0.5349,  0.6899,
          0.0167, -0.1068,  0.1790, -0.3078, -0.7213, -0.5140,  0.2115, -0.2041,
          0.3381,  0.0387,  0.2167, -0.1382,  0.4481, -0.5067,  0.1403, -0.1193,
          0.2827,  0.1036,  0.2690, -0.4208,  0.6899, -0.2293, -0.4114, -0.0911,
         -0.7632, -0.2806, -0.1382, -0.3162,  0.2240, -0.2827, -0.0890, -0.0827,
          0.2502, -0.7851,  0.4292, -0.2167,  0.8804, -0.5119,  0.3423,  0.0858,
         -0.1403, -0.2607, -0.7925,  0.4606,  0.8972, -0.3381,  0.1298,  0.3570,
          0.1120, -0.7004,  0.2387, -0.3078, -0.9076,  0.3172,  0.3643, -0.4093,
          0.2314, -0.2984, -0.1769,  0.7653,  0.3695, -0.5412,  0.0345, -0.7653,
          0.7181,  0.5821,  0.5548, -0.6061,  0.2230, -0.5559, -0.4083, -0.0450]],
       device='cuda:0')
q_output: tensor([[ 1.2036e-01,  9.3201e-02,  5.6519e-02, -1.3196e-01,  4.2908e-02,
         -7.7454e-02, -3.1414e-03, -1.0052e-01, -4.7119e-02,  1.0364e-01,
         -2.4078e-02, -8.7952e-02, -4.6051e-02, -3.4546e-02, -4.8157e-02,
          4.8157e-02,  1.0260e-01, -1.0468e-02, -3.1403e-02,  2.7222e-02,
         -4.7119e-02,  6.8054e-02,  5.0262e-02,  1.0571e-01,  8.0627e-02,
          8.1665e-02,  3.8727e-02,  8.3740e-02, -1.4656e-02,  2.0943e-03,
         -1.5701e-02,  5.3375e-02, -8.6914e-02, -5.4443e-02, -1.3611e-02,
          1.0785e-01,  1.9897e-02, -2.6169e-02,  1.8845e-02, -4.8157e-02,
         -9.7351e-02,  6.2805e-02, -5.5481e-02, -3.7689e-02, -7.9590e-02,
          1.3293e-01, -7.7454e-02, -3.0365e-02, -5.7587e-02, -4.0833e-02,
          5.2338e-03,  3.2440e-02, -3.0365e-02,  7.4341e-02, -5.1300e-02,
         -4.6051e-02, -8.7952e-02,  2.0943e-03, -4.1885e-03,  4.1870e-02,
         -6.2828e-03,  1.1517e-01, -7.4341e-02,  8.6914e-02,        -inf,
                 inf,         inf, -0.0000e+00,        -inf, -0.0000e+00,
                 inf,         inf,        -inf,        -inf,         inf,
                 inf,         inf,         inf,        -inf, -0.0000e+00,
                 inf, -1.3628e-03,        -inf, -0.0000e+00,        -inf,
                -inf,        -inf, -0.0000e+00, -8.7786e-04,         inf,
                -inf,        -inf,         inf,        -inf,         inf,
                -inf, -0.0000e+00,         inf,        -inf,        -inf,
                -inf,         inf,        -inf,        -inf,        -inf,
                 inf, -0.0000e+00, -6.4433e-05, -0.0000e+00,         inf,
                -inf,         inf,        -inf,        -inf,         inf,
                -inf,         nan,         inf,        -inf, -0.0000e+00,
                -inf, -1.8959e-03,        -inf, -0.0000e+00,        -inf,
         -0.0000e+00, -0.0000e+00,        -inf]], device='cuda:0',
       dtype=torch.float16)
q_diff tensor(nan, device='cuda:0')
LeiWang1999 commented 1 week ago

@MekkCyber , Utilizing bitblas.set_log_level("Debug") can yield more detailed information.

import torch
from utils_quant import BitLinear
import bitblas
bitblas.set_log_level("DEBUG")

t_input = torch.randn(1, 64).to("cuda")
t_linear = BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)
print("s_diff", (n_output - s_output).abs().mean(), "q_diff", (n_output - q_output).abs().mean())

The scripts works fine on my A100 gpu and I've no idea about the architecture detail of NVIDIA A4000, does it support INT8 Tensor Core?

TeaPoly commented 1 week ago

@MekkCyber , Utilizing bitblas.set_log_level("Debug") can yield more detailed information.

import torch
from utils_quant import BitLinear
import bitblas
bitblas.set_log_level("DEBUG")

t_input = torch.randn(1, 64).to("cuda")
t_linear = BitLinear(64, 128, bias=False).to("cuda")
with torch.no_grad():
    n_output = t_linear.native_forward(t_input)
    s_output = t_linear.forward_fp32_simulated(t_input)
    t_linear.post_process_weights()
    q_output = t_linear.forward(t_input)
print("s_diff", (n_output - s_output).abs().mean(), "q_diff", (n_output - q_output).abs().mean())

The scripts works fine on my A100 gpu and I've no idea about the architecture detail of NVIDIA A4000, does it support INT8 Tensor Core?

It works fine for NVIDIA GeForce RTX 3090,

q_diff tensor(5.8702e-05, device='cuda:0')

but result in NVIDIA RTX A4000 and NVIDIA RTX A6000 is wrong.

q_diff tensor(nan, device='cuda:0')

Debug logger in A4000 is here: a4000.log

LeiWang1999 commented 1 week ago

@TeaPoly , thank you for your report. It appears there are some bugs related to the target data structure because TVM lacks an A4000 Target Tag. Unfortunately, I don't possess an A4000 GPU to reproduce the issue. Perhaps running export TVM_TARGET="cuda" python test.py might resolve it.