pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.56k stars 169 forks source link

Outdated gpt-fast codebase #751

Open YJYJLee opened 2 months ago

YJYJLee commented 2 months ago

Thanks for the great work! I tried to enable AutoQuant on top of the latest gpt-fast repository since gpt-fast version that ao repo is providing as an example is outdated.

Here is the diff of enabling AutoQuant on top of the latest gpt-fast codebase.

Screenshot 2024-08-26 at 4 59 26 PM

But I'm getting the error mentioning "CUDA generator expects graph capture to be underway, but the current stream is not capturing".

(/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin) yejinlee@a100-st-p4de24xlarge-47:/fsx-checkpoints/yejinlee/gpt-fast$ python generate.py --compile --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth  --prompt "Hello, my name is" --quantization autoquant
uintx feature need torch 2.3+, please upgrade pytorch
Using device=cuda
Loading model ...
Time to load model: 39.35 seconds
activation_shapes: torch.Size([6, 4096]), times_seen: 1
activation_shapes: torch.Size([1, 4096]), times_seen: 199
weight_shape: torch.Size([12288, 4096]), dtype: torch.bfloat16, bias_shape: None
AUTOTUNE mm(6x4096, 4096x12288)
  mm 0.0832 ms 100.0%
  triton_mm_8 0.0840 ms 99.0%
  triton_mm_6 0.0842 ms 98.8%
  triton_mm_4 0.0857 ms 97.1%
  triton_mm_3 0.0861 ms 96.6%
  triton_mm_9 0.0879 ms 94.7%
  triton_mm_5 0.0887 ms 93.8%
  triton_mm_2 0.0944 ms 88.1%
  triton_mm_1 0.0962 ms 86.5%
  triton_mm_0 0.1044 ms 79.7%
SingleProcess AUTOTUNE takes 2.7937 seconds
warning: failed to autoquant AQFloatLinearWeight for shape: (torch.Size([6, 4096]), torch.Size([12288, 4096]), None, torch.bfloat16) due to CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Traceback (most recent call last):
  File "/opt/hpcaas/.mounts/fs-08829104cb559c481/yejinlee/gpt-fast/generate.py", line 480, in <module>
    main(
  File "/opt/hpcaas/.mounts/fs-08829104cb559c481/yejinlee/gpt-fast/generate.py", line 354, in main
    model.finalize_autoquant()
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 620, in finalize_autoquant
    _change_autoquantizable_to_quantized(
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 494, in _change_autoquantizable_to_quantized
    _replace_with_custom_fn_if_matches_filter(
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
  [Previous line repeated 1 more time]
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 183, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 238, in insert_subclass
    getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 146, in to_quantized
    self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
  File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 94, in tune_autoquant
    act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
RuntimeError: CUDA generator expects graph capture to be underway, but the current stream is not capturing.

Attaching the env info here

torch                    2.3.1+cu121
torchao                  0.4.0+gitc2f44608
torchaudio               2.3.1+cu121
torchvision              0.18.1+cu121

Thanks for the help in advance!

msaroufim commented 2 months ago

I'm actually pretty happy to see this bug report lol, we were hoping to enable torchao in gpt-fast and delete the broken quantization flows. We do need a fork of gpt-fast locally because we are making model changes and unfortunately there isnt' a good solution for us outside of occasionally syncing upstream.

So an action item for @HDCharles is to fix the existing code here in AO but @YJYJLee would you be open to contributing your patch to gpt-fast directly as well? We're doing a big launch on Sep 21 at the CUDA MODE IRL conference and were hoping to feature an integration with gpt-fast by then. Granted would highly recommend you try out PyTorch nightlies first

HDCharles commented 2 months ago

I can look at it but in reality these are updates to TorchAO, not gpt-fast i.e. TorchAO's model/generate code is more up to date than gpt-fast, rather than vice versa