jjihwan / FIFO-Diffusion_public

Official implementation of FIFO-Diffusion
https://jjihwan.github.io
277 stars 17 forks source link

Timestep Injection Strategy #13

Closed BurhanUlTayyab closed 3 weeks ago

BurhanUlTayyab commented 3 weeks ago

Hi

I'm implementing FIFO Diffusion in Open Sora Plan. This is my prototype

`

    latents_muna = self.prepare_latent_muna(latents_dir, num_inference_steps, video_length=17)
    num_frames_per_gpu = 17
    video_length = 17
    os.makedirs("fifo_dir", exist_ok=True)
    indices = np.arange(num_inference_steps)

    timesteps = np.concatenate([np.full((video_length//2,), timesteps.cpu()[0]), timesteps.cpu()])
    indices = np.concatenate([np.full((video_length//2,), 0), indices])

    new_video_length = 102
    num_partitions = 4
    lookahead_denoising = True
    for i in trange(new_video_length + num_inference_steps - video_length, desc="fifo sampling"):
        for rank in reversed(range(2 * num_partitions)):
            start_idx = rank*(num_frames_per_gpu // 2) if lookahead_denoising else rank*num_frames_per_gpu
            midpoint_idx = start_idx + num_frames_per_gpu // 2
            end_idx = start_idx + num_frames_per_gpu

            t = timesteps[start_idx:start_idx+end_idx]
            idx = indices[start_idx:end_idx]

            input_latents = latents_muna[:,:,start_idx:end_idx].clone()
            latent_model_input = torch.cat([input_latents] * 2) if do_classifier_free_guidance else input_latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            if not torch.is_tensor(t):
                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                # This would be a good case for the `match` statement (Python 3.10+)
                is_mps = latent_model_input.device.type == "mps"
                if isinstance(t, float):
                    dtype = torch.float32 if is_mps else torch.float64
                else:
                    dtype = torch.int32 if is_mps else torch.int64
                t = torch.tensor(t, dtype=dtype, device=latent_model_input.device)
            elif len(t.shape) == 0:
                t = t[None].to(latent_model_input.device)
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            # t = t.expand(latent_model_input.shape[0])
            if prompt_embeds.ndim == 3:
                prompt_embeds = prompt_embeds.unsqueeze(1)  # b l d -> b 1 l d
            noise_pred = self.transformer(
                latent_model_input,
                encoder_hidden_states=prompt_embeds,
                timestep=t,
                added_cond_kwargs=added_cond_kwargs,
                enable_temporal_attentions=enable_temporal_attentions,
                return_dict=False,
            )[0]

`

However, it seems that I need to implement some kind of timestep injection ( t = timesteps[start_idx:start_idx+end_idx] ) which should be different for each frame. Can you kindly let me know how you implemented it?

jjihwan commented 3 weeks ago

You're right, you might have to slightly modify the model's code(i.e. models.diffusion.latte.modeling_latte.LatteT2V) I'm planning to release the code for Open-Sora plan v1.1.0 soon, so please wait for a few days.

roundchuan commented 3 weeks ago

You're right, you might have to slightly modify the model's code(i.e. models.diffusion.latte.modeling_latte.LatteT2V) I'm planning to release the code for Open-Sora plan v1.1.0 soon, so please wait for a few days.

hello, I meet the same problems. Can you provide the codes~

jjihwan commented 3 weeks ago

The code has been released :)