explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
133 stars 29 forks source link

How to improve the reconstruction of high-frequency details in the VQVAE training? #25

Open vadori opened 1 month ago

vadori commented 1 month ago

Thank you so much for sharing!

Could you provide insights into the number of epochs required to achieve high-resolution, fine details during VQVAE training for 256x256 RGB images?

Additionally, has anyone compared VQVAE to VAE? Is VQVAE performing better? What key parameters can be adjusted in VQVAE to improve performance if the results aren't satisfactory?

I am trying to use the VQVAE to encode histological images. After 6 epochs with a batch size of 4, I get this a result on a sample image:

image

I'm wondering if the uniform color in the foreground, the absence of high-resolution details, and the presence of a seemingly repeated textural pattern across the image are simply the result of training for too few epochs and the model needing more time. Or could this indicate a more fundamental issue with my experimental setup? What are your thoughts?

Of course, it would be great if anyone could share their opinion on this!

Thanks again!

vadori commented 1 month ago

Thank you so much for sharing!

As mentioned in the title, could you provide insights into the number of epochs required to achieve high-resolution, fine details during VQVAE training for 256x256 RGB images?

Additionally, has anyone compared VQVAE to VAE? Is VQVAE performing better? What key parameters can be adjusted in VQVAE to improve performance if the results aren't satisfactory?

I am trying to use the VQVAE to encode histological images. After 6 epochs with a batch size of 4, I get this a result on a sample image:

image

I'm wondering if the uniform color in the foreground, the absence of high-resolution details, and the presence of a seemingly repeated textural pattern across the image are simply the result of training for too few epochs and the model needing more time. Or could this indicate a more fundamental issue with my experimental setup? What are your thoughts?

Of course, it would be great if anyone could share their opinion on this!

Thanks again!

After 60 epochs, the results have significantly improved (see below). However, I still notice that high-frequency details are missing from the reconstructions, and almost every image contains green spots. Is there a way to enhance the reconstruction of high-frequency components and eliminate the green spots? Thank you!

image

explainingai-code commented 1 month ago

Hello @Vadori , Can you share your config file for vqvae training once ? Regarding the green spots , I feel that it might be because the pixels are not bounded at inference time. Though there is code to handle this which clamps reconstruction but could you double check if this line isn't commented or modified in your setup ?

vadori commented 1 month ago

Hi @explainingai-code,

Thanks for replying!! Let me share the config file.

dataset_config:
  im_channels : 3
  im_size : 256
  name: 'cell'

diffusion_params:
  num_timesteps : 1000
  beta_start: 0.00085
  beta_end: 0.012

ldm_params:
  down_channels: [ 256, 384, 512, 768 ]
  mid_channels: [ 768, 512 ]
  down_sample: [ True, True, True ]
  attn_down : [True, True, True]
  time_emb_dim: 512
  norm_channels: 32
  num_heads: 16
  conv_out_channels : 128
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2
  condition_config:
    condition_types: None #[ 'classes' ]
    class_condition_config:
      cond_drop_prob : 0.1
      class_condition_l : 11
    text_condition_config:
      text_embed_model: 'clip'
      train_text_embed_model: False
      text_embed_dim: 512
      cond_drop_prob: 0.1
    image_condition_config:
      image_condition_input_channels: 18
      image_condition_output_channels: 3
      image_condition_h : 512
      image_condition_w : 512
      cond_drop_prob: 0.1

autoencoder_params:
  z_channels: 3
  codebook_size : 8192
  down_channels : [64, 128, 256, 256]
  mid_channels : [256, 256]
  down_sample : [True, True, True]
  attn_down : [False, False, False]
  norm_channels: 32
  num_heads: 4
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2
  num_headblocks : 1

train_params:
  seed : 1111
  task_name: 'cell'
  ldm_batch_size: 16
  autoencoder_batch_size: 4
  disc_start: 10
  disc_weight: 0.5
  codebook_weight: 1
  commitment_beta: 0.2
  perceptual_weight: 1
  kl_weight: 0.000005
  ldm_epochs: 2 # 100
  autoencoder_epochs: 60
  num_samples: 1
  num_grid_rows: 1
  ldm_lr: 0.000005
  autoencoder_lr: 0.00001
  autoencoder_acc_steps: 4
  autoencoder_img_save_steps: 256
  save_latents : False
  cf_guidance_scale : 1.0
  load_ckpt: True
  vae_latent_dir_name: 'vae_latents'
  vqvae_latent_dir_name: 'vqvae_latents'
  ldm_ckpt_name: 'ddpm_ckpt_classes_cond.pth'
  vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
  vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
  vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
  vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'

In regards to the inference, the snapshots I posted are from the visualizations during training, where values are actually clipped to 255 - so the line of code in the inference might solve the problem. Thank you! Do you also have any suggestions in regards to the parameters to tweak to ensure high frequency components are preserved?

explainingai-code commented 1 month ago

