huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
746 stars 55 forks source link

fp8 leads to black images (numerical instabilities) for transformer diffusion models #231

Closed sayakpaul closed 1 month ago

sayakpaul commented 2 months ago

I am on the main of quanto.

from diffusers import DiffusionPipeline
import torch
from optimum.quanto import quantize, freeze, qfloat8_e4m3fn

pipeline_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart"
pipeline = DiffusionPipeline.from_pretrained(
    pipeline_id,
    torch_dtype=torch.float16
).to("cuda")

quantize(pipeline.transformer, weights=qfloat8_e4m3fn)
freeze(pipeline.transformer)

image = pipe(
    prompt="ghibli style, a fantasy landscape with castles",
    num_images_per_prompt=1,
    generator=torch.manual_seed(0),
).images[0]

leads to a black image

sayakpaul commented 1 month ago

@dacorvo a gentle ping :)

dacorvo commented 1 month ago

I reproduce the issue. Somewhere in the third transformer block nan outputs are generated. I suspect this comes from one the quantized weights after dequantization (and not from a torch.float16 operation). I am checking them one by one to find out which one.

dacorvo commented 1 month ago

It works well with int8 though ...

dacorvo commented 1 month ago

Somewhere in the third transformer block nan outputs are generated. I suspect this comes from one the quantized weights after dequantization (and not from a torch.float16 operation) So I was wrong: this was indeed an overflow in a torch.float16 matmul, but only because I applied the float8 quantization scale after the matmul. By applying the scale to the float8 weights before the matmul they are back in the numerical range they had before quantization and the overflow disappears.

Here are the images generated with a float16 transformer and a float8 transformer (very subtle differences in the towers and the landscape is a bit different).

bs@1-dtype@fp16-qtype@fp16

bs@1-dtype@fp16-qtype@fp8

dacorvo commented 1 month ago

I am pushing the fix with a working example. I also quantized the text_encoder to make a real difference in size.