danier97 / LDMVFI

[AAAI'2024] "LDMVFI: Video Frame Interpolation with Latent Diffusion Models", Duolikun Danier, Fan Zhang, David Bull
MIT License
123 stars 11 forks source link

The model doesn't work as expected #6

Closed zeinshaheen closed 1 year ago

zeinshaheen commented 1 year ago

Hello!

Thank you for the interesting work.

I played with your pretrained model, and I noticed that it doesn't work as expected.

To check, I took 3 consequence frames from a septuplet in Vimeo90k I used first and third one as conditions (left and right).

I encoded middle frame using vq-encoder to get latents of middle frame, and decoded the latents immediately conditioning on features extracted from left and right frames. I expected to have very close frame to the middle one. however, the generated frame has difference.

To make sure, I replaced middle latents with zeros, and passed them through decoder conditioning on features extracted from left and right frames, and I got the same output. Finally, I passed left as zeros, right and middle from septuplet and I got a mixed image between black and left frame.

Basically, I think your vq-model ignores latents generated from latent diffusion (or in my case, latents generated from middle frame) and it generates interpolated frame on its own conditioning on left and right frame.

def decode(self, rawFrame0, rawFrame1, rawFrame_middle):
        frame0 = TF.normalize(TF.to_tensor(rawFrame0), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()
        frame1 = TF.normalize(TF.to_tensor(rawFrame1), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()
        frame2 = TF.normalize(TF.to_tensor(rawFrame_middle), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()

        with torch.no_grad():
            with self.model.ema_scope():
                # form condition tensor and define shape of latent rep
                xc = {'prev_frame': frame0, 'next_frame': frame1}
                c, phi_prev_list, phi_next_list = self.model.get_learned_conditioning(xc)
                shape = (self.model.channels, c.shape[2], c.shape[3])

                out = self.model.first_stage_model.encode(frame2)
                if isinstance(out, tuple):
                    out = out[0]
                # reconstruct interpolated frame from latent
                out = self.model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
                out =  torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]

        return tensor2rgb(out)[0]
zeinshaheen commented 1 year ago

I think your idea is nice, and I am interested in implementing a vq-model that reconstructs input conditioning on left and right frames to improve the quality of the reconstructed frames.

danier97 commented 1 year ago

Hi,

Thank you for your interet in our work and I think it’s a very interesting question that you raised!

Regarding your first question of there being differences between the GT image and the decoded image when the GT latent is used, I believe this is expected because the autoencoder (AE) is not a lossless codec, and so it doesn’t guarantee perfect reconstruction.

Regarding your second question (AE being insensitive to latent input), indeed, in some cases the AE is not affected by the latent input, but it’s not generally the case, as illustrated below.

Considering the case of f=32 (AE downsampling factor of 32, i.e. the model you tested with), below are the outputs decoded by the AE when the input latent to the decoder is (1) encodings of GT middle frame and (2) zeros. image In this case, as you have observed, the AE output isn’t affected by the latent. However, for a more challenging scenario with more complex motion below, image it can be seen that the latent does play a role. My explanation for this behaviour is that when f=32, the is a significant amount of information in the features of neighbouring frames (phis) so that the AE has learned to rely more heavily on those. When the scene is relatively simple, the information from the phis might be sufficient to reconstruct the frame. However, in more challenging cases (e.g. more complex and larger motions), it becomes more important to decode a good latent.

To support the explanation above, I took a look at the case of f=8, where much less information from two input frames are passed into the decoder compared to the previous case. Below are the outputs in this case. image As can be seen, here the latent plays a more important role even in simple scenarios.

Many thanks.

zeinshaheen commented 1 year ago

The case with f=8 is interesting. Thank you for the explanation.