sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
274 stars 8 forks source link

Error with FP8 compilation #6

Closed sayakpaul closed 3 months ago

sayakpaul commented 3 months ago

@drisspg getting an error when trying to run with dynamic fp8 quantization:

python benchmark_pixart.py --compile --quantization=fp8
 File "/fsx/sayak/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 865, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
  File "/fsx/sayak/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 321, in decode
    decoded = self._decode(z).sample
  File "/fsx/sayak/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 292, in _decode
    dec = self.decoder(z)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/autoencoders/vae.py", line 332, in forward
    sample = self.mid_block(sample, latent_embeds)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 738, in forward
    hidden_states = attn(hidden_states, temb=temb)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_processor.py", line 2191, in __call__
    query = attn.to_q(hidden_states)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/ao/torchao/float8/inference.py", line 112, in forward
    return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
  File "/fsx/sayak/ao/torchao/float8/float8_tensor.py", line 360, in __torch_dispatch__
    raise NotImplementedError(f"attempting to run {func}, this is not supported")
NotImplementedError: attempting to run aten.expand.default, this is not supported

I am on torch latest nightly as well as latest torchao.

I am running this on H100.

drisspg commented 3 months ago

Interesting, I was able to run your benchmark script on today. Benchmark worked with dynamic and showed a slight speed up, still need to look at the traces.

I was able to reproduce your error with static quant.

sayakpaul commented 3 months ago

Could you try with the latest benchmark code? I am still hitting this error with latest nightly and latest ao

drisspg commented 3 months ago

yeah will try

drisspg commented 3 months ago

On Pixart Bsz 1 I am getting: Compile + static quant FP8 Quant:

|                ckpt_id                 |   batch_size |  fuse  |  compile  |  quantization  |  sparsify  |   memory |   time |
|:--------------------------------------:|-------------:|:------:|:---------:|:--------------:|:----------:|---------:|-------:|
| PixArt-alpha/PixArt-Sigma-XL-2-1024-MS |            1 | False  |   True    |      fp8       |   False    |    9.672 |  1.242 |

Compile:

|                ckpt_id                 |   batch_size |  fuse  |  compile  |  quantization  |  sparsify  |   memory |   time |
|:--------------------------------------:|-------------:|:------:|:---------:|:--------------:|:----------:|---------:|-------:|
| PixArt-alpha/PixArt-Sigma-XL-2-1024-MS |            1 | False  |   True    |      None      |   False    |   10.211 |  1.353 |

Diff:

diff --git a/inference/benchmark_pixart.py b/inference/benchmark_pixart.py
index 64e0ff9..353813e 100644
--- a/inference/benchmark_pixart.py
+++ b/inference/benchmark_pixart.py
@@ -82,8 +82,8 @@ def load_pipeline(
             quantize_(pipeline.transformer, fp6_llm_weight_only())
             quantize_(pipeline.vae, fp6_llm_weight_only())
         elif quantization == "fp8":
-            pipeline.transformer = quantize_to_float8(pipeline.transformer, QuantConfig(ActivationCasting.DYNAMIC))
-            pipeline.vae = quantize_to_float8(pipeline.vae, QuantConfig(ActivationCasting.DYNAMIC))
+            pipeline.transformer = quantize_to_float8(pipeline.transformer, QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], dtype=torch.float32, device="cuda")))
+            # pipeline.vae = quantize_to_float8(pipeline.vae, QuantConfig(ActivationCasting.DYNAMIC), module_filter_fn=module_fn)
         elif quantization == "autoquant":
             pipeline.transformer = autoquant(pipeline.transformer)
             pipeline.vae = autoquant(pipeline.vae)
drisspg commented 3 months ago

The original error is from because we dont currently support bmm

sayakpaul commented 3 months ago

Yeah, I was able to run with static FP8 quant before without quantizing the VAE, which is what you seem to be doing as well. That is known. Thanks for looking into it.