jjihwan / FIFO-Diffusion_public

Official implementation of FIFO-Diffusion
https://jjihwan.github.io
277 stars 17 forks source link

Multi-prompt Video Generation Fix #22

Open prajwalsingh opened 5 days ago

prajwalsingh commented 5 days ago

Thank you for making code publicly available.

Here are the fix in the code for multi-prompt video generation:

  1. Add the following line in the argparse section

    parser.add_argument("--multiprompt", action="store_true", default=False, help="given prompt is multi-action prompt")
    
  2. Add the following code in main function

    if not args.multiprompt:
    video_frames = fifo_ddim_sampling(
        args, model, cond, noise_shape, ddim_sampler, args.unconditional_guidance_scale, output_dir=output_dir, latents_dir=latents_dir, save_frames=args.save_frames
    )
    else:
    video_frames = fifo_ddim_sampling_multiprompts(
        args, model, cond, noise_shape, ddim_sampler, cfg_scale=args.unconditional_guidance_scale, output_dir=output_dir, latents_dir=latents_dir, save_frames=args.save_frames, multiprompts=prompts
    )
    

if args.output_dir is None: output_path = output_dir+"/fifo" else: output_path = output_dir+f"/{prompt[:100]}"

if args.use_mp4 and not args.multiprompt: imageio.mimsave(output_path+".mp4", video_frames[-args.new_video_length:], fps=args.output_fps) elif not args.multiprompt: imageio.mimsave(output_path+".gif", video_frames[-args.new_video_length:], duration=int(1000/args.output_fps))

