ROCm / bitsandbytes

8-bit CUDA functions for PyTorch
MIT License
34 stars 3 forks source link

bitsandbytes loading models in 8-bit doesn't work on Radeon GPUs (gfx1100) #41

Open farshadghodsian opened 2 months ago

farshadghodsian commented 2 months ago

System Info

Platform: Ubuntu 22.04 Python Version: 3.10 Hardware: Threadripper Pro 5975wx, WRX80 motherboard, 128GB RAM, 1x Radeon Pro W7900 GPU Python Libraries:

accelerate==0.31.0
fastapi==0.111.0
gradio==4.37.2
torch==torch==2.5.0.dev20240705+rocm6.1
transformers==4.37.2

bitsandbytes version = built and installed using instructions in rocm_enabled branch

Reproduction

I was successfully able to get models loading in 4-bit using the ROCm/bitsandbytes and transformers library. To get it working I had to tell PyTorch not to use HIPBLASLT as Radeon GPUs do not support it. This is done by setting the environment variable TORCH_BLAS_PREFER_HIPBLASLT=0. This is an upstream issue with PyTorch and not related to bitsandbytes. While this works for loading models in 4-bit, inference fails when I try to do the same with models in 8-bit. When I run the follow code I get an error even after setting TORCH_BLAS_PREFER_HIPBLASLT=0 (see below):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_name = "facebook/opt-350m"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Example input
input_text = "Hello, world!"

# Prepare the input for the model
inputs = tokenizer.encode_plus(
        input_text,
        add_special_tokens=True,
        max_length=512,
        return_attention_mask=True,
        return_tensors='pt'
    )

# Generate text using the model
generated_text = model_8bit.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'])

# Decode the generated text using the tokenizer
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)

# Print the generated output
print(decoded_text)

Error:

rocblaslt warning: No paths matched /home/amd-user/.local/lib/python3.10/site-packages/torch/lib/hipblaslt/library/*gfx1100*co. Make sure that HIPBLASLT_TENSILE_LIBPATH is set correctly.
A: torch.Size([5, 512]), B: torch.Size([1024, 512]), C: (5, 1024); (lda, ldb, ldc): (c_int(5), c_int(1024), c_int(5)); (m, n, k): (c_int(5), c_int(1024), c_int(512))
error detectedTraceback (most recent call last):
  File "/home/amd-user/Documents/bitsandbytes 8-bit test.py", line 23, in <module>
    generated_text = model_8bit.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'])
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1479, in generate
    return self.greedy_search(
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2340, in greedy_search
    outputs = self(
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 1145, in forward
    outputs = self.model.decoder(
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 863, in forward
    inputs_embeds = self.project_in(inputs_embeds)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 801, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 559, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 398, in forward
    out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/functional.py", line 2388, in igemmlt
    raise Exception("cublasLt ran into an error!")
Exception: cublasLt ran into an error!

Expected behavior

The above code works fine when I load the model in 4bit only changing quantization_config = BitsAndBytesConfig(load_in_4bit=True) instead of quantization_config = BitsAndBytesConfig(load_in_8bit=True). This issue also occurs when using HuggingFace's accelerate library when trying to use load_in_8bit=True.

I don't really find myself using 8-bit a whole lot so this is not a priority, but thought I should point out this issue so that others are aware. Loading 4-bit models using bitsandbytes is working without issues.