I see that the disc_start parameter has been set to 10 iterations. Have you tried training vqvae with a higher value of disc start. Like maybe 5000? If not, then can you once try that. (Assuming you have 1K images atleast, with batch size of 4 thats about 20 epochs). The reason why I say this is because I have seen folks mentioning on issues in official repo that they get best results when discriminator is kicked in after autoencoder(trained without disc_loss) has learnt to reconstruct the best it can(but still blurry). If its started too early(like 10 iterations in your config), it often leads to worse outputs and convergence. Maybe that helps fix the issue of quality in your case.

vadori commented 1 month ago

Many thanks for your kind reply @explainingai-code! I am testing with more vectors in the codebook - do you think it can help based on your experience? Additionally, what parameter controls the size of the autoencoder bottleneck? z_channels? I want to test with a different configuration and see if it helps. I am also experimenting with additional losses to see if I can enforce the reconstruction of high frequencies. Again, thank you. I am looking forward to your feedback!

explainingai-code commented 1 month ago

@Vadori You can try with more vectors, I didnt find much help on the dataset that I worked with(CelebHQ) but give it a try, might help for your case. I have added parameters which you can use to significantly increase the capability of autoencoder.

autoencoder_params:
  z_channels: 4
  codebook_size : 16384
  down_channels : [256, 384, 512, 768] 
  mid_channels : [768]
  down_sample : [True, True, True]
  attn_down : [False, False, False]
  norm_channels: 32
  num_heads: 4
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2

