bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.31k stars 634 forks source link

Enable certain CUDA kernels to accept specified cuda stream #1330

Closed jeejeelee closed 3 months ago

jeejeelee commented 3 months ago

FIX https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308

By passing specified stream to certain kernel functions, cudagraph can correctly capture these kernels, enabling downstream repo vLLM to run inference in cudagraph mode, resulting in significant speed improvements for BNB models. ping @matthewdouglas @Titus-von-Koeller @TimDettmers cc @chenqianfzh

Titus-von-Koeller commented 3 months ago

Dear @jeejeelee,

Really cool, we weren't aware vLLM uses cudagraph. Just looked over this with Tim and overall, especially given the performance benefits this may have, this is a very strong contribution, thanks!

I checked out your branch and tried running the tests, but do get the below segfault, which doesn't happen on main. Rerunning the tests gives the same result. Could you please look into this, can you reproduce on your machine? I have a quad L4 setup with CC8.9, CUDA 12.4, Pytorch 2.4.

tests/test_autograd.py::test_matmul_fp8[matmul_fp8_mixed-fp16-transpose=FT-req_grad=TTT-dim4=61-dim3=59-dim2=43-dim1=17] Fatal Python error: Segmentation fault

Thread 0x00007deaa34006c0 (most recent call first):
<no Python frame>

Thread 0x00007dea95e006c0 (most recent call first):
<no Python frame>

Thread 0x00007deaa2a006c0 (most recent call first):
<no Python frame>

Thread 0x00007deaa3e006c0 (most recent call first):
<no Python frame>

Current thread 0x00007dec9d834740 (most recent call first):
  File "/home/ubuntu/src/bnb/bitsandbytes/functional.py", line 1535 in dequantize_no_absmax
  File "/home/ubuntu/src/bnb/bitsandbytes/functional.py", line 1476 in dequantize
  File "/home/ubuntu/src/bnb/bitsandbytes/research/autograd/_functions.py", line 42 in forward
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/torch/autograd/function.py", line 574 in apply
  File "/home/ubuntu/src/bnb/bitsandbytes/research/autograd/_functions.py", line 407 in matmul_fp8_mixed
  File "/home/ubuntu/src/bnb/tests/test_autograd.py", line 456 in test_matmul_fp8
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/python.py", line 159 in pytest_pyfunc_call
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/python.py", line 1627 in runtest
  File "/home/ubuntu/src/bnb/tests/conftest.py", line 9 in pytest_runtest_call
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 242 in <lambda>
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 341 in from_call
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 241 in call_and_report
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 132 in runtestprotocol
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 362 in pytest_runtestloop
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 337 in _main
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 283 in wrap_session
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/main.py", line 330 in pytest_cmdline_main
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/config/__init__.py", line 175 in main
  File "/home/ubuntu/.condax/mamba/envs/bnb/lib/python3.8/site-packages/_pytest/config/__init__.py", line 201 in console_main
  File "/home/ubuntu/.condax/mamba/envs/bnb/bin/pytest", line 10 in <module>
Segmentation fault (core dumped)

Please also be sure to install the pre-commit hooks 🤗

jeejeelee commented 3 months ago

@Titus-von-Koeller , Thank you for the feedback, I've corrected the error mentioned above. I'm verifying whether all the unit tests are passing.

jeejeelee commented 3 months ago

On my machine with a 3090 GPU, my test results are as follows:

=========================================================== 13 failed, 3264 passed, 35 skipped, 1061 warnings, 16 errors in 875.66s (0:14:35) ==========================================================

All tests in test_generation.py failed due to a network connection error. All tests in test_triton.py failed because my local triton version is 3.0.0. The other errors are likely due to precision issues. I'm not certain if these are caused by this PR

jeejeelee commented 3 months ago

@Titus-von-Koeller please review again, thanks~

matthewdouglas commented 3 months ago

@danielhanchen I believe you're directly calling some of these C-API functions in Unsloth, so I want to make sure you've got a heads up here since this changes their signatures.

matthewdouglas commented 3 months ago

@jeejeelee Thank you for the contribution! The only nit I have is the one that I noted about using c_void_p instead of uint64.

A few test failures in test_kbit_backprop and test_gemv_4bit is OK and not related to this PR. I see similar results on my 4090. The generation tests passed for me. Looks nice!

danielhanchen commented 3 months ago

@danielhanchen I believe you're directly calling some of these C-API functions in Unsloth, so I want to make sure you've got a heads up here since this changes their signatures.

Super thanks for the heads up!! Yep we use the C API directly!

Titus-von-Koeller commented 3 months ago

I'll be off until Monday, @matthewdouglas will be taking the lead. Thanks both!