ChenWu98 / cycle-diffusion

[ICCV 2023] A latent space for stochastic diffusion models
Other
560 stars 35 forks source link

train the unpaired image-to-image translation on one GPU #9

Closed JunMa11 closed 1 year ago

JunMa11 commented 1 year ago

Thanks for sharing the great work!

How to train the unpaired image-to-image translation on one GPU?

export CUDA_VISIBLE_DEVICES=1
export RUN_NAME=translate_afhqcat256_to_afhqdog256_ddim_eta01
export SEED=42
nohup python -m torch.distributed.launch --nproc_per_node 1 --master_port 1446 main.py --seed $SEED --cfg experiments/$RUN_NAME.cfg --run_name $RUN_NAME$SEED --logging_strategy steps --logging_first_step true --logging_steps 4 --evaluation_strategy steps --eval_steps 50 --metric_for_best_model CLIPEnergy --greater_is_better false --save_strategy steps --save_steps 50 --save_total_limit 1 --load_best_model_at_end --gradient_accumulation_steps 4 --num_train_epochs 0 --adafactor false --learning_rate 1e-3 --do_eval --output_dir output/$RUN_NAME$SEED --overwrite_output_dir --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --eval_accumulation_steps 4 --ddp_find_unused_parameters true --verbose true > $RUN_NAME$SEED.log 2>&1 &
ChenWu98 commented 1 year ago

Hi,

These commands are intended to translate cats to dogs. Did you encounter any errors?

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

I want to train a new model to translate horse to zebra on one GPU. How should I use the code to achieve this goal?

ChenWu98 commented 1 year ago

Please follow these steps:

  1. Train a DDPM on horses and a DDPM on zebras using the training code in OpenAI guided diffusion.
  2. Change these lines to support zebra and horse.
  3. Add a condition here to support zebra and horse.
  4. Create a dataloader of the sources images following this file; create a task config file following this file; change this line to support the new task config.

Let me know if you encounter errors except for the first step.

JunMa11 commented 1 year ago

Thank you so much for the detailed guidance.

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

May I ask two minor questions?

  1. Is the code of Algorithm 1 at https://github.com/ChenWu98/cycle-diffusion/blob/main/model/gan_wrapper/ddpm_ddim_wrapper.py ?

image

  1. What are the following outputs when running the cat2dog model?
export CUDA_VISIBLE_DEVICES=1
export RUN_NAME=translate_afhqcat256_to_afhqdog256_ddim_eta01
export SEED=42
nohup python -m torch.distributed.launch --nproc_per_node 1 --master_port 1446 main.py --seed $SEED --cfg experiments/$RUN_NAME.cfg --run_name $RUN_NAME$SEED --logging_strategy steps --logging_first_step true --logging_steps 4 --evaluation_strategy steps --eval_steps 50 --metric_for_best_model CLIPEnergy --greater_is_better false --save_strategy steps --save_steps 50 --save_total_limit 1 --load_best_model_at_end --gradient_accumulation_steps 4 --num_train_epochs 0 --adafactor false --learning_rate 1e-3 --do_eval --output_dir output/$RUN_NAME$SEED --overwrite_output_dir --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --eval_accumulation_steps 4 --ddp_find_unused_parameters true --verbose true > $RUN_NAME$SEED.log 2>&1 &
Rank 0 Trainer build successfully.
at tensor([[[[0.0007]]]], device='cuda:0')
0 196064.25
1 196010.484375
2 196223.25
3 196828.453125
4 196779.8125
5 196672.828125
6 196237.796875
7 196950.9375
8 196427.0
9 195980.59375
10 197091.3125
11 196416.78125
12 195744.375
13 196913.28125
14 197019.4375
15 196915.0
16 197006.171875
ChenWu98 commented 1 year ago
  1. Yes, but in this file, we use two independently trained diffusion models rather than two text conditions.
  2. Each line is the timestep and the norm of the predicted noise. This is used for sanity check but not reported in the paper.
cse891 commented 1 year ago

@ChenWu98 Could you also provide a training dataloader example for afhqcat256.py? I think the input consists of two samples, instead of one sample in the inference stage. Thanks.

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

Thanks for your answer very much. May I ask some follow-up questions?

  1. Where is the code of SDEdit denoing and Cross Attention Control in unpaired image-to-image translation?

image

  1. Based on algorithm 1, the translation is training-free (since the dog and cat generative models have been trained). Why does the code still call the trainer?
INFO:trainer.trainer:  Num examples = 2
INFO:trainer.trainer:  Batch size = 1
  1. For the cat2dog model, the translation should only need cat images, but why does the code also need dog images?
ChenWu98 commented 1 year ago
  1. The code for SDEdit denoising is already in the codebase. You can track it by searching for the corresponding arg refine_steps. Cross Attention Control is only applied in the zero-shot case (because CAC requires text as input), and the code is available on the demo page.
  2. trainer is a wrapper for both training and evaluation, following the convention of Hugging Face Transformers. Please note that the training part is not used here.
  3. Dog images are used for computing the FID for evaluation. The inference code does not require dog images.
JunMa11 commented 1 year ago

Thanks for your detailed explanation very much:)

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

I'm sorry to bother you.

Could you please explain the custom_steps, es_steps and how should we set the two parameters properly?

https://github.com/ChenWu98/cycle-diffusion/blob/57036b535f9f9f34ec803a5050ebec1e86be1e82/model/gan_wrapper/ddpm_ddim_wrapper.py#L319-L320

ChenWu98 commented 1 year ago

Hi, custom_steps and es_steps are T and T_es in Algorithm 1, respectively.

The model is trained with 1000 steps, and custom_steps allows us to skip several steps for faster inference. This is the same as "num_inference_steps" in the diffusers library.

es_step allows us not to encode the noises all the way to the last time step, which helps preserve the content of the original image. If you want the translated image to be close to the original one, a smaller es_step should help. This is similar to the "strength" in the diffusers library.

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

Thanks for your explanation very much.

