karrykkk / BayesDiff

7 stars 4 forks source link

Can you give the visualization code for the uncertainty estimation for each pixel point? #1

Open hejiaxiang1 opened 3 months ago

hejiaxiang1 commented 3 months ago

I tried to visualize the var part, but the output has no useful information. My modified /sd/dpmsolver_skipUQ.py code is as follows:

    #########   start sample  ########## 
    c = model.get_learned_conditioning(opt.prompt)
    c = torch.concat(opt.sample_batch_size * [c], dim=0)
    exp_dir = f'./dpm_solver_2_exp/skipUQ/{opt.prompt}_train{opt.train_la_data_size}_step{opt.timesteps}_S{opt.mc_size}/'
    os.makedirs(exp_dir, exist_ok=True)
    total_n_samples = opt.total_n_samples
    if total_n_samples % opt.sample_batch_size != 0:
        raise ValueError("Total samples for sampling must be divided exactly by opt.sample_batch_size, but got {} and {}".format(total_n_samples, opt.sample_batch_size))
    n_rounds = total_n_samples // opt.sample_batch_size
    var_sum = torch.zeros((opt.sample_batch_size, n_rounds)).to(device)
    sample_x = []
    var_x = [] # add
    img_id = 1000000
    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                for loop in tqdm(
                    range(n_rounds), desc="Generating image samples for FID evaluation."
                ):

                    xT, timestep, mc_sample_size  = torch.randn([opt.sample_batch_size, opt.C, opt.H // opt.f, opt.W // opt.f], device=device), opt.timesteps//2, opt.mc_size
                    T = t_seq[timestep]
                    if uq_array[timestep] == True:
                        xt_next = xT
                        exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
                        eps_mu_t_next, eps_var_t_next = custom_ld(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c) 
                        cov_xt_next_epst_next = torch.zeros_like(xT).to(device)
                        _, model_s1, _ = conditioned_update(ns, xt_next, T, t_seq[timestep-1], custom_ld, eps_mu_t_next, pre_wuq=True, r1=0.5, c=c)
                        list_eps_mu_t_next_i = torch.unsqueeze(model_s1, dim=0)
                    else:
                        xt_next = xT
                        exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
                        eps_mu_t_next = custom_ld.accurate_forward(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c)

                    ####### Start skip UQ sampling  ######
                    for timestep in range(opt.timesteps//2, 0, -1):

                        if uq_array[timestep] == True:
                            xt = xt_next
                            exp_xt, var_xt = exp_xt_next, var_xt_next
                            eps_mu_t, eps_var_t, cov_xt_epst = eps_mu_t_next, eps_var_t_next, cov_xt_next_epst_next
                            mc_eps_exp_t = torch.mean(list_eps_mu_t_next_i, dim=0)
                        else: 
                            xt = xt_next
                            exp_xt, var_xt = exp_xt_next, var_xt_next
                            eps_mu_t = eps_mu_t_next

                        s, t = t_seq[timestep], t_seq[timestep-1]
                        if uq_array[timestep] == True:
                            eps_t= sample_from_gaussion(eps_mu_t, eps_var_t)
                            xt_next, _ , model_s1_var = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
                            exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, pre_wuq=uq_array[timestep], mc_eps_exp_s1=mc_eps_exp_t)
                            var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep], cov_xt_epst= cov_xt_epst, var_epst=model_s1_var)
                            # decide whether to see xt_next as a random variable
                            if uq_array[timestep-1] == True:
                                list_xt_next_i, list_eps_mu_t_next_i=[], []
                                s_next = t_seq[timestep-1]
                                t_next = t_seq[timestep-2]
                                lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
                                h_next = lambda_t_next - lambda_s_next
                                lambda_s1_next = lambda_s_next + 0.5 * h_next
                                s1_next = ns.inverse_lambda(lambda_s1_next)
                                sigma_s1_next = ns.marginal_std(s1_next)
                                log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
                                phi_11_next = torch.expm1(0.5*h_next)

                                for _ in range(mc_sample_size):

                                    var_xt_next = torch.clamp(var_xt_next, min=0)
                                    xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
                                    list_xt_next_i.append(xt_next_i)
                                    model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
                                    xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
                                                                    torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
                                    model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
                                    list_eps_mu_t_next_i.append(model_u_i)

                                eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
                                list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
                                list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
                                cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
                            else:
                                eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)

                        else:
                            xt_next, model_s1 = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_mu_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
                            exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, exp_s1= model_s1, pre_wuq=uq_array[timestep])
                            var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep])
                            if uq_array[timestep-1] == True:
                                list_xt_next_i, list_eps_mu_t_next_i=[], []
                                s_next = t_seq[timestep-1]
                                t_next = t_seq[timestep-2]
                                lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
                                h_next = lambda_t_next - lambda_s_next
                                lambda_s1_next = lambda_s_next + 0.5 * h_next
                                s1_next = ns.inverse_lambda(lambda_s1_next)
                                sigma_s1_next = ns.marginal_std(s1_next)
                                log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
                                phi_11_next = torch.expm1(0.5*h_next)

                                for _ in range(mc_sample_size):

                                    var_xt_next = torch.clamp(var_xt_next, min=0)
                                    xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
                                    list_xt_next_i.append(xt_next_i)
                                    model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
                                    xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
                                                                    torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
                                    model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
                                    list_eps_mu_t_next_i.append(model_u_i)

                                eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
                                list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
                                list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
                                cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
                            else:
                                eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)

                    ####### Save variance and sample image  ######         
                    var_sum[:, loop] = var_xt_next.sum(dim=(1,2,3))
                    x_samples = model.decode_first_stage(xt_next) # 
                    # var_xt_next = model.decode_first_stage(var_xt_next)# add
                    x = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                    # os.makedirs(os.path.join(exp_dir, 'sam/'), exist_ok=True)
                    # for i in range(x.shape[0]):
                    #     path = os.path.join(exp_dir, 'sam/', f"{img_id}.png")
                    #     tvu.save_image(x.cpu()[i].float(), path)
                    #     img_id += 1
                    sample_x.append(x)
                    var_x.append(var_xt_next) # add

                sample_x = torch.concat(sample_x, dim=0)
                var_x = torch.concat(var_x, dim=0)# add
                var = []
                for j in range(n_rounds):
                    var.append(var_sum[:, j])
                var = torch.concat(var, dim=0)
                sorted_var, sorted_indices = torch.sort(var, descending=True)
                reordered_sample_x = torch.index_select(sample_x, dim=0, index=sorted_indices.int())
                grid_sample_x = tvu.make_grid(reordered_sample_x, nrow=8, padding=2)
                tvu.save_image(grid_sample_x.cpu().float(), os.path.join(exp_dir, "sorted_sample.png"))

                print(f'Sampling {total_n_samples} images in {exp_dir}')
                torch.save(var_sum.cpu(), os.path.join(exp_dir, 'var_sum.pt'))

                var_x = var_x.mean(dim=1, keepdim=True) # add
                reordered_var_x = torch.index_select(var_x, dim=0, index=sorted_indices.int()) # add
                grid_var_x = tvu.make_grid(reordered_var_x, nrow=12, padding=1, normalize=True) # add
                tvu.save_image(grid_var_x.cpu().float(), os.path.join(exp_dir, "sorted_var.png")) # add