if args.use_mp4 and args.multiprompt: imageio.mimsave(output_path+".mp4", video_frames[(args.num_inference_steps-args.video_length):], fps=args.output_fps) elif args.multiprompt: imageio.mimsave(output_path+".gif", video_frames[(args.num_inference_steps-args.video_length):], duration=int(1000/args.output_fps))

  1. In the funcs.py file, replace the fifo_ddim_sampling_multiprompts function with following

    def fifo_ddim_sampling_multiprompts(args, model, conditioning, noise_shape, ddim_sampler, multiprompts,
                                    cfg_scale=1.0, output_dir=None, latents_dir=None, save_frames=False, **kwargs):
    batch_size = noise_shape[0]
    kwargs.update({"clean_cond": True})
    
    # prompt_lengths = np.array([int(i) for i in multiprompts[-1].split(',')]).cumsum()
    # multiprompts_embed = [model.get_learned_conditioning(prompt) for prompt in multiprompts[:-1]]
    # prompt_lengths = np.array([len(i.split(' ')) for i in multiprompts[-1].split(',')]).cumsum()
    prompt_lengths = np.array([len(i) for i in multiprompts[-1].split(',')]).cumsum()
    # total_prompt_lengths = [prompt_lengths.sum()]
    print('prompt_lengths:', prompt_lengths)
    multiprompts_embed = [model.get_learned_conditioning(prompt) for prompt in multiprompts[-1].split(',')]
    # assert len(prompt_lengths) == len(multiprompts_embed)
    
    # check condition bs
    if conditioning is not None:
      if isinstance(conditioning, dict):
          try:
              cbs = conditioning[list(conditioning.keys())[0]].shape[0]
          except:
              cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
    
          if cbs != batch_size:
              print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
      else:
          if conditioning.shape[0] != batch_size:
              print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
    
    cond = conditioning
    ## construct unconditional guidance
    if cfg_scale != 1.0:
      prompts = batch_size * [""]
      #prompts = N * T * [""]  ## if is_imgbatch=True
      uc_emb = model.get_learned_conditioning(prompts)
    
      uc = {key:cond[key] for key in cond.keys()}
      uc.update({'c_crossattn': [uc_emb]})    
    else:
      uc = None
    
    latents = prepare_latents(args, latents_dir, ddim_sampler)
    
    num_frames_per_gpu = args.video_length
    fifo_dir = os.path.join(output_dir, "fifo")
    os.makedirs(fifo_dir, exist_ok=True)
    
    fifo_video_frames = []
    
    timesteps = ddim_sampler.ddim_timesteps
    indices = np.arange(args.num_inference_steps)
    
    if args.lookahead_denoising:
      timesteps = np.concatenate([np.full((args.video_length//2,), timesteps[0]), timesteps])
      indices = np.concatenate([np.full((args.video_length//2,), 0), indices])
    
    j = 0
    
    # for i in trange(args.new_video_length + args.num_inference_steps - args.video_length, desc="fifo sampling"):
    #     pass
    
    for i in trange(prompt_lengths[-1] + 2*(args.num_inference_steps - args.video_length), desc="fifo sampling"):
    
      if  (( i - (args.num_inference_steps - args.video_length) ) >= prompt_lengths[j]) and (( i - (args.num_inference_steps - args.video_length) ) < prompt_lengths[-1]):
          j = j +1
    
      embed = multiprompts_embed[j]
    
      cond.update({'c_crossattn':[embed]})
      for rank in reversed(range(2 * args.num_partitions if args.lookahead_denoising else args.num_partitions)):
          start_idx = rank*(num_frames_per_gpu // 2) if args.lookahead_denoising else rank*num_frames_per_gpu
          midpoint_idx = start_idx + num_frames_per_gpu // 2
          end_idx = start_idx + num_frames_per_gpu
    
          t = timesteps[start_idx:end_idx]
          idx = indices[start_idx:end_idx]
    
          input_latents = latents[:,:,start_idx:end_idx].clone()
          output_latents, _ = ddim_sampler.fifo_onestep(
                                          cond=cond,
                                          shape=noise_shape,
                                          latents=input_latents,
                                          timesteps=t,
                                          indices=idx,
                                          unconditional_guidance_scale=cfg_scale,
                                          unconditional_conditioning=uc,
                                          **kwargs
                                          )
          if args.lookahead_denoising:
              latents[:,:,midpoint_idx:end_idx] = output_latents[:,:,-(num_frames_per_gpu//2):]
          else:
              latents[:,:,start_idx:end_idx] = output_latents
          del output_latents
    
      # reconstruct from latent to pixel space
      first_frame_idx = args.video_length // 2 if args.lookahead_denoising else 0
      frame_tensor = model.decode_first_stage_2DAE(latents[:,:,[first_frame_idx]]) # b,c,1,H,W
      image = tensor2image(frame_tensor)
      if save_frames:
          fifo_path = os.path.join(fifo_dir, f"{i}.png")
          image.save(fifo_path)
      fifo_video_frames.append(image)
    
      latents = shift_latents(latents)
    return fifo_video_frames
    

The above modifications helps in generating multi-prompt videos as shown in the paper.

  1. Multi-prompts for input
    A tiger running on the grassland photorealistic 4k high definition, A tiger standing on the grassland photorealistic 4k high definition, A tiger resting on the grassland photorealistic 4k high definition.
    A tiger resting on the grassland photorealistic 4k high definition, A tiger standing on the grassland photorealistic 4k high definition, A tiger running on the grassland photorealistic 4k high definition.
    Ironman running on the road 4k high resolution, Ironman standing on the road 4k high resolution, Ironman flying on the road 4k high resolution.
    A teddy bear running on the street 4k high resolution, A teddy bear standing on the street 4k high resolution, A teddy bear dancing on the street 4k high resolution.
    A whale swimming on the surface of the ocean 4k high resolution, A whale jumps out of water on the surface of the ocean 4k high resolution.
    Titanic sailing through the sunny calm ocean 4k high resolution, Titanic sailing through a stormy ocean with lightning 4k high resolution.
    A pair of tango dancers performing in Buenos 4k high resolution, A pair of tango dancers kissing in Buenos 4k high resolution.
    
jjihwan commented 5 days ago

Thank you so much for your dedicated efforts. Honestly, I forgot to release the multi-prompt version of FIFO in this public repository, but you have already succeeded! I will incorporate your changes into my repository within a week.

prajwalsingh commented 5 days ago

@jjihwan Thank you.