I wrote a stand-along script based on your repo to test the cycle consistency.

Here is my pipeline:

MODEL_FLAGS="--image_size 256 --num_channels 64 --num_res_blocks 3 --num_heads 1"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 16"

python scripts/image_train.py --data_dir ../CT256 --log_dir ./work_dir/CT256 $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

Original image

ct_ori

Reconstructed image

ct_recon

However, the reconstructed image is very different from the original image.

What would be the possible reason? Any comments are highly appreciated.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 30 20:38:03 2022

@author: jma
"""
#%% load packages
import os
join = os.path.join
import argparse
import yaml
import numpy as np
import torch
# import torchvision.transforms as transforms
import torch.nn.functional as F
from skimage import io
from diffusion import DDPM
from iddpm import i_DDPM
# from ..lib.ddpm_ddim.models.improved_ddpm.script_util import i_DDPM # return model
from diffusion_utils import (
    get_beta_schedule, denoising_step, extract, requires_grad
)
from tqdm import tqdm
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
    create_model_and_diffusion,
    args_to_dict
)
#%%
def read_model_and_diffusion(args, model_path):
    """Reads the latest model from the given directory."""

    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
    )
    model.load_state_dict(dist_util.load_state_dict(model_path, map_location="cpu"))
    model.to(dist_util.dev())
    # if args.use_fp16:
    #     model.convert_to_fp16()
    model.eval()
    return model, diffusion

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).

    Args:
        seed (`int`): The seed to set.
    """
    # random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # set torch benchmark 
    torch.backends.cudnn.benchmark = True

def sample2img(sample):
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous().cpu().numpy()[0]

    return sample

# source: https://github.com/ChenWu98/cycle-diffusion/blob/main/model/gan_wrapper/ddpm_ddim_wrapper.py

#%% reverse: x_{t-1} -> x_t
def denoising_step_with_eps(xt, eps, t, t_next, *,
                            models,
                            logvars,
                            b,
                            sampling_type='ddpm',
                            eta=0.0,
                            learn_sigma=False,
                            hybrid=False,
                            hybrid_config=None,
                            ratio=1.0,
                            out_x0_t=False,
                            ):

    assert eps.shape == xt.shape

    # Compute noise and variance
    if type(models) != list:
        model = models
        et = model(xt, t)
        if et.shape != xt.shape:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
        if learn_sigma:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
            # calculations for posterior q(x_{t-1} | x_t, x_0)
            bt = extract(b, t, xt.shape)
            at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            posterior_variance = bt * (1.0 - at_next) / (1.0 - at)
            # log calculation clipped because the posterior variance is 0 at the
            # beginning of the diffusion chain.
            min_log = torch.log(posterior_variance.clamp(min=1e-6))
            max_log = torch.log(bt)
            frac = (model_var_values + 1) / 2
            logvar = frac * max_log + (1 - frac) * min_log
        else:
            logvar = extract(logvars, t, xt.shape)
    else:
        if not hybrid:
            et = 0
            logvar = 0
            if ratio != 0.0:
                et_i = ratio * models[1](xt, t)
                if learn_sigma:
                    raise NotImplementedError()
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += ratio * extract(logvars, t, xt.shape)
                et += et_i

            if ratio != 1.0:
                et_i = (1 - ratio) * models[0](xt, t)
                if learn_sigma:
                    raise NotImplementedError()
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += (1 - ratio) * extract(logvars, t, xt.shape)
                et += et_i

        else:
            for thr in list(hybrid_config.keys()):
                if t.item() >= thr:
                    et = 0
                    logvar = 0
                    for i, ratio in enumerate(hybrid_config[thr]):
                        ratio /= sum(hybrid_config[thr])
                        et_i = models[i+1](xt, t)
                        if learn_sigma:
                            raise NotImplementedError()
                            et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                            logvar_i = logvar_learned
                        else:
                            logvar_i = extract(logvars, t, xt.shape)
                        et += ratio * et_i
                        logvar += ratio * logvar_i
                    break

    # Compute the next x
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    if t_next.sum() == -t_next.shape[0]:  # if t_next is -1
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    xt_next = torch.zeros_like(xt)
    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        noise = eps
        mask = 1 - (t == 0).float()
        mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
        xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
        xt_next = xt_next.float()

    elif sampling_type == 'ddim':
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()  # predicted x0_t
        if eta == 0:
            xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et
        elif at > (at_next):
            print('Inversion process is only possible with eta = 0')
            raise ValueError
        else:
            c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
            c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
            xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * eps

    if out_x0_t == True:
        return xt_next, x0_t
    else:
        return xt_next

def compute_eps(xt, xt_next, t, t_next, models, sampling_type, b, logvars, eta, learn_sigma):

    assert eta is None or eta > 0
    # Compute noise and variance
    if type(models) != list:
        model = models
        et = model(xt, t)
        if et.shape != xt.shape:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
        if learn_sigma:
            # calculations for posterior q(x_{t-1} | x_t, x_0)
            bt = extract(b, t, xt.shape)
            at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            posterior_variance = bt * (1.0 - at_next) / (1.0 - at)
            # log calculation clipped because the posterior variance is 0 at the
            # beginning of the diffusion chain.
            min_log = torch.log(posterior_variance.clamp(min=1e-6))
            max_log = torch.log(bt)
            frac = (model_var_values + 1) / 2
            logvar = frac * max_log + (1 - frac) * min_log
        else:
            logvar = extract(logvars, t, xt.shape)
    else:
        raise NotImplementedError()

    # Compute the next x
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    assert not t_next.sum() == -t_next.shape[0]  # t_next should never be -1
    assert not t.sum() == 0  # t should never be 0
    at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        # print('torch.exp(0.5 * logvar).sum()', torch.exp(0.5 * logvar).sum())
        eps = (xt_next - mean) / torch.exp(0.5 * logvar)

    elif sampling_type == 'ddim':
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()  # predicted x0_t

        c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
        c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
        eps = (xt_next - at_next.sqrt() * x0_t - c2 * et) / c1
    else:
        raise ValueError()

    return eps

