huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
816 stars 61 forks source link

Add support for quantized Conv2d #74

Closed dacorvo closed 8 months ago

dacorvo commented 9 months ago

This layer is required for all computer vision models.

dacorvo commented 9 months ago

A first implementation here: https://github.com/huggingface/quanto/tree/qconv2d The gradient test needs to be fixed (tricky weight gradient calculation), but otherwise it seems to work as expected.

sayakpaul commented 9 months ago

@dacorvo I have got some findings for you (c.f. internal thread).

To set the context first, unlike transformers models, diffusers deal with "pipelines" (e.g., DiffusionPipeline) that can have multiple nn.Modules.

My setup is our internal audace machine (2 4090s while I used only one by setting CUDA_VISIBLE_DEVICES=0). I am on PyTorch nightly along with CUDA 12.1. I installed quanto from https://github.com/huggingface/quanto/tree/qconv2d.

I tried applying quantize() to the unet and the vae and then benchmark the memory and timing. Here's my script.

Timing and memory:

setting timing (secs) memory (gb)
vanilla fp16 4.016 8.958
weight-only 5.769 8.952
weight & activations 14.046 16.845

For the "weight & activations" setting, I changed the load_pipeline() function in the script like so, because float16 is not supported as said here:

Modified pipeline loading code ```python def load_pipeline(do_quantize): pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", ).to("cuda") pipeline.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix").to("cuda") if do_quantize: quantize(pipeline.unet, weights=torch.int8, activations=torch.int8) quantize(pipeline.vae, weights=torch.int8, activations=torch.int8) pipeline.set_progress_bar_config(disable=True) return pipeline ```

The image is just random noise with "weight & activations". Maybe it needs calibration but I didn't try it out. Here's a comparison:

Visual comparison
vanilla fp16 weight-only weight & activations
Vanilla FP16 Weight-only Weight & Activations

Note that I tried disabling activation quantization for the VAE, but that didn't help prevent the issue, either.

Now, I shifted gears to the UNet to test a single model component. Here is my script. Findings are below.

setting timing (secs) memory (gb)
vanilla fp16 0.069 5.053
weight-only 0.170 5.094
weight & activations 5.094 10.032

Let me know if anything is unclear. Also, let me know if you would like me to run other tests.

dacorvo commented 9 months ago

@sayakpaul Thank you very much for your tests and feedback.

The results are as expected, but you missed two important things in your tests (I will update the README to make it clearer).

  1. When quantizing activations, you need to calibrate the model with a few samples to find out the correct activation range (otherwise [-1, 1] is assumed). This explains the garbage output you currently get.

You just need to pass a few samples through forward.

with calibration():
    model(samples)
  1. Weights are dynamically quantized by default (i.e. stored in float and quantized on inference) to allow fine-tuning. If you just want to do an inference, you can freeze your model to convert them to integer. This will save on-device memory and significantly speed-up the inference.
freeze(model)

I would be happy to find out the results you get after applying these two changes.

sayakpaul commented 9 months ago

Thanks much! Let me get back to you after applying these changes.

sayakpaul commented 9 months ago

Alright. Things seem to be better now (script here):

Timing and memory:

setting timing (secs) memory (gb)
vanilla fp16 4.016 8.958
weight-only 4.614 6.729
weight & activations 11.582 13.375

Observations / questions:

Coming to the visual quality:

vanilla fp16 weight-only weight & activations
Vanilla FP16 Weight-only Weight & Activations

I think I am not doing the calibration properly which is why the quality is still degraded. What would you advise?

dacorvo commented 9 months ago

@sayakpaul this is indeed much better.

Regarding the activations:

sayakpaul commented 9 months ago

Using float8 seems to immediately improve the results:

image

Timing-wise, we're much worse now: 15.500 secs.

I applied more calibration to the earlier setup (non-float8) as well:

image

With calibration and float8:

image

I suspect this will be improved with a bit more calibration. WDYT?

Relevant calibration code ```python CALIBRATION_PROMPTS = load_dataset("nateraw/parti-prompts", split="train").shuffle(seed=2024).select(range(100)) if args.act_quant: chunk_size = 2 print("Calibrating") with Calibration(): for i in range(0, len(CALIBRATION_PROMPTS), chunk_size): _ = pipeline( CALIBRATION_PROMPTS[i: i + chunk_size]["Prompt"], num_inference_steps=10, generator=torch.manual_seed(2024), ) print("Calibration done.") ```

I suspect the increased timing is due to the dynamic quantization of the activations, which is not compensated by the int8 matmul speedup. Maybe the matrices are too small ?

You may be correct. Actually, this is hinted in https://pytorch.org/blog/accelerating-generative-ai-3/ as well. Let me try out some stuff from there to ensure we're not applying in8 for matrix multiplications that consist of small matrices. Will update this issue thread after running a couple experiments. Will this apply to float8 weight-activation quantization, too? I think yes.

