suxuann / ddib

Dual Diffusion Implicit Bridges for Image-to-Image Translation. ICLR 2023.
MIT License
354 stars 30 forks source link

why cycleconsistency is not robust #9

Open JunMa11 opened 1 year ago

JunMa11 commented 1 year ago

Dear @suxuann ,

Thanks for sharing the awesome work.

I tried DDIB for modality transfer: CT image to MR image.

I trained a CT model and an MR model on my own dataset based on guided-diffusion, respectively. I have verified that they can generate good samples.

Then, I tried cycle consistency and modality transfer CT-> MR. However, the cycle consistency is not robust and the transferred MR images have very different structures.

Here are some examples:

ct_0

ct_1

ct_2

ct_3

What could be the possible reason? Any comments are highly appreciated.

JunMa11 commented 1 year ago

I'm also attaching the code

import argparse
import numpy as np
import os
join = os.path.join
import pathlib
import torch.distributed as dist
from skimage import io, color
import torch
from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
    create_model_and_diffusion,
    args_to_dict
)
import matplotlib.pyplot as plt

def create_argparser():
    defaults = dict(
        image_size=256,
        batch_size=1,
        num_channels=64,
        num_res_blocks=3,
        num_heads=1,
        diffusion_steps=1000,
        noise_schedule='linear',
        lr=1e-4,
        clip_denoised=False,
        num_samples=1, # 10000
        use_ddim=True,
        # timestep_respacing='ddim250',
        model_path="",
    )
    ori = model_and_diffusion_defaults()
    # defaults.update(model_and_diffusion_defaults())
    ori.update(defaults)
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, ori)
    return parser

# def main():
args = create_argparser().parse_args()

logger.log(f"args: {args}")

dist_util.setup_dist()
logger.configure(dir='./log')

code_folder = './'
# data_folder = './datasets' # get_code_and_dataset_folders()

#%% load model
def read_model_and_diffusion(args, model_path):
    """Reads the latest model from the given directory."""

    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
    )
    model.load_state_dict(dist_util.load_state_dict(model_path, map_location="cuda"))
    model.to(dist_util.dev())
    # if args.use_fp16:
    #     model.convert_to_fp16()
    model.eval()
    return model, diffusion

ct_model_path =  './work_dir/abdomenCT256/ema_0.9999_480000.pt'
s_model, s_diffusion = read_model_and_diffusion(args, ct_model_path)
mr_model_path = './work_dir/abdomenMR256/ema_0.9999_480000.pt'
t_model, t_diffusion = read_model_and_diffusion(args, mr_model_path)
save_path = './log'
#%% translate image
s_img_path = './demo-img'
names = sorted(os.listdir(s_img_path))
# names = ['ct_ori.png']
def sample2img(sample):
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous().cpu().numpy()[0]

    return sample

for name in names:
    ct_data = io.imread(join(s_img_path, name))

    s_np = ct_data / np.max(ct_data)
    s_np = (s_np - 0.5) * 2.0
    # s_np = np.repeat(np.expand_dims(s_np, -1), 3, -1)
    assert s_np.shape == (256, 256, 3), 'shape error! Current shape' + ct_data.shape
    s_np = np.expand_dims(s_np, 0)

    source = torch.from_numpy(s_np.astype(np.float32)).permute(0,3,1,2).to('cuda')
    # print(f"{source.shape=}")
    noise = s_diffusion.ddim_reverse_sample_loop(
        s_model, source,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    source_recon = s_diffusion.ddim_sample_loop(
        s_model, (args.batch_size, 3, args.image_size, args.image_size),
        noise=noise,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    target = t_diffusion.ddim_sample_loop(
        t_model, (args.batch_size, 3, args.image_size, args.image_size),
        noise=noise,
        clip_denoised=False,
        device=dist_util.dev(),
    )

    #%% plot
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8,8))
    images = [ct_data, color.rgb2gray(sample2img(noise)), sample2img(source_recon), sample2img(target)]
    titles = ['CT image', 'CT noise encode', \
        'CT reconstruction', 'CT2MR']
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i], cmap='gray')
        ax.set_title(titles[i])
        ax.axis('off')
    plt.suptitle(name)

    plt.savefig(join(save_path, name), dpi=300)
suxuann commented 1 year ago

Hi Jun, thanks for your interests in our work, and attempting to validate our method on CT & MR images.

DDIBs translate images via a (regularized) optimal transport process. This is both an advantage and a limitation of our method. Training diffusion models on the two domains, independently, serves to decouple the training process; but the resulting optimal-transport based translation process may not necessarily produce images that you desire.

You can refer to Appendix B of our paper: https://arxiv.org/pdf/2203.08382.pdf, for detailed explanations about the phenomenon you observe. Let us know if you have additional questions!

JunMa11 commented 1 year ago

Hi @suxuann ,

Thanks for your answer very much. Now I understand the reason for the 2nd question.

Could you please explain the following question a little bit?

Why is cycle consistency (the noise encoding cannot reconstruct the original image) not robust? Base on the proof, it should be robust for different images.

leoil commented 1 year ago

Hi @JunMa11 , I'd like to ask some questions about model training. If I want to train a new model on my own dataset, just like your . /work_dir/abdomenCT256/ema_0.9999_480000.pt Could you please tell me how I should prepare the training script?

yang1173350896 commented 1 year ago

Hi @JunMa11 , I tried to reconstruct the original MR as well, but my reconstruction has a color problem. I tried to normalize the image to [0,1], but it still can't reconstruct the original image. Could you please tell me what could be the possible reason? image