Closed sayakpaul closed 3 months ago
Serialization code:
import torch
import tempfile
from torchao.utils import get_model_size_in_bytes
from torchao.quantization.quant_api import quantize_, int8_weight_only
from diffusers import PixArtTransformer2DModel
def get_inputs(dtype=torch.float32, device="cpu"):
hidden_states = torch.randn(1, 4, 128, 128).to(dtype=dtype, device=device)
timestep = torch.tensor([1]).to(device)
encoder_hidden_states = torch.randn(1, 32, 4096).to(dtype=dtype, device=device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"added_cond_kwargs": {"aspect_ratio": None, "resolution": None}
}
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
model = PixArtTransformer2DModel.from_pretrained(ckpt_id, subfolder="transformer", torch_dtype=dtype).to(device)
print(f"original model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
quantize_(model, int8_weight_only())
print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
example_inputs = get_inputs(dtype=dtype, device=device)
ref = model(**example_inputs).sample
with tempfile.NamedTemporaryFile() as f:
torch.save(model.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
with torch.device("meta"):
config = PixArtTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
model_loaded = PixArtTransformer2DModel.from_config(config).to(dtype)
# `linear.weight` is nn.Parameter, so we check the type of `linear.weight.data`
print(f"type of weight before loading: {type(model_loaded.proj_out.weight.data)=}")
model_loaded.load_state_dict(state_dict, assign=True)
model_loaded.to(device)
print(f"type of weight after loading: {type(model_loaded.proj_out.weight)=}")
res = model_loaded(**example_inputs).sample
assert torch.equal(res, ref)
Output:
original model size: 1174.1111450195312 MB
quantized model size: 596.2789306640625 MB
type of weight before loading: type(model_loaded.proj_out.weight.data)=<class 'torch.Tensor'>
type of weight after loading: type(model_loaded.proj_out.weight)=<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>
Serialize:
import torch
from torchao.quantization.quant_api import quantize_, int8_weight_only
from diffusers import PixArtTransformer2DModel
dtype = torch.bfloat16
ckpt_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
model = PixArtTransformer2DModel.from_pretrained(ckpt_id, subfolder="transformer", torch_dtype=dtype)
quantize_(model, int8_weight_only())
# should ideally be possible with safetensors but
# https://github.com/huggingface/safetensors/issues/515
torch.save(model.state_dict(), "pixart_sigma_int8.pt")
Inference
import torch
from diffusers import PixArtTransformer2DModel, PixArtSigmaPipeline
dtype, device = torch.bfloat16, "cuda"
ckpt_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
with torch.device("meta"):
config = PixArtTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
model = PixArtTransformer2DModel.from_config(config).to(dtype)
state_dict = torch.load("pixart_sigma_int8.pt", map_location="cpu")
model.load_state_dict(state_dict, assign=True)
pipeline = PixArtSigmaPipeline.from_pretrained(ckpt_id, transformer=model, torch_dtype=dtype).to("cuda")
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipeline(prompt).images[0]
image.save("pixart_int8.png")
Result:
Cc: @jerryzh168 in case you have any feedback. Otherwise, will close this issue.
@sayakpaul thanks, this looks great! safetensor or huggingface non-safetensor serialization are both not supported yet, I'm looking into it
Alright. Thanks. Closing this then.
https://github.com/pytorch/ao/blob/main/docs/source/serialization.rst