King-HAW / GMS

Official repository of Generative Medical Segmentation
https://arxiv.org/abs/2403.18198
48 stars 3 forks source link

Questions related to the SD-VAE #2

Open batman47steam opened 6 months ago

batman47steam commented 6 months ago

Hi, Thank you for sharing this work. It has been very inspiring for me. However, I still have some questions regarding the SD-VAE part. Is the SD-VAE directly using the existing Stable Diffusion VAE? The paper seems to mention that the SD-VAE has three corresponding upsampling and downsampling layers. Was this part designed independently? If it needs to be designed independently, how are the SD pretrained weights utilized? Looking forward to your response.

King-HAW commented 6 months ago

Hi, thanks for following our work. Yes, we use the first stage model (the VAE) weight of the Stable Diffusion VAE. You can find the original model weight here, we delete the denoising UNet part since we only care about image compression and reconstruction. Three upsampling and downsampling layers are included in the original model weights, and we do not make any changes to this part.

batman47steam commented 6 months ago

Hi, thank you for your response. It helps a lot! I find that the training speed on an RTX 3090 GPU with is a bit slow. Is this normal? Is the main overhead during training due to the inference process of the SD-VAE? It seems that latent mapping model itself is relatively lightweight.

King-HAW commented 6 months ago

Yes, it is normal. The structure of SD-VAE is much more complex than the latent mapping model, so the inference of SD-VAE may take some time. Also, I use deep supervision during model training (you can find the code here), which means the decoding process will repeat 3 times. If you want to train the GMS fast, try to remove deep supervision (set ds_list=['out'], code is here).

batman47steam commented 6 months ago

Thank you very much for your prompt and helpful response! I get it now.

JAYCHOU2020 commented 6 months ago

Yes, it is normal. The structure of SD-VAE is much more complex than the latent mapping model, so the inference of SD-VAE may take some time. Also, I use deep supervision during model training (you can find the code here), which means the decoding process will repeat 3 times. If you want to train the GMS fast, try to remove deep supervision (set ds_list=['out'], code is here).

Hello,thanks for your great work, I would like to ask about an Error when I load weights for training: RuntimeError: Error(s) in loading state_dict for AutoencoderKL: Missing key(s) in state_dict: "encoder.conv_in.weight", "encoder.conv_in.bias", "encoder.down.0.block.0.norm1.weight", "encoder.down.0.block. Is it because I loaded the wrong weights?I downloaded this weight:https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt?download=true

King-HAW commented 6 months ago

Since the SD team might reconstruct the SD-VAE structure, please try the model weight in this link.

JAYCHOU2020 commented 6 months ago

Since the SD team might reconstruct the SD-VAE structure, please try the model weight in this link.

Thanks for your reply, ill try