lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.11k stars 768 forks source link

Inpainting sampling results with full zeros masks looks off. #251

Closed xiankgx closed 2 years ago

xiankgx commented 2 years ago

This is the code used for generating the images with the model.

@torch.no_grad()
def generate(self, images, cond_scale: float = 10.0, inpaint_masks=None) -> np.ndarray:
        assert not self.autoencoder.training

        # Encode images with VAE
        posterior = self.autoencoder.encode(images)
        z = posterior.sample().detach()

        # Rescale latent values to a range suitable for diffusion

        if self.hparams.latent_scale == "ada_norm":
            z_mean = self.running_z_mean
            z_std = self.running_z_std

            # Normalize and denormalize using images stats
            z = (z - z_mean)/(z_std + 1e-6) * \
                self.running_images_std + self.running_images_mean
        elif isinstance(self.hparams.latent_scale, float):
            z = z * self.latent_scale
        else:
            raise Exception("Unknown latent rescaling")

        # Get conditioning
        assert not self.image_encoder.training
        image_embeds = self.image_encoder.encode_image(
            inv_normalize(images, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ).detach()
        # print(f"image_embeds.shape: {image_embeds.shape}")
        image_embeds = self.clip_image_embeds_to_text_embeds(image_embeds) \
            .reshape(image_embeds.size(0), self.num_image_tokens, -1)

        samples = self.imagen.sample(
            # texts=self.sample_texts
            text_embeds=image_embeds,
            cond_scale=cond_scale,
            inpaint_images=z if inpaint_masks is not None else None,
            inpaint_masks=inpaint_masks
        )

        # Deocde the latent feature map using the decoder from the autoencoder
        assert not self.autoencoder.training

        # Latent rescaling
        if self.hparams.latent_scale == "ada_norm":
            # Undo previous
            samples = ((samples - self.running_images_mean) /
                       self.running_images_std) * z_std + z_mean

        elif isinstance(self.hparams.latent_scale, float):
            samples = samples * 1/self.latent_scale
        else:
            raise Exception("Unknown latent rescaling")

        samples = self.autoencoder.decode(samples)

        samples = inv_normalize(samples, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        samples.clamp_(0, 1)
        samples = (samples.permute((0, 2, 3, 1)) *
                   255.0).detach().cpu().numpy().astype(np.uint8)

        return samples
QinSY123 commented 2 years ago

Hello,Do you have a model trained now? I would like to ask you about the specific value of the Unet parameter. The results I get so far are not good.

xiankgx commented 2 years ago

I can provide you the params, but it probably doesn't make sense for you, because I'm working on latent rather than image space. Instead of providing and predicting image pixels, I'm providing and predicting the latent feature map of the fixed autoencoder from Stable/Latent Diffusion.

But here is it anyway:

    --image_encoder_name "vggface2" \
    --num_image_tokens 8 \
    --dim 320 \
    --dim_mults "1,1,2,3" \
    --cond_dim 512 \
    --layer_attns "0,0,1,1" \
    --layer_cross_attns "0,0,1,1" \
QinSY123 commented 2 years ago

Thank you for your reply, I still have some questions, do you only use a Unet network.And @what is the role of the parameter num_image_tokens, I only saw its definition in the file imagen_pytorch.py and did not see its use

xiankgx commented 2 years ago

In the original implementation, the conditioning model is the T5 text encoder. In my model, I have modified to take the embedding from a vggface2 feature extractor (facenet_pytorch) which outputs an embedding of 512 dim for each image.

For example, if I have a batch size of 8, the size of this tensor would be (8, 512). However, the original model accepts text condition in the form of (batch_size, sequence_length, embedding_dim). Hence, I applied a linear layer to go from (8, 512) to (8, num_image_tokens(=8) * 512) and then reshape to (8, 8, 512). This is then used as the text conditioning in imagen (text_embeds).

Hence, this parameter you need not concern yourself with.

QinSY123 commented 2 years ago

Thanks