guoqiang-zhang-x / BDIA

16 stars 1 forks source link

the choice of gamma on CIFAR10 #2

Open zituitui opened 6 months ago

zituitui commented 6 months ago

hi, I found the paper recommand gamma to be 1.0 in the CIFAR10 generation task. I replicate BDIA and try on CIFAR10 and found that BDIA perform best on gamma=0.1 and perform very bad when gamma=1.0. Is it because that I choose a different checkpoint from yours? image

guoqiang-zhang-x commented 6 months ago

Hi,

My current guess is that your implementation is not correct. If possible, could you try to share your python code with me?

Best regards, Guoqiang

On Thu, May 9, 2024 at 8:44 AM wangfangyikang @.***> wrote:

hi, I found the paper recommand gamma to be 1.0 in the CIFAR10 generation task. I replicate BDIA and try on CIFAR10 and found that BDIA perform best on gamma=0.1 and perform very bad when gamma=1.0. Is it because that I choose a different checkpoint from yours? image.png (view on web) https://github.com/guoqiang-zhang-x/BDIA/assets/65052672/2be88032-ce84-4084-8a0b-2a10c6b8319c

— Reply to this email directly, view it on GitHub https://github.com/guoqiang-zhang-x/BDIA/issues/2, or unsubscribe https://github.com/notifications/unsubscribe-auth/A6XG7OBZG3HQHVK3CQ7ZICLZBMLOJAVCNFSM6AAAAABHOIR7J2VHI2DSMVQWIX3LMV43ASLTON2WKOZSGI4DOMBRGIZTSMI . You are receiving this because you are subscribed to this thread.Message ID: @.***>

zituitui commented 6 months ago

hi thanks, here is my BDIA implementation

def bdia_forward(ddpm_pipe, batch_size, num_inference_steps, seed = 0, states=None, gamma = 1.0):
    dtype = torch.float32
    # torch.manual_seed(seed)
    ddpm_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
    timesteps = ddpm_pipe.scheduler.timesteps
    torch.manual_seed(seed)
    xis = []
    # Sample gaussian noise to begin loop
    if isinstance(ddpm_pipe.unet.config.sample_size, int):
        image_shape = (
            batch_size,
            ddpm_pipe.unet.config.in_channels,
            ddpm_pipe.unet.config.sample_size,
            ddpm_pipe.unet.config.sample_size,
        )
    else:
        image_shape = (batch_size, ddpm_pipe.unet.config.in_channels, *ddpm_pipe.unet.config.sample_size)
    states = torch.randn(image_shape, generator=None, device='cuda', dtype=dtype)

    xis.append(states)
    with torch.no_grad():
        for i, t in enumerate(timesteps):
            # print('###', i)
            noise_pred = ddpm_pipe.unet(
                states,
                t,
                return_dict=False,
            )[0]

            if i < num_inference_steps - 1:
                alpha_s = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i + 1]].to(torch.float32)
                alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32)
            else:
                alpha_s = 1
                alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32)

            sigma_s = (1 - alpha_s)**0.5
            sigma_t = (1 - alpha_t)**0.5
            alpha_s = alpha_s**0.5
            alpha_t = alpha_t**0.5

            coef_xt = alpha_s / alpha_t
            coef_eps = sigma_s - sigma_t * coef_xt
            if i == 0:
                states = coef_xt * states + coef_eps * noise_pred
            else:
                alpha_p = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i - 1]].to(torch.float32)
                sigma_p = (1 - alpha_p) ** 0.5
                alpha_p = alpha_p ** 0.5
                coef_xt = coef_xt - gamma * alpha_p / alpha_t
                coef_eps_2 = sigma_p - sigma_t * alpha_p / alpha_t
                coef_eps = coef_eps - gamma * coef_eps_2
                states = gamma * xis[-2] + coef_xt * xis[-1] + coef_eps * noise_pred

            xis.append(states)
    image = xis[-1]
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = ddpm_pipe.numpy_to_pil(image)
    return image
guoqiang-zhang-x commented 6 months ago

Can you also send me the original python code for DDIM sampling?

On Thu, May 9, 2024 at 8:06 AM wangfangyikang @.***> wrote:

hi thanks, here is my BDIA implementation

def bdia_forward(ddpm_pipe, batch_size, num_inference_steps, seed = 0, states=None, gamma = 1.0): dtype = torch.float32

torch.manual_seed(seed)

ddpm_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
timesteps = ddpm_pipe.scheduler.timesteps
torch.manual_seed(seed)
xis = []
# Sample gaussian noise to begin loop
if isinstance(ddpm_pipe.unet.config.sample_size, int):
    image_shape = (
        batch_size,
        ddpm_pipe.unet.config.in_channels,
        ddpm_pipe.unet.config.sample_size,
        ddpm_pipe.unet.config.sample_size,
    )
else:
    image_shape = (batch_size, ddpm_pipe.unet.config.in_channels, *ddpm_pipe.unet.config.sample_size)
states = torch.randn(image_shape, generator=None, device='cuda', dtype=dtype)

