sayakpaul / diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Apache License 2.0
274 stars 8 forks source link

End-to-end example on serializing and loading #3

Closed sayakpaul closed 3 months ago

sayakpaul commented 3 months ago

https://github.com/pytorch/ao/blob/main/docs/source/serialization.rst

sayakpaul commented 3 months ago

Serialization code:

import torch
import tempfile
from torchao.utils import get_model_size_in_bytes
from torchao.quantization.quant_api import quantize_, int8_weight_only
from diffusers import PixArtTransformer2DModel

def get_inputs(dtype=torch.float32, device="cpu"):
    hidden_states = torch.randn(1, 4, 128, 128).to(dtype=dtype, device=device)
    timestep = torch.tensor([1]).to(device)
    encoder_hidden_states = torch.randn(1, 32, 4096).to(dtype=dtype, device=device)
    return {
        "hidden_states": hidden_states,
        "timestep": timestep,
        "encoder_hidden_states": encoder_hidden_states,
        "added_cond_kwargs": {"aspect_ratio": None, "resolution": None}
    }

dtype, device = torch.bfloat16, "cuda"

ckpt_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
model = PixArtTransformer2DModel.from_pretrained(ckpt_id, subfolder="transformer", torch_dtype=dtype).to(device)

print(f"original model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")

quantize_(model, int8_weight_only())
print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")

example_inputs = get_inputs(dtype=dtype, device=device)
ref = model(**example_inputs).sample
with tempfile.NamedTemporaryFile() as f:
    torch.save(model.state_dict(), f)
    f.seek(0)
    state_dict = torch.load(f)

with torch.device("meta"):
    config = PixArtTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
    model_loaded = PixArtTransformer2DModel.from_config(config).to(dtype)

# `linear.weight` is nn.Parameter, so we check the type of `linear.weight.data`
print(f"type of weight before loading: {type(model_loaded.proj_out.weight.data)=}")
model_loaded.load_state_dict(state_dict, assign=True)
model_loaded.to(device)
print(f"type of weight after loading: {type(model_loaded.proj_out.weight)=}")

res = model_loaded(**example_inputs).sample
assert torch.equal(res, ref)

Output:

original model size: 1174.1111450195312 MB
quantized model size: 596.2789306640625 MB

type of weight before loading: type(model_loaded.proj_out.weight.data)=<class 'torch.Tensor'>
type of weight after loading: type(model_loaded.proj_out.weight)=<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>
sayakpaul commented 3 months ago

End-to-end example

Serialize:

import torch
from torchao.quantization.quant_api import quantize_, int8_weight_only
from diffusers import PixArtTransformer2DModel

dtype = torch.bfloat16

ckpt_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
model = PixArtTransformer2DModel.from_pretrained(ckpt_id, subfolder="transformer", torch_dtype=dtype)
quantize_(model, int8_weight_only())

# should ideally be possible with safetensors but
# https://github.com/huggingface/safetensors/issues/515
torch.save(model.state_dict(), "pixart_sigma_int8.pt")

Inference

import torch
from diffusers import PixArtTransformer2DModel, PixArtSigmaPipeline

dtype, device = torch.bfloat16, "cuda"
ckpt_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"

with torch.device("meta"):
    config = PixArtTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
    model = PixArtTransformer2DModel.from_config(config).to(dtype)

state_dict = torch.load("pixart_sigma_int8.pt", map_location="cpu")
model.load_state_dict(state_dict, assign=True)

pipeline = PixArtSigmaPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipeline(prompt).images[0]
image.save("pixart_int8.png")

Result:

sayakpaul commented 3 months ago

Cc: @jerryzh168 in case you have any feedback. Otherwise, will close this issue.

jerryzh168 commented 3 months ago

@sayakpaul thanks, this looks great! safetensor or huggingface non-safetensor serialization are both not supported yet, I'm looking into it

sayakpaul commented 3 months ago

Alright. Thanks. Closing this then.