microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
341 stars 29 forks source link

accuracy and performance of bfloat16 with bitblas linear #161

Open AbedKhateeb2 opened 2 weeks ago

AbedKhateeb2 commented 2 weeks ago

I tried to run bfloat16 linear of bitblas but I got different result

output: qunatizing /decoder/block/0/layer/0/SelfAttention/k torch linear took by avg 7.802581787109375e-05 BitBLAS Operator found in global_operator_cache. bitblas linear took to init : 1.157283067703247 sec bitblas linear took by avg 7.474946975708008e-05 torch compare : tensor(2.2344, device='cuda:0', dtype=torch.bfloat16)

the linear layer is from pretrained model the model was trained with bf16 cuda version : 12.1 gpu : A10G ubuntu bitblas version bitblas==0.0.1.dev15

from bitblas import Linear as BitBLASLinear
print(f"qunatizing {name}")
  in_features = linear_layer.in_features
  out_features = linear_layer.out_features

      opt_M = 1

  class Custom( BitBLASLinear):

      def forward(self, A):
          out = super().forward(A)
          out = out.to(torch.bfloat16)
          return out
  input_tensor = torch.rand(opt_M, in_features).to(torch.bfloat16).cuda()
  st = time.time()
  while time.time() - st < 1.0:
      linear_layer(input_tensor)
  times = 1000
  with torch.no_grad():
      start_time = time.time()
      for _ in range(times):
          output_torch = linear_layer(input_tensor)
      end_time = time.time()
  print(f"torch linear took by avg {(end_time-start_time)/times}")
  start_time = time.time()
  # bitblas_linear = Int8Linear(linear_module=linear_torch)
  # BitBLASLinear.STORAGE_DTYPE='bfloa16'
  bitblas_linear = Custom(linear_layer.in_features, linear_layer.out_features, bias=linear_layer.bias is not None, opt_M=opt_M, accum_dtype='float32', A_dtype='bfloat16', W_dtype='bfloat16')
  bitblas_linear.load_and_transform_weight(linear_layer.weight.clone())
  if linear_layer.bias is not None:
      bitblas_linear.bias.data = linear_layer.bias.data.clone()

  st = time.time()
  while time.time() - st < 1.0:
      bitblas_linear(input_tensor)
  end_time = time.time()
  print(f"bitblas linear took to init : {(end_time-start_time)} sec")
  bitblas_linear.cuda()
  with torch.no_grad():
      start_time = time.time()
      for _ in range(times):
          output_bitblas = bitblas_linear(input_tensor)
      end_time = time.time()
  print(f"bitblas linear took by avg {(end_time-start_time)/times}")

  print("torch compare : ",torch.mean(torch.abs(output_torch.to(torch.bfloat16)-output_bitblas.to(torch.bfloat16))))
LeiWang1999 commented 2 weeks ago

hi @AbedKhateeb2 , bfloat16 related test can be found at https://github.com/microsoft/BitBLAS/blob/main/testing/python/operators/test_general_matmul_bf16.py

would you mind provide a simple unit test to reproduce? because I cannot access the layer that you mentioned within your problem.

AbedKhateeb2 commented 2 weeks ago

thank you @LeiWang1999 for your response 😃 here a standalone script torch 2.4.0 torchaudio 2.4.0 torchvision 0.19.0 bitblas 0.0.1.dev15 Python 3.10.14 nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Mon_Apr__3_17:16:06_PDT_2023 Cuda compilation tools, release 12.1, V12.1.105 Build cuda_12.1.r12.1/compiler.32688072_0

import time
from bitblas import Linear as BitBLASLinear
import torch
import torch.nn as nn
import os
import torchvision.models as models
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# Load a pre-trained VGG-16 model
vgg16 = models.vgg16(pretrained=True)

# Get the last linear layer from the pre-trained model
linear_layer = vgg16.classifier[3].to(torch.bfloat16).cuda()
linear_layer.bias=None
print(f"quantizing {linear_layer}")
in_features = linear_layer.in_features
out_features = linear_layer.out_features

opt_M = 1

input_tensor = torch.rand(opt_M, in_features).to(torch.bfloat16).cuda()
st = time.time()
while time.time() - st < 1.0:
    linear_layer(input_tensor)
times = 1000
with torch.no_grad():
    start_time = time.time()
    for _ in range(times):
        output_torch = linear_layer(input_tensor).to(torch.bfloat16)
    end_time = time.time()
print(f"torch linear took by avg {(end_time-start_time)/times}")

start_time = time.time()
bitblas_linear = BitBLASLinear(linear_layer.in_features, linear_layer.out_features, bias=linear_layer.bias is not None, opt_M=opt_M, accum_dtype='float32', A_dtype='bfloat16', W_dtype='bfloat16')
bitblas_linear.load_and_transform_weight(linear_layer.weight.clone())
if linear_layer.bias is not None:
    bitblas_linear.bias.data = linear_layer.bias.data.clone()

st = time.time()
while time.time() - st < 1.0:
    bitblas_linear(input_tensor)
end_time = time.time()
print(f"bitblas linear took to init : {(end_time-start_time)} sec")
bitblas_linear.cuda()

with torch.no_grad():
    start_time = time.time()
    for _ in range(times):
        output_bitblas = bitblas_linear(input_tensor)
    end_time = time.time()
print(f"bitblas linear took by avg {(end_time-start_time)/times}")

print("torch compare : ", torch.mean(torch.abs(output_torch.to(torch.bfloat16)-output_bitblas.to(torch.bfloat16))))

the result : quantizing Linear(in_features=4096, out_features=4096, bias=False) torch linear took by avg 7.706689834594727e-05 2024-08-29 17:20:22 [BitBLAS:WARNING]: [BitBLAS][Warning] with_zeros is not supported for int source format as int has a constant zeropoints already. 2024-08-29 17:20:23 [BitBLAS:WARNING]: [BitBLAS][Warning] with_zeros is not supported for int source format as int has a constant zeropoints already. 2024-08-29 17:20:25 [BitBLAS:WARNING]: [BitBLAS][Warning] with_zeros is not supported for int source format as int has a constant zeropoints already. BitBLAS Operator found in global_operator_cache. bitblas linear took to init : 10.60917353630066 sec bitblas linear took by avg 9.263944625854492e-05 torch compare : tensor(2.0469, device='cuda:0', dtype=torch.bfloat16)

LeiWang1999 commented 2 weeks ago

https://github.com/microsoft/BitBLAS/blob/872d6d71b2c6caee294544b1364132f413be5262/bitblas/module/__init__.py#L267-L270

There do exist a bug in BitBLASLinear that causes any datatype to be casted into float16.

LeiWang1999 commented 2 weeks ago

Hi @AbedKhateeb2 , Take a look at pr #164

You can check out this fix by installing the upstream bitblas with command pip install git+https://github.com/microsoft/BitBLAS.git