Let me know if you have specific experiments for me run.

dacorvo commented 9 months ago

This is super interesting ! First it confirms that float8 are probably the best candidate type for quantizing activations to 8-bit, due to their non-linear representation and wider range.

Unfortunately, as you noticed, there is no hardware support for float8 on most hardware, so it is slower. I might improve this in the future by using specific kernels.

You can pinpoint which specific parts of a model you want to quantize by passing a list of modules to quantize.

Finally, about calibration, you can try adding more samples, but I think it is more important to pass them as batches, so that they are averaged: if you pass them individually, you give too much importance to the first sample (next samples always contribute as (1 - momentum) to the moving average). If you really need to pass them one by one, you can use a lower momentum in the moving average, but I am not sure it will be better.

sayakpaul commented 9 months ago

You can pinpoint which specific parts of a model you want to quantize by passing a list of modules to quantize.

Could you provide an example here?

Regarding calibration batches, let’s not forget we’re in the image generation space :D. So, asking a model to generate four 1024x1024 images is just infeasible for most consumer GPUs.

So, it appears to me that quantization still remains a challenging thing in the diffusion world.

I will try to selectively quantize some modules where the matrix shapes are larger as we did in PyTorch post (weight-only quant) and share my findings here.

dacorvo commented 9 months ago

Example of selecting module:

quantize(model, modules=[model.lm_head])
sayakpaul commented 9 months ago

I enabled fusion of the QKV projection matrices to increase the size of the matrices to see if the quantization speed-up becomes more evident. However, I am running into:

Traceback (most recent call last):
  File "/home/sayak/brrr_quanto_diffusers.py", line 72, in <module>
    _ = pipeline(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1216, in __call__
    noise_pred = self.unet(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1121, in forward
    sample, res_samples = downsample_block(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 1199, in forward
    hidden_states = attn(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_2d.py", line 391, in forward
    hidden_states = block(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/attention.py", line 329, in forward
    attn_output = self.attn1(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/attention_processor.py", line 512, in forward
    return self.processor(
  File "/home/sayak/diffusers/src/diffusers/models/attention_processor.py", line 1335, in __call__
    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  File "/home/sayak/quanto/quanto/calibrate.py", line 49, in __torch_function__
    output = func(*args, **kwargs)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 291, in __torch_function__
    return func(*args, **kwargs)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 309, in __torch_dispatch__
    return qdispatch.qop(*args, **kwargs)
  File "/home/sayak/quanto/quanto/tensor/ops.py", line 317, in view
    return qfallback(op, input, *shape)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 160, in qfallback
    args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {}))
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 858, in tree_map_only
    return tree_map(map_only(__type_or_types)(func), tree)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 732, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 599, in unflatten
    leaves = list(leaves)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 818, in wrapped
    return func(x)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 160, in <lambda>
    args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {}))
  File "/home/sayak/quanto/quanto/tensor/core.py", line 259, in dequantize
    return Dequantizer.apply(self)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/sayak/quanto/quanto/tensor/core.py", line 202, in forward
    return (t._scale.to(torch.float32) * t._data).to(t._scale.dtype)
RuntimeError: The size of tensor a (1920) must match the size of tensor b (640) at non-singleton dimension 2

Is this expected?

dacorvo commented 9 months ago

Not really: basically the QKV were quantized per-axis/channel with vector scales of shape (640) corresponding to their last dim, but were later concatenated to a tensor whose last dim is (1920) = (640) x 3 (I think). This should have triggered also a concatenation of the scales, but it didn't because the corresponding dispatched op for quantized tensors has not been updated since older versions of quanto and always assumes scalar scales. https://github.com/huggingface/quanto/blob/5c2a4114eab4dc37208fbc9fff35143453a2017e/quanto/tensor/ops.py#L89 Sorry you hit that bug ...

sayakpaul commented 9 months ago

Ah cool. Seems like I should wait for a fix?

sayakpaul commented 9 months ago

This is my pipeline loading function, btw:

def load_pipeline(do_quantize, act_quant):
    pipeline = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        variant="fp16",
        torch_dtype=torch.float32 if act_quant else torch.float16
    ).to("cuda")
    pipeline.vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32 if act_quant else torch.float16
    ).to("cuda")
    pipeline.fuse_qkv_projections()

    if do_quantize:
        quantize(pipeline.unet, weights=torch.int8, activations=torch.float8_e4m3fn if act_quant else None)
        quantize(pipeline.vae, weights=torch.int8, activations=torch.float8_e4m3fn if act_quant else None)

        freeze(pipeline.unet)
        freeze(pipeline.vae)

    pipeline.set_progress_bar_config(disable=True)
    return pipeline

I don't think changing the order in which fuse_qkv_projections() is called will make sense?

dacorvo commented 8 months ago

Closing as it has been merged in #91