Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.56k stars 3.4k forks source link

Fail to quantize SDXL model with lightning codes #19169

Open moonlightian opened 11 months ago

moonlightian commented 11 months ago

Bug description

It seems not good to use BitsandbytesPrecision directly as shown at front pages of Lightning.

what should I do to quantize SDXL and make it saved after quantization? Codes and bugs are shown below

What version are you seeing the problem on?

v2.1

How to reproduce the bug

from lightning.fabric import Fabric
from lightning.fabric.plugins import BitsandbytesPrecision
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import torch
import time
path = '/mnt/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/bf714989e22c57ddc1c453bf74dab4521acb81d8/'
prompt = "hyperrealistic glamour portrait of an old weary wizard surrounded by elemental magic, arcane, freckles, skin pores, pores, velus hair, macro, extreme details, looking at viewer"
negative_prompt = "sketch, cartoon, drawing, anime:1.4, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions"

torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
with torch.inference_mode():
    pipe = StableDiffusionXLPipeline.from_pretrained(
      path, torch_dtype=torch.float16, use_safetensors=True
    )
    pipe.to(f"cuda")
    pipe.unet.to(device=f"cuda", dtype=torch.float16, memory_format=torch.channels_last)
    img = pipe(prompt=prompt,negative_prompt=negative_prompt, num_inference_steps=50, guidance_scale = 9, num_images_per_prompt=1).images[0]
    img.save(f"image.png")

mode = "nf4"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)
model = fabric.setup_module(pipe) # quantizes the layers

Error messages and logs

Bugs are like this:

Traceback (most recent call last):
  File "/tmp/lighting/quant.py", line 52, in <module>
  File "/opt/conda/lib/python3.9/site-packages/lightning/fabric/fabric.py", line 289, in setup_module
    module = self._precision.convert_module(module)
  File "/opt/conda/lib/python3.9/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 101, in convert_module
    if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
  File "/opt/conda/lib/python3.9/site-packages/diffusers/configuration_utils.py", line 137, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'StableDiffusionXLPipeline' object has no attribute 'modules'

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

awaelchli commented 11 months ago

@moonlightian Fabric's setup only accepts torfch.nn.Module models. You get the error because StableDiffusionXLPipeline is not a PyTorch module.

I looked a bit but I couldn't find an easy way to apply bitsandbytes to this pipeline.

moonlightian commented 11 months ago

I looked a bit but I couldn't find an easy way to apply bitsandbytes to this pipeline.

Thank you for your response, but I still want to know if there is a way to implement quantization operations related to SDXL as shown on the webpage. image

moonlightian commented 11 months ago

And I found it would work if I quantize the unet with

quantized_unet = fabric.setup_module(pipe.unet)
pipe.unet = quantized_unet

while it would not be good to replace SDXL's origin unet with quantized one when doing inference , because of the dismatch of the data type for input and output. What should I do if I want to quantize a model module by module?