Xiuyu-Li / q-diffusion

[ICCV 2023] Q-Diffusion: Quantizing Diffusion Models.
https://xiuyuli.com/qdiffusion/
MIT License
301 stars 21 forks source link

Calibration Process in the 'resume_cali_model' Function #5

Closed wangjialinEcopia closed 11 months ago

wangjialinEcopia commented 1 year ago

Hello, I have a question about the code you provided. In the resume_cali_model function, I noticed there are some steps related to calibration. I wanted to confirm if this code indeed includes a calibration process and which parts specifically contribute to the calibration. Could you please clarify if there is a calibration process in the code, and if so, which sections are responsible for calibration? Thank you!

Xiuyu-Li commented 1 year ago

Hi, resume_cali_model refers to "loading the checkpoint of the calibrated quantized model". We have not yet released the code for the calibration process, but we are planning to do so in the future. Stay tuned!

wangjialinEcopia commented 1 year ago

Thank you for the information. I look forward to the release of the code for the calibration process. I'll stay tuned for updates. Keep up the good work!

wangjialinEcopia commented 1 year ago

I noticed that there is an evaluation of FID results mentioned in the text, but the script provided for # CIFAR-10 (DDIM) seems to only generate images without incorporating the computation and evaluation process. Could you please consider sharing the code for this evaluation part as well?

Xiuyu-Li commented 1 year ago

We use torch-fidelity to calculate the IS and FID scores. I will try writing a script with instructions on evaluating FIDs later. For now, you can refer to the following snippet to reproduce the results for CIFAR-10:

import torch
import torch_fidelity

torch.set_grad_enabled(False)

metrics_dict = torch_fidelity.calculate_metrics(
        input1=img_path, # fill this with your generated images path
        input2="cifar10-train",
        batch_size=256, 
        cuda=True, 
        isc=True, 
        fid=True, 
        kid=False,
        cache_root=cache_root, # fill this with your own path
        cache=True,
        verbose=True,
        samples_find_deep=True,
        samples_find_ext="jpg,jpeg,png,webp"
    )

For LSUN, the raw datasets and corresponding splits need to be loaded to generate the reference batch. For example, this process for the bedrooms dataset is something like this:

from PIL import Image
import numpy as np
import torch
import torch_fidelity
from ldm.data import lsun

class LSUNBaseTensor(lsun.LSUNBase):
    def __getitem__(self, i):
        example = dict((k, self.labels[k][i]) for k in self.labels)
        image = Image.open(example["file_path_"])
        if not image.mode == "RGB":
            image = image.convert("RGB")

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)
        crop = min(img.shape[0], img.shape[1])
        h, w, = img.shape[0], img.shape[1]
        img = img[(h - crop) // 2:(h + crop) // 2,
              (w - crop) // 2:(w + crop) // 2]

        image = Image.fromarray(img)
        if self.size is not None:
            image = image.resize((self.size, self.size), resample=self.interpolation)

        # Since flip_p is set to 0, we do not flip and remove this function call.
        # image = self.flip(image)
        image = np.array(image).astype(np.uint8)

        # Tensor with dtype uint8, original shape: (256, 256, 3), output shape: (3, 256, 256)
        example = torch.tensor(image).permute(2, 0, 1)
        return example

    def __len__(self):
        # Perform evaluation on subsamples
        # return 500000
        return super().__len__()

class LSUNBedroomsTensorTrain(LSUNBaseTensor):
    def __init__(self, **kwargs):
        super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)

lsun_beds256_train = lambda root, download: LSUNBedroomsTensorTrain(
                size=256,
                interpolation="bicubic",
                flip_p=0.)

torch_fidelity.register_dataset('lsun_beds256_train', lsun_beds256_train)

And then using the CIFAR-10 script with input2 set to lsun_beds256_train will do the job. Hope this helps!