xichenpan / ARLDM

Official Pytorch Implementation of Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models
https://arxiv.org/abs/2211.10950
MIT License
182 stars 28 forks source link

Implementation about classifier free guidance #20

Closed skywalker00001 closed 10 months ago

skywalker00001 commented 1 year ago

Hi, I have some little questions about how to implement the classifier free guidance generation in your code.

As far as I know, classifier free guidance needs two steps.

  1. Training: randomly select samples (p=0.1 for example) and mask all the context of the selected samples. This means we jointly train two models under the same architecture (p=0.1 to train a unconditional model based on the null context, and p=0.9 to train a conditional model)
  2. Sampling: simultaneously using two models to generate Noise1 from conditional model and Noise2 from unconditional model, and use the formula Noise = Noise1 + w* (Noise1 - Noise2).

But in your code, I'm confused about why you randomly discard some frames in the context rather than all the frames. Because in the sampling stage, you seem to generate the Noise2 from null context.

image

In the above training stage, you only fill random frames with the null, rather than entirely sample. (I think it maybe classifier_free_idx = np.random.rand(B) rather than classifier_free_idx = np.random.rand(B*V))

And in the sampling, I think your code is correct.

image

We should align the behaviour between the training and sampling by both discard all frames in one sample, yeah?

Beside, I don't know if there is a little typo in main.py 355

image

Does it supposed to be : noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond) ? Because according to the paper:

image
skywalker00001 commented 1 year ago

If you just discard some frames, and remain the other frames in the training, it's like MAE, let model learn how to generate the next frame based on the corrupted previous frames (let's say if we are generating the frame 3, and your free idx goes to the frame 1, that means the stable diffusion is trying to learn how to diffuse from (ClIP[3]+BLIP[0, 2]) context. But this method doesn't align how you do sampling.

xichenpan commented 1 year ago

Hi @skywalker00001, thank you for your comment. Actually, you can treat our batch size as B * V, which means we generate every single frame according to the previous frames in a story. We do drop all clip and blip conditions at the classifier_free_id frames. And for other frames, the conditons are not corrupted, as you can see we save a copy in https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L200 You can debug our code to better understand the shape, it is quite confusing.

And the code

https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L355

you mentioned is actually following the implementation of Diffusers

https://github.com/huggingface/diffusers/blob/b9b891621e8ed5729761cc6a31b23072315d2df0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L686

actually the $scale=1+w$, the two formula is the same

skywalker00001 commented 1 year ago

Hi, xichenpan, thank you for your answer! I see. Does that mean you actually trained 4 extra unconditional models (let's say our task is continuation ) for each frames? unconditional model 1: [null clip+ null blip frame 0] unconditional model 2: [null clip+ null blip frame0 + null blip frame1] unconditional model 3: [null clip+ null blip frame0 + null blip frame1 + null blip frame2] unconditional model 4: [null clip+ null blip frame0 + null blip frame1 + null blip frame2 + null blip frame 3] And use attention mask = 1 to mask the future frames in each model. Is my interpretation right?

And for the second question, yeah, you are right, the two formulas are the same. Thanks again for your patience!

skywalker00001 commented 1 year ago

Besides, I encountered another problem. How do you calculate the FID score? I debugged and output the shape of "original_images" and "images" https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L313-L323

original_images has the shape: # (4, 3, 128, 128), but don't we need to permute it before turn it to PIL image? Because the "original_images" now has the PIL image shape 128 3 but "images" has the PIL image shape of 512 512 (the stable diffusion output size).

I know in the code https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L381-L390 It will turn the images to (3, 64, 64 ) in the line of 382, but I suspect that the values will change. For example, I use "transforms.ToTensor()(images[0])" for "original_images", and the shape is (3, 3, 128).

And I use [transforms.ToPILImage()(self.fid_augment(im)).save("fakedpics/ori{:02}.png".format(idx)) for idx, im in enumerate(images)], the generated images are all corrupted for "original_images". But for the generated "images", there will not be error. the "transforms.ToTensor()(images[0])" has the shape (3, 512, 512) and PIL save the correct image.

image

Therefore, I suppose should it be

original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype('uint8')

in the line 316? https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L316

If that is the case, maybe the FID scores will be changed too...

skywalker00001 commented 1 year ago

And I think the multi-GPU inference is supported by Pytorch Lightning.

I only add "strategy="ddp"," in the Trainer, and set "args.gpu_ids" to [0, 1, 2, 3] and comment your code in line https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L425

It succeed. So I think it may help others.

image

for https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/main.py#L430-L435

xichenpan commented 1 year ago

@skywalker00001 Hi, for the frist issue, it is one single uncond model (with varied length), because all params are shared. For the second issue, thanks for pointing that out, I found in our original implementation, we do not permute the original_image, and it has a shape of BHWC https://github.com/xichenpan/ARLDM/blob/a24e2e94332eb86fcc071abb83aaf341006aa622/ARLDM.py#L146 While for our current implementation we do, so it has a shpae of BCHW https://github.com/xichenpan/ARLDM/blob/b8c1db4627e44a881d67191199a18706f7e0af93/datasets/flintstones.py#L65-L66 So this cause an inconsistency cause we still copy the code from old implementation: https://github.com/xichenpan/ARLDM/blob/a24e2e94332eb86fcc071abb83aaf341006aa622/ARLDM.py#L353-L357 We will remove the permute in dataset code, so that the shape is correct. So I believe the FID score we reported in our original paper is correct, while current repo do not correctly immigrant the original implementation. Another user has reported this issue, while I am sorry that I forget to correct it, https://github.com/xichenpan/ARLDM/issues/10#issuecomment-1445405938. For the final issue, you can do so, while running on multiple GPU may droplast or assign a same sample multiple times on different GPU, so we only use single GPU in inference, if you want to get the result quickly, it is acceptable.

skywalker00001 commented 1 year ago

Got it! Thanks again.

skywalker00001 commented 10 months ago

您好,我是侯翼,您的邮件已收到,祝您生活愉快~