Fantasy-Studio / Paint-by-Example

Paint by Example: Exemplar-based Image Editing with Diffusion Models
1.08k stars 96 forks source link

question about training input #52

Open D222097 opened 9 months ago

D222097 commented 9 months ago
Nice work! I'm wondering that why the input can be set this way during training? image_GT????? inpaint_image inpaint_mask ref_imgs
img masked_img msk ref_img

In this work, I found that the inputs are gt(add noise), masked_img, mask and ref_img. As follows, the input x_start to unet is concatenated by z(encode on gt), z_inpaint(encode on masked_img) and mask_resize(downsampling mask):

z_new =,z_inpaint,mask_resize),dim=1)  # x_start
def p_losses(self, x_start, cond, t, noise=None, ):
    if self.first_stage_key == 'inpaint':
        # x_start=x_start[:,:4,:,:]
        noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
        x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise)
        x_noisy =,x_start[:,4:,:,:]),dim=1)
    model_output = self.apply_model(x_noisy, t, cond)

    if self.parameterization == "x0":
        target = x_start
    elif self.parameterization == "eps":
        target = noise
        raise NotImplementedError()

    loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])

I am curious about why GT image can be input into the unet directly. Even though it has been added with noise, it is still visible to the unet.

Use the images above as an example: the input is car image, and the expected output is car image during training. And when comes for infering, the input is image unrelated to car(arbitrary object or just background), and the expected output is car image.

This is a little weird. On the one hand, model needs GT to be optimized, and it is often used as a target in other generative model, rather than as a direct input to the model. On the other hand, diffusion model usually do not predict pixels but Gaussian noise, there seems to be no other way for diffusion model to be constrained from gt. I don't know how to understand how model learns, I'd be grateful if anyone could give me advice