Doubiiu / DynamiCrafter

[ECCV 2024] DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors
Apache License 2.0
2.09k stars 165 forks source link

Gradio takes a long time to run #31

Closed Yangfan-96 closed 3 months ago

Yangfan-96 commented 4 months ago

When I run the 576*1024 model, it takes 86 seconds to run with the 'run.sh' command, but it takes 250 seconds with Gradio. Why is this happening?

Doubiiu commented 4 months ago

Hi. We use the mixed-precision in running script run.sh. You can manually add with torch.no_grad(), torch.cuda.amp.autocast(): to Line 56 in i2v_test.py can then restart the radio:

    with torch.no_grad(), torch.cuda.amp.autocast():   ##Add this line of code with proper indent
        # text cond
        text_emb = model.get_learned_conditioning([prompt])

        # img cond
        img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
        img_tensor = (img_tensor / 255. - 0.5) * 2

        image_tensor_resized = transform(img_tensor) #3,h,w
        videos = image_tensor_resized.unsqueeze(0) # bchw

        z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw

        img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)

        cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
        img_emb = model.image_proj_model(cond_images)

        imtext_cond = torch.cat([text_emb, img_emb], dim=1)

        fs = torch.tensor([fs], dtype=torch.long, device=model.device)
        cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}

        ## inference
        batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
        ## b,samples,c,t,h,w
        prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
        prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
        prompt_str=prompt_str[:40]
        if len(prompt_str) == 0:
            prompt_str = 'empty_prompt'