def sample_xt_next(x0, xt, t, t_next, sampling_type, b, eta):
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    assert not t_next.sum() == -t_next.shape[0]  # t_next should never be -1
    assert not t.sum() == 0  # t should never be 0
    at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    if sampling_type == 'ddpm':
        w0 = at_next.sqrt() * bt / (1 - at)
        wt = (1 - bt).sqrt() * (1 - at_next) / (1 - at)
        mean = w0 * x0 + wt * xt

        var = bt * (1 - at_next) / (1 - at)

        xt_next = mean + var.sqrt() * torch.randn_like(x0)
    elif sampling_type == 'ddim':
        et = (xt - at.sqrt() * x0) / (1 - at).sqrt()  # posterior et given x0 and xt
        c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
        c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
        xt_next = at_next.sqrt() * x0 + c2 * et + c1 * torch.randn_like(x0)
    else:
        raise ValueError()

    return xt_next

def prepare_ddpm_ddim(source_model_type, source_model_path):
    parser = argparse.ArgumentParser(description=globals()['__doc__'])
    # Default
    parser.add_argument('--config', type=str, required=True, help='Path to the config file')

    # Train & Test
    parser.add_argument('--model_path', type=str, default=None, help='Test model path')

    if source_model_type == 'ct256':
        # assert source_model_path is None
        ddim_args = parser.parse_args(
            [
                '--config', 'ct256.yml',
                '--model_path', source_model_path,
            ]
        )
    elif source_model_type == 'mr256':
        # assert source_model_path is not None
        ddim_args = parser.parse_args(
            [
                '--config', 'mr256.yml',
                '--model_path', source_model_path,
            ]
        )    

    # parse config file
    with open(os.path.join('configs', ddim_args.config), 'r') as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    return ddim_args, new_config

def sample_xt(x0, t, b):
    at = extract((1.0 - b).cumprod(dim=0), t, x0.shape)  # at is the \hat{\alpha}_t
    print('at', at)
    xt = at.sqrt() * x0 + (1 - at).sqrt() * torch.randn_like(x0)
    return xt

