xichenpan / ARLDM

Official Pytorch Implementation of Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models
https://arxiv.org/abs/2211.10950
MIT License
182 stars 28 forks source link

About the image size #10

Closed SaulZhang closed 1 year ago

SaulZhang commented 1 year ago

Hi, thank you for your wonderful work. I have some questions. 1.Due to the limited memory of each GPU (8*A100 40G), I can only resize images into 256x256 (but not 512x512) so that my GPU can accommodate one story to train AR-LDM. What impact will this change have on the FID score? 2.Besides, the process of sample is very slow. Can we perform the sampling on multiple GPUs?

xichenpan commented 1 year ago

@SaulZhang Hi, thanks for your inquiry! For the first question, since the stable diffusion PTM is for 512512, so I guess the performance will drop, I recommend you to try fp16 training with 512 512 resolution. For the second quesiton, the answer is yes! you can also increase the batch size. We use single GPU to avoid sample a same case for multiple times, thereby affect FID score. (for normal experiment, you may feel free to use multiple GPU and large batch size)

SaulZhang commented 1 year ago

Thanks for your reply. I have tried setting precision=16 in Trainer and also setting freeze_clip/freeze_blip/freeze_resnet=True. Unfortunately, these changes only allow a single A100 40G GPU to handle a maximum resolution of about 470x470. If I don't freeze the weights of stable diffusion, could it reduce the impact of image resolution?

xichenpan commented 1 year ago

@SaulZhang Hi, I guess it may work, or you can also try to enable gradient checkpointing to save vram. It seems to be already implemented by Diffusers.

SaulZhang commented 1 year ago

Okay, thank you for the suggestion. I will give it a try.

SaulZhang commented 1 year ago

@Flash-321 Hello, I have set the image size to 256x256 and performed story continuation experiments on three datasets. Although the generated images look satisfactory, I'm confused as to why the calculated FID Score exceeds 400. This appears to be quite unreasonable, and it's consistent across all three datasets. Additionally, I have attached some of the stories that were generated on three datasets. Could you help me understand why the FID Score is so high? out out1 out2

xichenpan commented 1 year ago

@SaulZhang Hi, do you calculate the FID score across the whole dataset, or only a subset?

SaulZhang commented 1 year ago

I calculate the FID score across the whole testing set, and don't modify the code of sample.

SaulZhang commented 1 year ago

After thorough debugging, I've identified that the main issue lies within this particular line of code. original_images = [Image.fromarray(im, 'RGB') for im in original_images] And the correct code should be as follows: original_images = [Image.fromarray(im.transpose(1,2,0), 'RGB') for im in original_images]