Open ngaloppo opened 17 hours ago
Is it an open to call qkv_fuse_projections
first and then quantize_
?
Ha! That seems to work. Thanks! We may want to document that, or provide some kind of example of it somewhere, either in this repo, or in diffusers
docs.
The attached code is still not working completely for me. When quantizing at runtime, things go well. When loading the serialized fused+quantized model from disk, I'm getting the following error message:
Taking `'Attention' object has no attribute 'to_qkv'` while using `accelerate.load_checkpoint_and_dispatch` to mean flux-quantized-transformer_black-forest-labs_FLUX.1-schnell was saved with deprecated attention block weight names. We will load it with the deprecated attention block names and convert them on the fly to the new attention block format. Please re-save the model after this conversion, so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint, please also re-upload it or open a PR on the original repository.
Traceback (most recent call last):
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 774, in from_pretrained
accelerate.load_checkpoint_and_dispatch(
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
load_checkpoint_in_model(
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 1878, in load_checkpoint_in_model
set_module_tensor_to_device(
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 334, in set_module_tensor_to_device
new_module = getattr(module, split)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Attention' object has no attribute 'to_qkv'. Did you mean: 'to_k'?
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/sysperf/code/flux/run_flux.py", line 100, in <module>
pipe = load_pipe(args)
^^^^^^^^^^^^^^^
File "/Users/sysperf/code/flux/run_flux.py", line 29, in load_pipe
pipe.transformer = FluxTransformer2DModel.from_pretrained(quantized_path, torch_dtype=dtype, use_safetensors=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 804, in from_pretrained
accelerate.load_checkpoint_and_dispatch(
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
load_checkpoint_in_model(
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 1878, in load_checkpoint_in_model
set_module_tensor_to_device(
File "/Users/sysperf/miniforge3/envs/flux/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 440, in set_module_tensor_to_device
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: AffineQuantizedTensor.__new__() got an unexpected keyword argument 'requires_grad'
Can you try this under context with torch.no_grad:
? Assuming you're using this for inference.
Thanks for the discussions, here. @ngaloppo would be helpful if you could maybe open a PR?
Summary
When calling
qkv_fuse_projections()
on an instance ofFlux2DTransformerModel
that was quantized withtorchao
'squantize_
, it fails with the following error:How to reproduce
Run the script below with the following commandline:
Code (save as run_flux.py)
```python import torch from torchao.quantization import quantize_, int8_weight_only import argparse from diffusers import FluxPipeline, DiffusionPipeline, FluxTransformer2DModel from time import perf_counter from pathlib import Path from statistics import mean, median, pstdev def load_pipe(args) -> DiffusionPipeline: dtype = torch.bfloat16 pipe = FluxPipeline.from_pretrained(args.model_id, torch_dtype=dtype) if args.int8: quantized_path = Path(f"flux-quantized-transformer_{args.model_id.replace('/','_')}") if not quantized_path.exists(): quantized_path.mkdir() print("Quantizing...") quantize_(pipe.transformer, int8_weight_only()) print(f"Saving to {quantized_path}...") pipe.transformer.save_pretrained(quantized_path, safe_serialization=False) else: print(f"Quantized model loading from {quantized_path}...!") with torch.device('meta'): pipe.transformer = FluxTransformer2DModel.from_pretrained(quantized_path, torch_dtype=dtype, use_safetensors=False) if args.fp16: print("Running in float16 mode...") pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.to(torch.float16) pipe.to(device=args.device) # Doesn't seem to work together with torchao pipe.transformer.fuse_qkv_projections() return pipe def benchmark(pipe, args): seed = args.seed device = args.device if args.bench: # warmup print("Running in benchmark mode, performing a warmup run (single diffusion iteration)...") image = pipe( args.prompt, output_type="pil", num_inference_steps=1, width = args.width, height = args.height, generator=torch.Generator(device).manual_seed(seed) ).images[0] num_iters = 1 if not args.bench else args.bench elapsed = [] for _ in range(num_iters): t0 = perf_counter() image = pipe( args.prompt, output_type="pil", num_inference_steps=args.n, max_sequence_length=256 if "schnell" in args.model_id else 512, width = args.width, height = args.height, generator=torch.Generator(device).manual_seed(seed) ).images[0] elapsed.append(perf_counter() - t0) if args.bench: mean_elapsed = mean(elapsed) stdev_elapsed = pstdev(elapsed) print(f"Average elapsed time: {mean_elapsed:.3f} s (± {stdev_elapsed} s)") print(f"MPS memory usage: {torch.mps.driver_allocated_memory() / 1024 / 1024 / 1024:.3f} GiB.") if args.output and image: image.save(args.output) if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--prompt", type=str, default="a photo of a cute cat being threatened by darth vader in the back") parser.add_argument("--model-id", type=str, default="black-forest-labs/FLUX.1-schnell") parser.add_argument("-n", type=int, help="Number of diffusion steps", default=4) parser.add_argument("--seed", type=int, help="RNG seed", default=23) parser.add_argument("--width", type=int, help="Output width", default=1024) parser.add_argument("--height", type=int, help="Output height", default=1024) parser.add_argument("--fp16", action="store_true", help="Enable float16 inference") parser.add_argument("--int8", action="store_true", help="Enable int8 quantization") parser.add_argument("-o", "--output", type=str, help="Save to output file", default="flux-output.png") parser.add_argument("--bench", type=int, help="Number of images to benchmark. Only benchmark when specified.") args = parser.parse_args() args.device = "cpu" pipe = load_pipe(args) benchmark(pipe, args) ```cc: https://github.com/huggingface/diffusers/pull/9185#issuecomment-2412893029