class DDPMDDIMWrapper(torch.nn.Module):

    def __init__(self, args, source_model_type, sample_type, custom_steps, es_steps, source_model_path=None,
                 refine_steps=150, refine_iterations=1, eta=0.1, t_0=None, enforce_class_input=None):
        super(DDPMDDIMWrapper, self).__init__()
        self.args = args
        self.enforce_class_input = enforce_class_input
        self.custom_steps = custom_steps
        self.refine_steps = refine_steps
        self.refine_iterations = refine_iterations
        self.sample_type = sample_type
        self.eta = eta
        self.t_0 = t_0 if t_0 is not None else 999
        self.es_steps = es_steps
        self.learn_sigma = args.learn_sigma

        if self.sample_type == 'ddim':
            assert self.eta > 0
        elif self.sample_type == 'ddpm':
            if not self.eta is None:
                self.eta = None
        else:
            raise ValueError()

        # Set up generator
        self.ddim_args, config = prepare_ddpm_ddim(source_model_type, source_model_path)

        print(f"{self.ddim_args}")
        print(f"{config=}")

        betas = get_beta_schedule(
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps
        )
        self.register_buffer(
            'betas', torch.from_numpy(betas).float()
        )
        self.num_timesteps = betas.shape[0]

        # ----------- Model -----------#
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.generator, s_diffusion = read_model_and_diffusion(self.args, source_model_path)
        self.logvar = np.log(np.maximum(posterior_variance, 1e-20))

        init_ckpt = torch.load(self.ddim_args.model_path)
        self.generator.load_state_dict(init_ckpt)

        self.resolution = config.data.image_size
        self.channels = config.data.channels
        self.latent_dim = self.resolution ** 2 * self.channels * self.es_steps
        # Freeze.
        requires_grad(self.generator, False)

        # Post process.
        # self.post_process = transforms.Compose(  # To un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]
        #     [transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])]
        # )

    def generate(self, z, class_label=None):
        if (self.t_0 + 1) % self.custom_steps == 0:
            seq_inv = range(0, self.t_0 + 1, (self.t_0 + 1) // self.custom_steps)
            assert len(seq_inv) == self.custom_steps
        else:
            seq_inv = np.linspace(0, 1, self.custom_steps) * self.t_0
        seq_inv = [int(s) for s in list(seq_inv)][:self.es_steps]  # 0, 1, ..., t_0
        seq_inv_next = ([-1] + list(seq_inv[:-1]))[:self.es_steps]  # -1, 0, 1, ..., t_0-1

        bsz = z.shape[0]
        eps_list = z.view(bsz, self.es_steps, self.channels, self.resolution, self.resolution)
        x_T = eps_list[:, 0]
        eps_list = eps_list[:, 1:]

        x = x_T

        for it, (i, j) in enumerate(zip(reversed(seq_inv), reversed(seq_inv_next))):
            t = (torch.ones(bsz) * i).to(self.device)
            t_next = (torch.ones(bsz) * j).to(self.device)

            if it < self.es_steps - 1:
                eps = eps_list[:, it]
                x = denoising_step_with_eps(x, eps=eps, t=t, t_next=t_next, models=self.generator,
                                            logvars=self.logvar,
                                            sampling_type=self.sample_type,
                                            b=self.betas,
                                            eta=self.eta,
                                            learn_sigma=self.learn_sigma)
            else:
                x = denoising_step(x, t=t, t_next=t_next, models=self.generator,
                                   logvars=self.logvar,
                                   sampling_type=self.sample_type,
                                   b=self.betas,
                                   eta=self.eta,
                                   learn_sigma=self.learn_sigma)

        if self.refine_steps == 0:
            img = x
        else:
            for r in range(self.refine_iterations):
                refine_eta = 1
                # Sample xt
                t = (torch.ones(bsz) * self.refine_steps - 1).to(self.device)
                xt = sample_xt(x0=x, t=t, b=self.betas)
                # Denoise
                x = xt
                assert self.refine_steps < self.custom_steps
                seq_inv_refine = seq_inv[:self.refine_steps]
                seq_inv_next_refine = seq_inv_next[:self.refine_steps]
                for i, j in zip(reversed(seq_inv_refine), reversed(seq_inv_next_refine)):
                    t = (torch.ones(bsz) * i).to(self.device)
                    t_next = (torch.ones(bsz) * j).to(self.device)
                    x = denoising_step(x, t=t, t_next=t_next, models=self.generator,
                                       logvars=self.logvar,
                                       sampling_type=self.sample_type,
                                       b=self.betas,
                                       eta=refine_eta,
                                       learn_sigma=self.learn_sigma)
            img = x

        return img

    def encode(self, image, class_label=None):
        # Eval mode for the generator.
        self.generator.eval()

        if (self.t_0 + 1) % self.custom_steps == 0:
            seq_inv = range(0, self.t_0 + 1, (self.t_0 + 1) // self.custom_steps)
            assert len(seq_inv) == self.custom_steps
        else:
            seq_inv = np.linspace(0, 1, self.custom_steps) * self.t_0
        seq_inv = [int(s) for s in list(seq_inv)][:self.es_steps]
        seq_inv_next = ([-1] + list(seq_inv[:-1]))[:self.es_steps]

        # Normalize.
        image = (image - 0.5) * 2.0
        # Resize.
        assert image.shape[2] == image.shape[3] == self.resolution

        with torch.no_grad():
            x0 = image
            bsz = x0.shape[0]

            # DPM-Encoder.
            T = (torch.ones(bsz) * (self.es_steps - 1)).to(self.device)
            xT = sample_xt(x0=x0, t=T, b=self.betas)
            z_list = [xT, ]

            xt = xT
            for it, (i, j) in enumerate(zip(reversed(seq_inv), reversed(seq_inv_next))):
                t = (torch.ones(bsz) * i).to(self.device)
                t_next = (torch.ones(bsz) * j).to(self.device)

                if it < self.es_steps - 1:
                    xt_next = sample_xt_next(
                        x0=x0,
                        xt=xt,
                        t=t,
                        t_next=t_next,
                        sampling_type=self.sample_type,
                        b=self.betas,
                        eta=self.eta,
                    )
                    eps = compute_eps(
                        xt=xt,
                        xt_next=xt_next,
                        t=t,
                        t_next=t_next,
                        models=self.generator,
                        sampling_type=self.sample_type,
                        b=self.betas,
                        logvars=self.logvar,
                        eta=self.eta,
                        learn_sigma=self.learn_sigma,
                    )
                    # print(it, (eps ** 2).sum().item())
                    xt = xt_next
                    z_list.append(eps)
                else:
                    break
            # z = z_list
            z = torch.stack(z_list, dim=1)#.view(bsz, -1)
            # assert z.shape[1] == self.latent_dim
            # np.savez('encoding.npz', z = z.cpu().numpy(),
            #           image=image.cpu().numpy())

        return z

    def forward(self, z, class_label=None):
        # Eval mode for the generator.
        self.generator.eval()

        img = self.generate(z, class_label)

        # Post process.
        img = self.post_process(img)

        return img

    @property
    def device(self):
        return next(self.parameters()).device

#%% set unet and diffusion parameters
defaults_param = model_and_diffusion_defaults()
new_param = dict(
    image_size=256,
    batch_size=1,
    num_channels=64,
    num_res_blocks=3,
    num_heads=1,
    diffusion_steps=1000,
    noise_schedule='linear',
    lr=1e-4,
    clip_denoised=False,
    num_samples=1, 
    use_ddim=True,
    # timestep_respacing='ddim250',
    model_path="",
)
defaults_param.update(new_param)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults_param)
args = parser.parse_args()
set_seed(42)

# source image path
s_img_path = '/home/jma/Documents/I2I/diffusers/data/abdomenCT256'
names = sorted(os.listdir(s_img_path))

name = names[10]
ct_data = io.imread(join(s_img_path, name))
s_np = ct_data.astype(np.float32) / 127.5 -1
source = torch.from_numpy(np.expand_dims(s_np, 0)).permute(0,3,1,2).to('cuda')

ct_model_path = './work_dir/CT256/ema_0.9999_400000.pt'
s_model, s_diffusion = read_model_and_diffusion(args, ct_model_path)

# test model inference; passed: the model can generate reasonable images
# s_model.eval()
# with torch.no_grad():
#     img_demo = s_model(source, (torch.ones(1)*900).to("cuda"))

#%% get_encoding
cycle_wapper = DDPMDDIMWrapper(args=args, source_model_type='ct256', sample_type='ddim', custom_steps=1000, es_steps=850,
                                refine_steps=150, source_model_path=ct_model_path)
embedding_exist = True
with torch.no_grad():
    if not embedding_exist:
        encoding = cycle_wapper.encode(source)
       # save encoding to disk 
        np.savez('encoding.npz', z = encoding.cpu().numpy())
    else:
        encod_npz = np.load('./encoding.npz')
        encoding = torch.from_numpy(encod_npz['z']).to('cuda')
        recon_img = cycle_wapper.generate(encoding)

# save input image and reconstructed image
io.imsave('ori.png', ct_data)
io.imsave('recon.png', sample2img(recon_img))
ChenWu98 commented 1 year ago

Can you try setting refine_steps = 0? This parameter is used to remove potential noises during translation (maybe because the cat and dog DDPMs are not well-trained; the text-guided experiment does not require such noise removal). I'm happy to check in detail if this solution does not work.

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

Thanks for your guidance very much.

Here is the result for refine_steps = 0. The reconstruction quality looks worse than the previous refine_steps = 150

cycle_wapper = DDPMDDIMWrapper(args=args, source_model_type='ct256', sample_type='ddim', custom_steps=1000, es_steps=850, refine_steps=0, source_model_path=ct_model_path)

original image:

ori

reconstructed image:

encoding = cycle_wapper.encode(source)
recon_img = cycle_wapper.generate(encoding)

recon

I'm validating the cycle consistency feature (the image should be reconstructed based on its DPMEncoding), but the results are so wired. Most of the above code is copied from https://github.com/ChenWu98/cycle-diffusion/blob/main/model/gan_wrapper/ddpm_ddim_wrapper.py and I didn't make modifications to the functions. I also checked that my DDPM model is well-trained since it can generate very good images from random noise. Any further comments are highly appreciated.

ChenWu98 commented 1 year ago

Hi, two things that I can help with.

  1. Normalization. The .encode() function assumes that the input image to this function is in [0, 1] and the input to the diffusion model is [-1, 1]. So maybe there can be some double-checking of the normalization and post-processing.
  2. If things still don't go well after double-checking the normalization, can you provide the weights and the input image so that I can run the code from my side?

Thanks a lot for your patience!

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

Thanks for your quick reply very much.

I tried to normalize the image to [0,1] but it still can't reconstruct the original image.

I send you an email to reproduce the results. You can also download the data and code at https://drive.google.com/file/d/1cm1omI0RgOt592sM_DkVrpLiWTr-Sn6H/view?usp=share_link

You can reproduce my results by

Step 1. pip install -e .

Step 2. python cyclediffusion.py

ChenWu98 commented 1 year ago

Thanks, I'll run the code soon.

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

Happy new year of the rabbit! And hope you enjoyed the weekend.

I also tested the DDIB during the past week and got some promising results with the same model that I shared with you.

image

image

image

image

But for some images, the results are not robust.

image

image

Does the code work on your side?

JunMa11 commented 1 year ago

Hi @ChenWu98 ,

I also raised an issue on DDIB in case you are interested in this problem.

https://github.com/suxuann/ddib/issues/9

ChenWu98 commented 1 year ago

Hi @JunMa11,

Happy new year and sorry about the late reply! I have tested your code (with minor modifications), and it worked well for me. Minor modifications I made:

  1. Improved memory efficiency, so now there is no need to save to numpy first.
  2. Fixed some import errors.

Let me put the code below, and let me know if there are problems!

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 30 20:38:03 2022

@author: jma
"""
#%% load packages
import os
join = os.path.join
import argparse
import yaml
import numpy as np
import torch
# import torchvision.transforms as transforms
import torch.nn.functional as F
from skimage import io
# from ..lib.ddpm_ddim.models.improved_ddpm.script_util import i_DDPM # return model
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
    create_model_and_diffusion,
    args_to_dict
)
#%%

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
    betas = np.linspace(beta_start, beta_end,
                        num_diffusion_timesteps, dtype=np.float64)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

def extract(a, t, x_shape):
    """Extract coefficients from a based on t and reshape to make it
    broadcastable with x_shape."""
    bs, = t.shape
    assert x_shape[0] == bs
    out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long())
    assert out.shape == (bs,)
    out = out.reshape((bs,) + (1,) * (len(x_shape) - 1))
    return out

def denoising_step(xt, t, t_next, *,
                   models,
                   logvars,
                   b,
                   sampling_type='ddpm',
                   eta=0.0,
                   learn_sigma=False,
                   hybrid=False,
                   hybrid_config=None,
                   ratio=1.0,
                   out_x0_t=False,
                   ):

    # Compute noise and variance
    if type(models) != list:
        model = models
        et = model(xt, t)
        if et.shape != xt.shape:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
        if learn_sigma:
            # calculations for posterior q(x_{t-1} | x_t, x_0)
            bt = extract(b, t, xt.shape)
            at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            if t_next.sum() == -t_next.shape[0]:  # if t_next is -1
                at_next = torch.ones_like(at)
            else:
                at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            posterior_variance = bt * (1.0 - at_next) / (1.0 - at)
            # log calculation clipped because the posterior variance is 0 at the
            # beginning of the diffusion chain.
            min_log = torch.log(posterior_variance.clamp(min=1e-6))
            max_log = torch.log(bt)
            frac = (model_var_values + 1) / 2
            logvar = frac * max_log + (1 - frac) * min_log
        else:
            logvar = extract(logvars, t, xt.shape)
    else:
        if not hybrid:
            et = 0
            logvar = 0
            if ratio != 0.0:
                et_i = ratio * models[1](xt, t)
                if learn_sigma:
                    raise NotImplementedError()
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += ratio * extract(logvars, t, xt.shape)
                et += et_i

            if ratio != 1.0:
                et_i = (1 - ratio) * models[0](xt, t)
                if learn_sigma:
                    raise NotImplementedError()
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += (1 - ratio) * extract(logvars, t, xt.shape)
                et += et_i

        else:
            for thr in list(hybrid_config.keys()):
                if t.item() >= thr:
                    et = 0
                    logvar = 0
                    for i, ratio in enumerate(hybrid_config[thr]):
                        ratio /= sum(hybrid_config[thr])
                        et_i = models[i+1](xt, t)
                        if learn_sigma:
                            raise NotImplementedError()
                            et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                            logvar_i = logvar_learned
                        else:
                            logvar_i = extract(logvars, t, xt.shape)
                        et += ratio * et_i
                        logvar += ratio * logvar_i
                    break

    # Compute the next x
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    if t_next.sum() == -t_next.shape[0]:  # if t_next is -1
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    xt_next = torch.zeros_like(xt)
    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        noise = torch.randn_like(xt)
        mask = 1 - (t == 0).float()
        mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
        xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
        xt_next = xt_next.float()

    elif sampling_type == 'ddim':
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()  # predicted x0_t
        if eta == 0:
            xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et
        elif at > (at_next):
            print('Inversion process is only possible with eta = 0')
            raise ValueError
        else:
            c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
            c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
            xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(xt)

    if out_x0_t == True:
        return xt_next, x0_t
    else:
        return xt_next

def read_model_and_diffusion(args, model_path):
    """Reads the latest model from the given directory."""

    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
    )
    model.load_state_dict(dist_util.load_state_dict(model_path, map_location="cpu"))
    model.to(dist_util.dev())
    # if args.use_fp16:
    #     model.convert_to_fp16()
    model.eval()
    return model, diffusion

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).

    Args:
        seed (`int`): The seed to set.
    """
    # random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # set torch benchmark 
    torch.backends.cudnn.benchmark = True

def sample2img(sample):
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous().cpu().numpy()[0]

    return sample

#%% reverse: x_{t-1} -> x_t
def denoising_step_with_eps(xt, eps, t, t_next, *,
                            models,
                            logvars,
                            b,
                            sampling_type='ddpm',
                            eta=0.0,
                            learn_sigma=False,
                            hybrid=False,
                            hybrid_config=None,
                            ratio=1.0,
                            out_x0_t=False,
                            ):

    assert eps.shape == xt.shape

    # Compute noise and variance
    if type(models) != list:
        model = models
        et = model(xt, t)
        if et.shape != xt.shape:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
        if learn_sigma:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
            # calculations for posterior q(x_{t-1} | x_t, x_0)
            bt = extract(b, t, xt.shape)
            at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            posterior_variance = bt * (1.0 - at_next) / (1.0 - at)
            # log calculation clipped because the posterior variance is 0 at the
            # beginning of the diffusion chain.
            min_log = torch.log(posterior_variance.clamp(min=1e-6))
            max_log = torch.log(bt)
            frac = (model_var_values + 1) / 2
            logvar = frac * max_log + (1 - frac) * min_log
        else:
            logvar = extract(logvars, t, xt.shape)
    else:
        if not hybrid:
            et = 0
            logvar = 0
            if ratio != 0.0:
                et_i = ratio * models[1](xt, t)
                if learn_sigma:
                    raise NotImplementedError()
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += ratio * extract(logvars, t, xt.shape)
                et += et_i

            if ratio != 1.0:
                et_i = (1 - ratio) * models[0](xt, t)
                if learn_sigma:
                    raise NotImplementedError()
                    et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                    logvar += logvar_learned
                else:
                    logvar += (1 - ratio) * extract(logvars, t, xt.shape)
                et += et_i

        else:
            for thr in list(hybrid_config.keys()):
                if t.item() >= thr:
                    et = 0
                    logvar = 0
                    for i, ratio in enumerate(hybrid_config[thr]):
                        ratio /= sum(hybrid_config[thr])
                        et_i = models[i+1](xt, t)
                        if learn_sigma:
                            raise NotImplementedError()
                            et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
                            logvar_i = logvar_learned
                        else:
                            logvar_i = extract(logvars, t, xt.shape)
                        et += ratio * et_i
                        logvar += ratio * logvar_i
                    break

    # Compute the next x
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    if t_next.sum() == -t_next.shape[0]:  # if t_next is -1
        at_next = torch.ones_like(at)
    else:
        at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    xt_next = torch.zeros_like(xt)
    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        noise = eps
        mask = 1 - (t == 0).float()
        mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
        xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
        xt_next = xt_next.float()

    elif sampling_type == 'ddim':
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()  # predicted x0_t
        if eta == 0:
            xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et
        elif at > (at_next):
            print('Inversion process is only possible with eta = 0')
            raise ValueError
        else:
            c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
            c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
            xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * eps

    if out_x0_t == True:
        return xt_next, x0_t
    else:
        return xt_next

def compute_eps(xt, xt_next, t, t_next, models, sampling_type, b, logvars, eta, learn_sigma):

    assert eta is None or eta > 0
    # Compute noise and variance
    if type(models) != list:
        model = models
        et = model(xt, t)
        if et.shape != xt.shape:
            et, model_var_values = torch.split(et, et.shape[1] // 2, dim=1)
        if learn_sigma:
            # calculations for posterior q(x_{t-1} | x_t, x_0)
            bt = extract(b, t, xt.shape)
            at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)
            posterior_variance = bt * (1.0 - at_next) / (1.0 - at)
            # log calculation clipped because the posterior variance is 0 at the
            # beginning of the diffusion chain.
            min_log = torch.log(posterior_variance.clamp(min=1e-6))
            max_log = torch.log(bt)
            frac = (model_var_values + 1) / 2
            logvar = frac * max_log + (1 - frac) * min_log
        else:
            logvar = extract(logvars, t, xt.shape)
    else:
        raise NotImplementedError()

    # Compute the next x
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    assert not t_next.sum() == -t_next.shape[0]  # t_next should never be -1
    assert not t.sum() == 0  # t should never be 0
    at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    if sampling_type == 'ddpm':
        weight = bt / torch.sqrt(1 - at)

        mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
        # print('torch.exp(0.5 * logvar).sum()', torch.exp(0.5 * logvar).sum())
        eps = (xt_next - mean) / torch.exp(0.5 * logvar)

    elif sampling_type == 'ddim':
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()  # predicted x0_t

        c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
        c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
        eps = (xt_next - at_next.sqrt() * x0_t - c2 * et) / c1
    else:
        raise ValueError()

    return eps

def sample_xt_next(x0, xt, t, t_next, sampling_type, b, eta):
    bt = extract(b, t, xt.shape)  # bt is the \beta_t
    at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)  # at is the \hat{\alpha}_t (DDIM does not use \hat notation)

    assert not t_next.sum() == -t_next.shape[0]  # t_next should never be -1
    assert not t.sum() == 0  # t should never be 0
    at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)  # at_next is the \hat{\alpha}_{t_next}

    if sampling_type == 'ddpm':
        w0 = at_next.sqrt() * bt / (1 - at)
        wt = (1 - bt).sqrt() * (1 - at_next) / (1 - at)
        mean = w0 * x0 + wt * xt

        var = bt * (1 - at_next) / (1 - at)

        xt_next = mean + var.sqrt() * torch.randn_like(x0)
    elif sampling_type == 'ddim':
        et = (xt - at.sqrt() * x0) / (1 - at).sqrt()  # posterior et given x0 and xt
        c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()  # sigma_t
        c2 = ((1 - at_next) - c1 ** 2).sqrt()  # direction pointing to x_t
        xt_next = at_next.sqrt() * x0 + c2 * et + c1 * torch.randn_like(x0)
    else:
        raise ValueError()

    return xt_next

