Janspiry / Palette-Image-to-Image-Diffusion-Models

Unofficial implementation of Palette: Image-to-Image Diffusion Models by Pytorch
MIT License
1.5k stars 200 forks source link

How to correctly set up a custom dataset class to train the model? #14

Closed DanBigioi closed 2 years ago

DanBigioi commented 2 years ago

Specifically,

I'm looking at line 58 in Palette-Image-to-Image-Diffusion-Models/data/dataset.py

I'm trying to set up my own custom mask function, to process my dataset, for a variation of the image2image cropping task, and this is what I have so far.

I have code that generates a custom mask like this for a given image. The mask has shape 2562563:
image

and a target image like this: Obama_256x256

For the get_item method in data/dataset.py, I have some questions regarding the following lines:

img = self.tfs(self.loader(path)) mask = self.get_mask() cond_image = img(1. - mask) + masktorch.randn_like(img) mask_img = img*(1. - mask) + mask

1) what does that tfs method do, and is it necessary to use on my ground truth image? 2) For my own get mask function, my mask has shape (h,w,3). Does this need to be (h,w,1) instead to account for the "hole", and valid regions? If so, how do I work around this so that I can include information about the lip position in the mask like I have in image 1. 3) What does the cond_image calculation do and why is it done? 4) Why is the masked image calculated this way, and not by doing a bitwise and multiplication between the img and mask? If I try to do this using my own data, the result is messed up completely.

Additionally, I have 1 extra question. Because I want to go from masked image with a drawing of the lips -> gt image, is this more suited for an image colorization task??

Thanks!

DanBigioi commented 2 years ago

Just to answer my 4th question, I was making a mistake. Was calculating the mask_img like this: mask_img = img*(255. - mask) + mask

When instead I should have omitted that . after the 255 as it was completing screwing up my result, very stupid mistake hahaha.

Janspiry commented 2 years ago

Hi, thanks for your attention.

  1. As shown in https://github.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/blob/d8853b89956a4653dae38c0becc51f64c5fde86c/data/dataset.py#L42, the tfs function just transform the PIL image to Tensor [-1,1], whose form is consistent with the standard Gaussian distribution. You need to use it in all images.
  2. mask is the binary tensor {0,1} with shape [H,W,1], and you can refer to https://github.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/blob/d8853b89956a4653dae38c0becc51f64c5fde86c/data/util/mask.py#L100
  3. As the paper said image
  4. As I understand it, you can set the lip area mask to 1 and the rest to 0, and then do Inpainting
DanBigioi commented 2 years ago

Awesome thank you for your answers, they all clear it up for me. With regards to your 4th answer, do you think something like this would do the trick?:

mask: image

masked_image (please ignore the colour): image

Janspiry commented 2 years ago

I didn't quite understand the use of this mask. Are you trying to recover the full image with the given lip outline and part of the face? But in any case, it should be possible to recover the face. It depends on how much you train.

DanBigioi commented 2 years ago

Yup! Given the lip outline + part of the face, I want it to output the full face. Ideally when its trained, at the inference stage, I can choose any lip outline, for a particular face, and the network should generate the full face with what the lips generated based off the outline.

The application is to generate a set of facial landmarks given audio, and using these generated facial landmarks, to render a photorealistic video based off the new positions of the lips.

Janspiry commented 2 years ago

I understand it.

  1. mask is 0 for masked pixels, you may need to change the setting. This setting will be used by default in the model.
  2. the second schematic you gave is easy to understand and I think it will have some effect. But the model is not designed for contour information, and its perception of a given contour may not be particularly good.
DanBigioi commented 2 years ago

Got it! thanks so much for your help, will close the issue now 😄