xis.append(states)
with torch.no_grad():
    for i, t in enumerate(timesteps):
        # print('###', i)
        noise_pred = ddpm_pipe.unet(
            states,
            t,
            return_dict=False,
        )[0]

        if i < num_inference_steps - 1:
            alpha_s = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i + 1]].to(torch.float32)
            alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32)
        else:
            alpha_s = 1
            alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32)

        sigma_s = (1 - alpha_s)**0.5
        sigma_t = (1 - alpha_t)**0.5
        alpha_s = alpha_s**0.5
        alpha_t = alpha_t**0.5

        coef_xt = alpha_s / alpha_t
        coef_eps = sigma_s - sigma_t * coef_xt
        if i == 0:
            states = coef_xt * states + coef_eps * noise_pred
        else:
            alpha_p = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i - 1]].to(torch.float32)
            sigma_p = (1 - alpha_p) ** 0.5
            alpha_p = alpha_p ** 0.5
            coef_xt = coef_xt - gamma * alpha_p / alpha_t
            coef_eps_2 = sigma_p - sigma_t * alpha_p / alpha_t
            coef_eps = coef_eps - gamma * coef_eps_2
            states = gamma * xis[-2] + coef_xt * xis[-1] + coef_eps * noise_pred

        xis.append(states)
image = xis[-1]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = ddpm_pipe.numpy_to_pil(image)
return image

— Reply to this email directly, view it on GitHub https://github.com/guoqiang-zhang-x/BDIA/issues/2#issuecomment-2102079369, or unsubscribe https://github.com/notifications/unsubscribe-auth/A6XG7OATFNZFM7IS5PAPYSLZBMN6HAVCNFSM6AAAAABHOIR7J2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBSGA3TSMZWHE . You are receiving this because you commented.Message ID: @.***>

zituitui commented 6 months ago

Can you also send me the original python code for DDIM sampling? On Thu, May 9, 2024 at 8:06 AM wangfangyikang @.> wrote: hi thanks, here is my BDIA implementation def bdia_forward(ddpm_pipe, batch_size, num_inference_steps, seed = 0, states=None, gamma = 1.0): dtype = torch.float32 # torch.manual_seed(seed) ddpm_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda') timesteps = ddpm_pipe.scheduler.timesteps torch.manual_seed(seed) xis = [] # Sample gaussian noise to begin loop if isinstance(ddpm_pipe.unet.config.sample_size, int): image_shape = ( batch_size, ddpm_pipe.unet.config.in_channels, ddpm_pipe.unet.config.sample_size, ddpm_pipe.unet.config.sample_size, ) else: image_shape = (batch_size, ddpm_pipe.unet.config.in_channels, ddpm_pipe.unet.config.sample_size) states = torch.randn(image_shape, generator=None, device='cuda', dtype=dtype) xis.append(states) with torch.no_grad(): for i, t in enumerate(timesteps): # print('###', i) noise_pred = ddpm_pipe.unet( states, t, return_dict=False, )[0] if i < num_inference_steps - 1: alpha_s = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i + 1]].to(torch.float32) alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32) else: alpha_s = 1 alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32) sigma_s = (1 - alpha_s)0.5 sigma_t = (1 - alpha_t)0.5 alpha_s = alpha_s0.5 alpha_t = alpha_t0.5 coef_xt = alpha_s / alpha_t coef_eps = sigma_s - sigma_t coef_xt if i == 0: states = coef_xt states + coef_eps noise_pred else: alpha_p = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i - 1]].to(torch.float32) sigma_p = (1 - alpha_p) 0.5 alpha_p = alpha_p 0.5 coef_xt = coef_xt - gamma alpha_p / alpha_t coef_eps_2 = sigma_p - sigma_t alpha_p / alpha_t coef_eps = coef_eps - gamma coef_eps_2 states = gamma xis[-2] + coef_xt xis[-1] + coef_eps noise_pred xis.append(states) image = xis[-1] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() image = ddpm_pipe.numpy_to_pil(image) return image — Reply to this email directly, view it on GitHub <#2 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/A6XG7OATFNZFM7IS5PAPYSLZBMN6HAVCNFSM6AAAAABHOIR7J2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBSGA3TSMZWHE . You are receiving this because you commented.Message ID: @.>

def ddim_forward(ddpm_pipe, seed, num_inference_steps, states=None):

    dtype = torch.float32
    # torch.manual_seed(seed)
    ddpm_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
    timesteps = ddpm_pipe.scheduler.timesteps

    xis = []
    # Sample gaussian noise to begin loop
    if isinstance(ddpm_pipe.unet.config.sample_size, int):
        image_shape = (
            1,
            ddpm_pipe.unet.config.in_channels,
            ddpm_pipe.unet.config.sample_size,
            ddpm_pipe.unet.config.sample_size,
        )
    else:
        image_shape = (1, ddpm_pipe.unet.config.in_channels, *ddpm_pipe.unet.config.sample_size)
    states = torch.randn(image_shape, generator=None, device='cuda', dtype=dtype)

    xis.append(states)
    with torch.no_grad():
        for i, t in enumerate(timesteps):
            # print('###', i)
            noise_pred = ddpm_pipe.unet(
                states,
                t,
                return_dict=False,
            )[0]

            if i < num_inference_steps - 1:
                alpha_s = ddpm_pipe.scheduler.alphas_cumprod[timesteps[i + 1]].to(torch.float32)
                alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32)
            else:
                alpha_s = 1
                alpha_t = ddpm_pipe.scheduler.alphas_cumprod[t].to(torch.float32)

            sigma_s = (1 - alpha_s)**0.5
            sigma_t = (1 - alpha_t)**0.5
            alpha_s = alpha_s**0.5
            alpha_t = alpha_t**0.5

            coef_xt = alpha_s / alpha_t
            coef_eps = sigma_s - sigma_t * coef_xt
            states = coef_xt * states + coef_eps * noise_pred
            xis.append(states)
    image = xis[-1]
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = ddpm_pipe.numpy_to_pil(image)
    return image