def prepare_ddpm_ddim(source_model_type, source_model_path):
    parser = argparse.ArgumentParser(description=globals()['__doc__'])
    # Default
    parser.add_argument('--config', type=str, required=True, help='Path to the config file')

    # Train & Test
    parser.add_argument('--model_path', type=str, default=None, help='Test model path')

    if source_model_type == 'ct256':
        # assert source_model_path is None
        ddim_args = parser.parse_args(
            [
                '--config', 'ct256.yml',
                '--model_path', source_model_path,
            ]
        )
    elif source_model_type == 'mr256':
        # assert source_model_path is not None
        ddim_args = parser.parse_args(
            [
                '--config', 'mr256.yml',
                '--model_path', source_model_path,
            ]
        )    

    # parse config file
    with open(os.path.join('configs', ddim_args.config), 'r') as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    return ddim_args, new_config

def sample_xt(x0, t, b):
    at = extract((1.0 - b).cumprod(dim=0), t, x0.shape)  # at is the \hat{\alpha}_t
    print('at', at)
    xt = at.sqrt() * x0 + (1 - at).sqrt() * torch.randn_like(x0)
    return xt

class DDPMDDIMWrapper(torch.nn.Module):

    def __init__(self, args, source_model_type, sample_type, custom_steps, es_steps, source_model_path=None,
                 refine_steps=150, refine_iterations=1, eta=0.1, t_0=None, enforce_class_input=None):
        super(DDPMDDIMWrapper, self).__init__()
        self.args = args
        self.enforce_class_input = enforce_class_input
        self.custom_steps = custom_steps
        self.refine_steps = refine_steps
        self.refine_iterations = refine_iterations
        self.sample_type = sample_type
        self.eta = eta
        self.t_0 = t_0 if t_0 is not None else 999
        self.es_steps = es_steps
        self.learn_sigma = args.learn_sigma

        if self.sample_type == 'ddim':
            assert self.eta > 0
        elif self.sample_type == 'ddpm':
            if not self.eta is None:
                self.eta = None
        else:
            raise ValueError()

        # Set up generator
        self.ddim_args, config = prepare_ddpm_ddim(source_model_type, source_model_path)

        print(f"{self.ddim_args}")
        print(f"{config=}")

        betas = get_beta_schedule(
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps
        )
        self.register_buffer(
            'betas', torch.from_numpy(betas).float()
        )
        self.num_timesteps = betas.shape[0]

        # ----------- Model -----------#
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.generator, s_diffusion = read_model_and_diffusion(self.args, source_model_path)
        self.logvar = np.log(np.maximum(posterior_variance, 1e-20))

        init_ckpt = torch.load(self.ddim_args.model_path)
        self.generator.load_state_dict(init_ckpt)

        self.resolution = config.data.image_size
        self.channels = config.data.channels
        self.latent_dim = self.resolution ** 2 * self.channels * self.es_steps
        # Freeze.
        requires_grad(self.generator, False)

        # Post process.
        # self.post_process = transforms.Compose(  # To un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]
        #     [transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])]
        # )

    def generate(self, z, class_label=None):
        if (self.t_0 + 1) % self.custom_steps == 0:
            seq_inv = range(0, self.t_0 + 1, (self.t_0 + 1) // self.custom_steps)
            assert len(seq_inv) == self.custom_steps
        else:
            seq_inv = np.linspace(0, 1, self.custom_steps) * self.t_0
        seq_inv = [int(s) for s in list(seq_inv)][:self.es_steps]  # 0, 1, ..., t_0
        seq_inv_next = ([-1] + list(seq_inv[:-1]))[:self.es_steps]  # -1, 0, 1, ..., t_0-1

        bsz = z[0].shape[0]
        eps_list = z
        x_T = eps_list[0]
        eps_list = eps_list[1:]

        x = x_T

        for it, (i, j) in enumerate(zip(reversed(seq_inv), reversed(seq_inv_next))):
            t = (torch.ones(bsz) * i).to(self.device)
            t_next = (torch.ones(bsz) * j).to(self.device)

            if it < self.es_steps - 1:
                eps = eps_list[it]
                x = denoising_step_with_eps(x, eps=eps, t=t, t_next=t_next, models=self.generator,
                                            logvars=self.logvar,
                                            sampling_type=self.sample_type,
                                            b=self.betas,
                                            eta=self.eta,
                                            learn_sigma=self.learn_sigma)
            else:
                x = denoising_step(x, t=t, t_next=t_next, models=self.generator,
                                   logvars=self.logvar,
                                   sampling_type=self.sample_type,
                                   b=self.betas,
                                   eta=self.eta,
                                   learn_sigma=self.learn_sigma)

        if self.refine_steps == 0:
            img = x
        else:
            for r in range(self.refine_iterations):
                refine_eta = 1
                # Sample xt
                t = (torch.ones(bsz) * self.refine_steps - 1).to(self.device)
                xt = sample_xt(x0=x, t=t, b=self.betas)
                # Denoise
                x = xt
                assert self.refine_steps < self.custom_steps
                seq_inv_refine = seq_inv[:self.refine_steps]
                seq_inv_next_refine = seq_inv_next[:self.refine_steps]
                for i, j in zip(reversed(seq_inv_refine), reversed(seq_inv_next_refine)):
                    t = (torch.ones(bsz) * i).to(self.device)
                    t_next = (torch.ones(bsz) * j).to(self.device)
                    x = denoising_step(x, t=t, t_next=t_next, models=self.generator,
                                       logvars=self.logvar,
                                       sampling_type=self.sample_type,
                                       b=self.betas,
                                       eta=refine_eta,
                                       learn_sigma=self.learn_sigma)
            img = x

        return img

    def encode(self, image, class_label=None):
        # Eval mode for the generator.
        self.generator.eval()

        if (self.t_0 + 1) % self.custom_steps == 0:
            seq_inv = range(0, self.t_0 + 1, (self.t_0 + 1) // self.custom_steps)
            assert len(seq_inv) == self.custom_steps
        else:
            seq_inv = np.linspace(0, 1, self.custom_steps) * self.t_0
        seq_inv = [int(s) for s in list(seq_inv)][:self.es_steps]
        seq_inv_next = ([-1] + list(seq_inv[:-1]))[:self.es_steps]

        # Normalize.
        image = (image - 0.5) * 2.0
        # Resize.
        assert image.shape[2] == image.shape[3] == self.resolution

        with torch.no_grad():
            x0 = image
            bsz = x0.shape[0]

            # DPM-Encoder.
            T = (torch.ones(bsz) * (self.es_steps - 1)).to(self.device)
            xT = sample_xt(x0=x0, t=T, b=self.betas)
            z_list = [xT, ]

            xt = xT
            for it, (i, j) in enumerate(zip(reversed(seq_inv), reversed(seq_inv_next))):
                t = (torch.ones(bsz) * i).to(self.device)
                t_next = (torch.ones(bsz) * j).to(self.device)

                if it < self.es_steps - 1:
                    xt_next = sample_xt_next(
                        x0=x0,
                        xt=xt,
                        t=t,
                        t_next=t_next,
                        sampling_type=self.sample_type,
                        b=self.betas,
                        eta=self.eta,
                    )
                    eps = compute_eps(
                        xt=xt,
                        xt_next=xt_next,
                        t=t,
                        t_next=t_next,
                        models=self.generator,
                        sampling_type=self.sample_type,
                        b=self.betas,
                        logvars=self.logvar,
                        eta=self.eta,
                        learn_sigma=self.learn_sigma,
                    )
                    # print(it, (eps ** 2).sum().item())
                    xt = xt_next
                    z_list.append(eps)
                else:
                    break
            z = z_list
            # assert z.shape[1] == self.latent_dim
            # np.savez('encoding.npz', z = z.cpu().numpy(),
            #           image=image.cpu().numpy())

        return z

    def forward(self, z, class_label=None):
        # Eval mode for the generator.
        self.generator.eval()

        img = self.generate(z, class_label)

        # Post process.
        img = self.post_process(img)

        return img

    @property
    def device(self):
        return next(self.parameters()).device

#%% set unet and diffusion parameters
defaults_param = model_and_diffusion_defaults()
new_param = dict(
    image_size=256,
    batch_size=1,
    num_channels=64,
    num_res_blocks=3,
    num_heads=1,
    diffusion_steps=1000,
    noise_schedule='linear',
    lr=1e-4,
    clip_denoised=False,
    num_samples=1, 
    use_ddim=True,
    # timestep_respacing='ddim250',
    model_path="",
)
defaults_param.update(new_param)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults_param)
# def main():
args = parser.parse_args()

set_seed(42)
s_img_path = './'
name = 'ct_ori.png'
ct_data = io.imread(join(s_img_path, name))
source_np = ct_data.astype(np.float32) / 127.5 -1 # normalize to [-1, 1]
source_np01 = ct_data.astype(np.float32) / 255.0 # normalize to [0, 1]
source = torch.from_numpy(np.expand_dims(source_np, 0)).permute(0,3,1,2).to('cuda')
source01 = torch.from_numpy(np.expand_dims(source_np01, 0)).permute(0,3,1,2).to('cuda')

ct_model_path = './work_dir/abdomenCT256/ema_0.9999_480000.pt'
s_model, s_diffusion = read_model_and_diffusion(args, ct_model_path)

#%% get_encoding and reconstruct the image based on the encoding --> verify the cycle consistency
ct2ct_wapper = DDPMDDIMWrapper(args=args, source_model_type='ct256', sample_type='ddim', custom_steps=1000, es_steps=850,
                                refine_steps=0, source_model_path=ct_model_path).to('cuda')
embedding_exist = False
# it is wired that i cannot run encode and generate in the same session. 
# CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasLtMatmul
# So I first save the encoding and load it during generation.

if not embedding_exist:
    encoding = ct2ct_wapper.encode(source01)
    # np.savez('encoding.npz', z = encoding.cpu().numpy())
else:
    encod_npz = np.load('./encoding.npz')
    encoding = torch.from_numpy(encod_npz['z']).to('cuda')
with torch.no_grad():
    recon_img = ct2ct_wapper.generate(encoding)
io.imsave('ct_ori_check.png', ct_data)
io.imsave('ct_recon.png', sample2img(recon_img))

#%% generate MR image
mr_model_path = './work_dir/abdomenMR256/ema_0.9999_480000.pt'

ct2mr_wapper = DDPMDDIMWrapper(args=args, source_model_type='ct256', sample_type='ddim', custom_steps=1000, es_steps=850,
                               refine_steps=0, source_model_path=mr_model_path).to('cuda')
with torch.no_grad():
    recon_mr = ct2mr_wapper.generate(encoding)
io.imsave('ct2mr.png', sample2img(recon_mr))
print('done')
JunMa11 commented 1 year ago

Hi @ChenWu98 ,

Thank you so much for your great help. The updated code works can robustly keep the cycleconsistency.

I did some translation experiments CT-MR. Both CT and MR models are trained on the same human abdomen dataset. However, it can be found that the translated MR images have very different structures from the input CT images. Do you have any suggestions about how I can improve the structure consistency in the translated images?

FLARE22_Tr_0001-0014

FLARE22_Tr_0003-0053

teaghan commented 1 year ago

@JunMa11 I was also considering using this method for a domain transfer application. Did you have any luck getting the translated images to maintain the structure of the original images? Thanks!