bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.34k stars 636 forks source link

Linear8bitLt can not be moved back to cpu #1332

Open Nerogar opened 3 months ago

Nerogar commented 3 months ago

System Info

During inference of larger models in VRAM constrained environments, offloading unused model layers from VRAM to RAM is an easy method of reducing the overall VRAM usage. Linear8bitLt can not be moved back to cpu memory once forward() is called on them. There are two issues:

  1. The internal state object is ignored during the to("cpu") call. All tensors remain in VRAM
  2. Moving the layer back to cuda calls the quantization logic again, breaking the model

Reproduction

import torch
import bitsandbytes as bnb

if __name__ == '__main__':
    # initialize a simple model
    generator = torch.Generator()
    generator.manual_seed(42)

    linear = bnb.nn.Linear8bitLt(
        input_features=32,
        output_features=32,
        bias=False,
        has_fp16_weights=False,
    )
    x_in = torch.randn(size=(1, 32), generator=generator)
    torch.nn.init.xavier_uniform(linear.weight, generator=generator)

    # move everything to CUDA for the first time
    linear.to("cuda")
    x_in = x_in.to("cuda")

    # call the model once to get a "good" result
    x_out_1 = linear(x_in)
    print(f"output 1: {x_out_1}")

    # move the model to cpu and observe that some tensors are still stored in VRAM
    linear.to("cpu")
    print(f"CxB device after cpu offloading: {linear.state.CxB.device}")
    linear.to("cuda")

    # call the model again after moving it to CUDA
    x_out_2 = linear(x_in)
    print(f"output 1: {x_out_2}")

Output:

output 1: tensor([[ 0.0978, -0.4744, 0.0976, -1.7158, 0.3936, 0.7334, -0.6406, 0.5264, 1.7373, -1.0938, -1.0625, -0.3091, 0.9946, 2.1582, 0.4675, 1.2090, -0.7349, -0.2979, 0.6055, 0.1614, 1.0742, -2.6758, -2.2266, 0.7310, -1.5635, 0.1646, -0.0526, 0.4590, 1.0068, -1.6650, 0.5469, 0.1232]], device='cuda:0')

CxB device after cpu offloading: cuda:0

output 2: tensor([[ 41.0000, -198.3750, 40.5000, -716.0000, 173.3750, 314.5000, -267.0000, 219.5000, 731.0000, -459.5000, -444.7500, -134.5000, 429.0000, 908.0000, 199.3750, 527.5000, -306.7500, -130.5000, 256.7500, 68.0625, 447.5000, -1117.0000, -941.0000, 305.7500, -726.0000, 69.8750, -22.7344, 195.1250, 440.0000, -694.0000, 241.1250, 51.9062]], device='cuda:0')

Expected behavior

  1. to("cpu") should move all parameters of a model to to cpu memory
  2. moving a model around between devices should not change the model.

Linear4bit already implements this behavior, I would expect Linear8bitLt to behave the same way

Arcitec commented 4 weeks ago

@matthewdouglas This is definitely a serious bug which impacts me too. Any idea if or when it's fixable?

O-J1 commented 4 weeks ago

Also agreed, this impacts me as well.