huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
755 stars 55 forks source link

Pixart sigma example crash on CUDA arch >= 80 with int4 weights #248

Open dacorvo opened 1 month ago

dacorvo commented 1 month ago

When running the pixart sigma example on CUDA arch >= 80 with int4 weights, the following error happens:

  File "/home/ubuntu/dev/quanto/optimum/quanto/tensor/qtensor_func.py", line 152, in linear
    return QTensorLinear.apply(input, other, bias)
  File "/home/ubuntu/dev/quanto/.venv/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/dev/quanto/optimum/quanto/tensor/qtensor_func.py", line 130, in forward
    output = output + bias
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This means one of the optimized kernels is crashing (either awq or tinygemm as those are triggered only if CUDA arch >=80).

dacorvo commented 1 month ago

The error happens in the final projection of the transformer. Excluding it from quantization solves the issue, but it is only a workaround.

    if qtype:
        quantize(pipe.transformer, weights=qtype, exclude='proj_out')
        freeze(pipe.transformer)
        quantize(pipe.text_encoder, weights=qtype)
        freeze(pipe.text_encoder)
sayakpaul commented 1 month ago

It actually fails on our science cluster too.

nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:A8:00.0 Off |                    0 |
| N/A   34C    P0              67W / 700W |      2MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

I did include the exclude:

diff --git a/examples/vision/text-to-image/quantize_pixart_sigma.py b/examples/vision/text-to-image/quantize_pixart_sigma.py
index 6d3e2b2..0064f7a 100644
--- a/examples/vision/text-to-image/quantize_pixart_sigma.py
+++ b/examples/vision/text-to-image/quantize_pixart_sigma.py
@@ -22,7 +22,7 @@ def load_pipeline(model_id, torch_dtype, qtype=None, device="cpu"):
     pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True).to(device)

     if qtype:
-        quantize(pipe.transformer, weights=qtype)
+        quantize(pipe.transformer, weights=qtype, exclude="proj_out")
         freeze(pipe.transformer)
         quantize(pipe.text_encoder, weights=qtype)
         freeze(pipe.text_encoder)

But this time the error is different:

Traceback (most recent call last):
  File "/fsx/sayak/optimum-quanto/examples/vision/text-to-image/quantize_pixart_sigma.py", line 78, in <module>
    image = pipeline(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 834, in __call__
    noise_pred = self.transformer(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py", line 270, in forward
    timestep, embedded_timestep = self.adaln_single(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 169, in forward
    embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/embeddings.py", line 1175, in forward
    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/embeddings.py", line 472, in forward
    sample = self.act(sample)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 396, in forward
    return F.silu(input, inplace=self.inplace)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/functional.py", line 2102, in silu
    return torch._C._nn.silu(input)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
dacorvo commented 1 month ago

I made progress on this: the issue comes from the fp16int4 AWQ kernels that crash when the output features are less then 128. @sayakpaul you can try this branch: https://github.com/huggingface/optimum-quanto/tree/disable_awq_gemm_out_feats_less_128

sayakpaul commented 1 month ago

@dacorvo the error still persists on the cluster.

dacorvo commented 1 month ago

I don't have access to an H100. I can disable the kernel for Hopper architectures entirely, but it would be slow.

The issue seems to happen in an AdaLayerNormSingle layer. If by chance these are all using the same convention, maybe you can exclude them (you can pass a list of patterns like exclude=['proj_out', '*adaln_single']).

If that fixes the issue, then we can take a closer look at the shape of the Linear weights in that layer.

sayakpaul commented 1 month ago

Still the same issue. The error occurs when it tries to nn.SiLU(), though.

sayakpaul commented 1 month ago

Tried with "fal/AuraFlow" and SiLU fails:

Traceback (most recent call last):
  File "/fsx/sayak/optimum-quanto/examples/vision/text-to-image/quantize_pixart_sigma.py", line 78, in <module>
    image = pipeline(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 555, in __call__
    noise_pred = self.transformer(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/auraflow_transformer_2d.py", line 339, in forward
    temb = self.time_step_proj(temb)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/embeddings.py", line 472, in forward
    sample = self.act(sample)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 396, in forward
    return F.silu(input, inplace=self.inplace)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/functional.py", line 2102, in silu
    return torch._C._nn.silu(input)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
dacorvo commented 1 month ago

Still the same issue. The error occurs when it tries to nn.SiLU(), though.

Yes, but this is generally a deferred error, and the culprit is likely the quantized Linear just before that.

sayakpaul commented 1 month ago

I excluded time_step_proj but it's still the same. I am wondering if this is a separate issue other than the AWQ kernel crashing.

dacorvo commented 1 month ago

As a first step, it would be great to run the optimum-quanto unit tests on a H100.

sayakpaul commented 1 month ago

Yes, but this is generally a deferred error, and the culprit is likely the quantized Linear just before that.

Understood. But this makes me think why would that be when we are excluding the layer from quantization. So,

As a first step, it would be great to run the optimum-quanto unit tests on a H100.

Makes perfect sense. Let's wait for the infra team to reply then.

github-actions[bot] commented 3 weeks ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.