sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
228 stars 7 forks source link

qkv_fuse_projections() fails with torchao quantized Flux2DTransformerModel #37

Open ngaloppo opened 17 hours ago

ngaloppo commented 17 hours ago

Summary

When calling qkv_fuse_projections() on an instance of Flux2DTransformerModel that was quantized with torchao's quantize_, it fails with the following error:

File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/torchao/utils.py", line 389, in _dispatch__torch_dispatch__
    raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: aten.cat.default

How to reproduce

Run the script below with the following commandline:

python run_flux.py --seed 23 -n 4 --width 512 --height 512 --int8
Code (save as run_flux.py) ```python import torch from torchao.quantization import quantize_, int8_weight_only import argparse from diffusers import FluxPipeline, DiffusionPipeline, FluxTransformer2DModel from time import perf_counter from pathlib import Path from statistics import mean, median, pstdev def load_pipe(args) -> DiffusionPipeline: dtype = torch.bfloat16 pipe = FluxPipeline.from_pretrained(args.model_id, torch_dtype=dtype) if args.int8: quantized_path = Path(f"flux-quantized-transformer_{args.model_id.replace('/','_')}") if not quantized_path.exists(): quantized_path.mkdir() print("Quantizing...") quantize_(pipe.transformer, int8_weight_only()) print(f"Saving to {quantized_path}...") pipe.transformer.save_pretrained(quantized_path, safe_serialization=False) else: print(f"Quantized model loading from {quantized_path}...!") with torch.device('meta'): pipe.transformer = FluxTransformer2DModel.from_pretrained(quantized_path, torch_dtype=dtype, use_safetensors=False) if args.fp16: print("Running in float16 mode...") pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.to(torch.float16) pipe.to(device=args.device) # Doesn't seem to work together with torchao pipe.transformer.fuse_qkv_projections() return pipe def benchmark(pipe, args): seed = args.seed device = args.device if args.bench: # warmup print("Running in benchmark mode, performing a warmup run (single diffusion iteration)...") image = pipe( args.prompt, output_type="pil", num_inference_steps=1, width = args.width, height = args.height, generator=torch.Generator(device).manual_seed(seed) ).images[0] num_iters = 1 if not args.bench else args.bench elapsed = [] for _ in range(num_iters): t0 = perf_counter() image = pipe( args.prompt, output_type="pil", num_inference_steps=args.n, max_sequence_length=256 if "schnell" in args.model_id else 512, width = args.width, height = args.height, generator=torch.Generator(device).manual_seed(seed) ).images[0] elapsed.append(perf_counter() - t0) if args.bench: mean_elapsed = mean(elapsed) stdev_elapsed = pstdev(elapsed) print(f"Average elapsed time: {mean_elapsed:.3f} s (± {stdev_elapsed} s)") print(f"MPS memory usage: {torch.mps.driver_allocated_memory() / 1024 / 1024 / 1024:.3f} GiB.") if args.output and image: image.save(args.output) if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--prompt", type=str, default="a photo of a cute cat being threatened by darth vader in the back") parser.add_argument("--model-id", type=str, default="black-forest-labs/FLUX.1-schnell") parser.add_argument("-n", type=int, help="Number of diffusion steps", default=4) parser.add_argument("--seed", type=int, help="RNG seed", default=23) parser.add_argument("--width", type=int, help="Output width", default=1024) parser.add_argument("--height", type=int, help="Output height", default=1024) parser.add_argument("--fp16", action="store_true", help="Enable float16 inference") parser.add_argument("--int8", action="store_true", help="Enable int8 quantization") parser.add_argument("-o", "--output", type=str, help="Save to output file", default="flux-output.png") parser.add_argument("--bench", type=int, help="Number of images to benchmark. Only benchmark when specified.") args = parser.parse_args() args.device = "cpu" pipe = load_pipe(args) benchmark(pipe, args) ```

cc: https://github.com/huggingface/diffusers/pull/9185#issuecomment-2412893029

cpuhrsch commented 16 hours ago

Is it an open to call qkv_fuse_projections first and then quantize_?

ngaloppo commented 14 hours ago

Ha! That seems to work. Thanks! We may want to document that, or provide some kind of example of it somewhere, either in this repo, or in diffusers docs.

ngaloppo commented 14 hours ago

The attached code is still not working completely for me. When quantizing at runtime, things go well. When loading the serialized fused+quantized model from disk, I'm getting the following error message:

Taking `'Attention' object has no attribute 'to_qkv'` while using `accelerate.load_checkpoint_and_dispatch` to mean flux-quantized-transformer_black-forest-labs_FLUX.1-schnell was saved with deprecated attention block weight names. We will load it with the deprecated attention block names and convert them on the fly to the new attention block format. Please re-save the model after this conversion, so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint, please also re-upload it or open a PR on the original repository.
Traceback (most recent call last):
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 774, in from_pretrained
    accelerate.load_checkpoint_and_dispatch(
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
    load_checkpoint_in_model(
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 1878, in load_checkpoint_in_model
    set_module_tensor_to_device(
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 334, in set_module_tensor_to_device
    new_module = getattr(module, split)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Attention' object has no attribute 'to_qkv'. Did you mean: 'to_k'?

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/sysperf/code/flux/run_flux.py", line 100, in <module>
    pipe = load_pipe(args)
           ^^^^^^^^^^^^^^^
  File "/Users/sysperf/code/flux/run_flux.py", line 29, in load_pipe
    pipe.transformer = FluxTransformer2DModel.from_pretrained(quantized_path, torch_dtype=dtype, use_safetensors=False)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 804, in from_pretrained
    accelerate.load_checkpoint_and_dispatch(
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
    load_checkpoint_in_model(
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 1878, in load_checkpoint_in_model
    set_module_tensor_to_device(
  File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 440, in set_module_tensor_to_device
    new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: AffineQuantizedTensor.__new__() got an unexpected keyword argument 'requires_grad'
cpuhrsch commented 13 hours ago

Can you try this under context with torch.no_grad:? Assuming you're using this for inference.

sayakpaul commented 10 hours ago

Thanks for the discussions, here. @ngaloppo would be helpful if you could maybe open a PR?