This autoencoder config will have a downscale factor o f 8(256x256 reduced to 32x32, as down_sample is done three times and there are three downblocks( from len(down_channels)-1). If you want to change the downscale factor to 4(but keep the three downblocks then simply change down_sample to [True, True, False] If you want to change downscale factor to 4 as well as downblocks also to two, then also change down_channels to something like [256, 384, 512] .

Let me know if this works for you or if you run into any issues with these changes.

vadori commented 1 month ago

@Vadori You can try with more vectors, I didnt find much help on the dataset that I worked with(CelebHQ) but give it a try, might help for your case. I have added parameters which you can use to significantly increase the capability of autoencoder.

autoencoder_params:
  z_channels: 4
  codebook_size : 16384
  down_channels : [256, 384, 512, 768] 
  mid_channels : [768]
  down_sample : [True, True, True]
  attn_down : [False, False, False]
  norm_channels: 32
  num_heads: 4
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2

This autoencoder config will have a downscale factor o f 8(256x256 reduced to 32x32, as down_sample is done three times and there are three downblocks( from len(down_channels)-1). If you want to change the downscale factor to 4(but keep the three downblocks then simply change down_sample to [True, True, False] If you want to change downscale factor to 4 as well as downblocks also to two, then also change down_channels to something like [256, 384, 512] .

Let me know if this works for you or if you run into any issues with these changes.

Hi again @explainingai-code, thanks for your reply!! Before your suggestion, I started experimenting with z_channels = 32, codebook_size = 16k, leaving the remaining parameters unchanged as follows:

autoencoder_params:
  z_channels: 32
  codebook_size : 16000
  down_channels : [64, 128, 256, 256]
  mid_channels : [256, 256]
  down_sample : [True, True, True]
  attn_down : [False, False, False]
  norm_channels: 32
  num_heads: 4
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2
  num_headblocks : 1

I also added 2 additional loss functions, and it looks like the performance has improved. However, in order to understand what could be a better solution (such as the one you suggested) I would need to be sure I understand the parameters. Maybe you can help me here? Or is there any documentation I may look at to get this info? codebook_size is quite self-explanatory, but what about the following variables? Could you please confirm/correct? It would be immensely helpful!

 z_channels: dimension of the codewords in the codebooks?
 down_channels: number of channels in the outputs of the four down layers?
 mid_channels: similar, for the middle layers?
 down_sample: if set to false, the width and height of the image do not change from one layer to the next?
 attn_down: no self-attention?
 norm_channels: number of groups for group normalization?
 num_heads:  number of attention heads?
 num_down_layers:  number of convolutional blocks in each residual block in the down layers
 num_mid_layers: similar for middle layers
 num_up_layers: similar for up layers
 num_headblocks: ?

Also, for example, with the following parameters, are the sizes of the outputs below correct?

z_channels:4
down_channels : [256, 384, 512, 768] 
down_sample : [True, True, True]

output of the first down layer: 128x128x384 output of the second down layer: 64x64x512 output of the third down layer: 32x32x768

Then after the middle layers, a Conv2d module is applied to get a 32x32x4 output, which is the size of the bottleneck, and each vector with dimension 4 is mapped onto a codeword (quantized).

If I set z_channels:32, I would get a 32x32x32 output, which is the size of the bottleneck, and each vector with dimension 32 would be mapped onto a codeword (quantized).

Thank you!!

explainingai-code commented 1 month ago

@Vadori Went through everything and your understanding for all of them is correct. Just one thing, I dont have any num_headblocks in config. I am assuming thats a parameter that you added. Also regarding different parameter values, it might benefit to also see what the authors of latent diffusion used(just as a guideline) - https://arxiv.org/pdf/2112.10752 You can just see Appendix E1, Table 13. Those parameters are for the diffusion model, but from that you can infer the autoencoder params like codebook size, z_channels and so on.

vadori commented 3 weeks ago

Hi again @explainingai-code,
Thank you once again for your helpful responses, so much appreciated! I noticed that in the current implementation, the cross-attention mechanism is applied only when using text conditioning. I’m planning to use cross-attention with image conditioning and was wondering if your decision to apply cross-attention only with text conditioning was based on experimental results (e.g., cross-attention with images not offering significant benefits and increasing training time, making its removal more practical). Thanks again!

Edit: In the paper of LDMs https://arxiv.org/pdf/2112.10752, it looks like for image-to-image translation tasks and specifically for semantic image synthesis, they actually concatenated downsampled versions of the semantic maps to the latent image representation - which is somewhat surprising to me as I was expecting the conditioning to be required on more network points. It would be great if you could let me know if you had any experience with the two distinct settings. (with and without cross attention at intermediate layers for semantic image synthesis)

explainingai-code commented 3 weeks ago

Hello @Vadori , Yes in the current implementation the only image conditioning that is done is spatial conditioning(for mask to image task). For spatial conditioning, the authors use the mechanism of concatenating the condition to input and hence I also used that. I would guess this should work better than cross attention(but this is not something that I validated through experiments).

However for use cases like say generating variations of image, where the goal is to retain the semantic aspects of image but not necessarily the spatial layout, cross attention would be appropriate choice. But since the repo only has code for spatial conditioning, you would need to make changes to be able to experiment with cross attention. As a quick fix, you can train it exactly like a text conditioned diffusion model(which the repo supports), but instead of passing text encodings to model, pass image encodings that you would extract from an image encoder like VIT here

vadori commented 3 weeks ago

Hi @explainingai-code,

Thank you! Why would you say that avoiding cross-attention would work better? I am interested in your intuition, even though you did not experiment with it. I am modifying the code accordingly using a custom encoder to generate mask encodings. I am planning to try both solutions (with cross and without).

explainingai-code commented 3 weeks ago

By better, I am mainly referring to how easy it is for the model to learn 'how to use spatial conditioning' . In concatenation, because of the convolution layer, each noisy pixel at x_ij is only impacted by the corresponding spatial pixel cond_ij, which is exactly what we need, since we want the denoising process to generate an image which has the exact same layout as conditioning image. Whereas with cross attention this is something the model would have to learn through training, which it should indeed end up learning, provided enough training compute, but concatenation does that job in a much simpler manner with fewer parameters. However like I said , this is just my guess :)

vadori commented 3 weeks ago

By better, I am mainly referring to how easy it is for the model to learn 'how to use spatial conditioning' . In concatenation, because of the convolution layer, each noisy pixel at x_ij is only impacted by the corresponding spatial pixel cond_ij, which is exactly what we need, since we want the denoising process to generate an image which has the exact same layout as conditioning image. Whereas with cross attention this is something the model would have to learn through training, which it should indeed end up learning, provided enough training compute, but concatenation does that job in a much simpler manner with fewer parameters. However like I said , this is just my guess :)

Thank you very much for you response. You may be right :) I am curious to see what works. What I like about cross attention is that cross attention is repeated throughout the network multiple times, offering repeated guidance, while the conditioning via concatenation is performed once, and the information is potentially diluted (aka, lost) while going from input to output - but this is just how I imagine it. Maybe concatenation could be performed at multiple layers. Experiments should answer - maybe, in my case, none works because I am working in the latent space with semantic mask encodings, not semantic mask downsampled versions. This is because, with simple downsampling, I lose too many details (the components in the mask are tiny, and some disappear, and I must ensure that all of them persist). The conditioning input (encoded mask) encodes a spatial layout, but the model must learn this to use its information to actually apply spatial conditioning when generating the images. When downsampling the conditioning semantic mask rather than encoding it, the fact that the mask brings with it a spatial layout is explicit, and the model does need to learn this, only to constrain the generative model accordingly.

explainingai-code commented 3 weeks ago

Got it. Regarding simple downscaling leading to loss of details, another thing you could try is instead of passing a downsampled version, pass normal (same size as original image) mask and add additional conv2d layers that take the mask from say 256x256(original image size) to 32x32(latent size). And then concat this to the noisy input, allowing model to learn features which provide the low level details that simple downscaled mask loses out on.

vadori commented 3 weeks ago

Got it. Regarding simple downscaling leading to loss of details, another thing you could try is instead of passing a downsampled version, pass normal (same size as original image) mask and add additional conv2d layers that take the mask from say 256x256(original image size) to 32x32(latent size). And then concat this to the noisy input, allowing model to learn features which provide the low level details that simple downscaled mask loses out on.

Great suggestion, thank you! I'll update you as soon as I have some results. Hopefully, it won’t be too long! 😄