xiexh20 commented 3 months ago

I have a similar issue. I tried to visualize ddpm uncertainty of imagenet generations, but the image is not very meaningful.

In section 4.3, you write we sample a variety of latent states...estimate the empirical variance...as the final pixel-wise uncertainty

How did you sample exactly? Did you do Gaussian sample over the final exp_xt?

Thank you for your time and help!

karrykkk commented 3 months ago

Thanks for your interest in our work!

For variance visualization of Stable Diffusion in the latent space, we save $E(z_0)$ and $Var(z_0)$ (exp_xt_next and var_xtnext in the xxUQ.py script) and resample $z{0,1}, ..., z_{0,N}$ from Gaussian distribution $\mathcal{N}(E(z_0), Var(z0))$. Then we decode them to $x{0,1}, ..., x_{0,N}$ and estimate the empirical variance as the final pixel-wise variance.

For the code for visualization, you refer to this script below. Feel free to ask if you have any further questions.

import torch
from matplotlib import pyplot as plt
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from torchvision import transforms
import torchvision.utils as tvu

to_pil = transforms.ToPILImage()

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model

config = OmegaConf.load(f"configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, f"your_local_sd_ckpt").to(torch.device("cuda:5"))
device = torch.device("cuda:5")
#get z
z_dev_list = []
z_exp_list = []

exp_dir = 'your_local_exp_dir'

id = 0
z_var_i = torch.load(f'{exp_dir}/z_var/{id}.pth')
z_exp_i = torch.load(f'{exp_dir}/z_exp/{id}.pth')
z_dev_i = torch.clamp(z_var_i,min=0)**0.5
z_dev_list.append(z_dev_i)
z_exp_list.append(z_exp_i)

def get_dev_x_from_z(dev,exp,N):
     #get n samples from z distribution
    z_list = []
    for i in range(N):
        z_list.append(
            exp + torch.rand_like(exp) * dev
        )

    #### decode z into x
    Z = torch.stack(z_list,dim = 0)
    X = model.decode_first_stage(Z.to(device))
    var_x = torch.var(X,dim = 0)
    exp_x = torch.mean(X,dim=0)
    dev_x = (var_x)**0.5
    return dev_x

import os
os.makedirs(f'{exp_dir}/x_dev',exist_ok=True)

N = 15
for index in range(1):
    z_dev = z_dev_list[index]
    z_exp = z_exp_list[index]
    dev_x = get_dev_x_from_z(z_dev,z_exp,N)
    tvu.save_image(dev_x*100,f'{exp_dir}/x_dev/{id}.jpg' )
cilevanmarken commented 2 months ago

Hi! I have successfully created uncertainty maps for Stable Diffusion. However, the uncertainty maps I generated for DDIM_and_guided by visualizing the var do not align with the results in your paper. Could you kindly provide the visualization code for this? Thank you in advance.

karrykkk commented 2 months ago

Hi👋~ Thank you for your interest in our work! For CELEBA uncertainty visualization using DDIM sampler, you can try the python script in ./ddpm_and_guided/ddim_skipUQ_visualization.py & this bash configuration:

DEVICES="5"
data="celeba"
steps="100"
mc_size="10"
sample_batch_size="16"
total_n_sample="16"
train_la_data_size="5000"
DIS="uniform"
fixed_class="10"
seed=123

CUDA_VISIBLE_DEVICES=$DEVICES python ddim_skipUQ_visualization.py \
--config $data".yml" --timesteps=$steps --skip_type=$DIS --train_la_batch_size 32 \
--mc_size=$mc_size --sample_batch_size=$sample_batch_size --fixed_class=$fixed_class --train_la_data_size=$train_la_data_size \
--total_n_sample=$total_n_sample --fixed_class=$fixed_class --seed=$seed
cilevanmarken commented 2 months ago

Thank you for your fast reply! However, when running the given visualization code on ImageNet instead of CELEBA (with the specifications of the last post as well as the standard specifications from the ddim.sh file), the generated uncertainty maps still don't make much sense. Do you have any pointers as to why this might be? Or does the visualization code for ImageNet differ from the visualization code of CELEBA? Thanks in advance! visualize_sample visualize_var

karrykkk commented 2 months ago

Hi @cilevanmarken ~

For ImageNet visualization, as the size of dataset grows, you need to increase train_la_data_size, which means using less data to fit the posterior distribution with the amount of #total_dataset_size/train_la_data_size to get larger variance. For example, you will get the following results after changing train_la_data_size=500000 in the bash script above.

visualize_sample visualize_var

LibertyRoamer commented 1 month ago

Hi! I would like to ask a question. If I change to a different dataset that only contains 200 images, what value should I set for train_la_data_size? Or would a small dataset like this lead to suboptimal results?