yisol / IDM-VTON

[ECCV2024] IDM-VTON : Improving Diffusion Models for Authentic Virtual Try-on in the Wild
https://idm-vton.github.io/
3.91k stars 612 forks source link

pre-trained model #151

Open naelsen opened 1 month ago

naelsen commented 1 month ago

Can someone tell me where could I find the weights of the model ? I'm lost amongst all these information. Basically I just want to make inference with the pre-trained model and I don't know how to.

samsara-ku commented 1 month ago

@naelsen Here

naelsen commented 1 month ago

Hello @samsara-ku, Could you give me an example step by step on how to use this repo to make inference on my own data ? I appreciate your help

samsara-ku commented 1 month ago

@naelsen Did you check gradio_demo/app.py? There is a pipeline for custom data (i.e. not prepared agnostic-mask, and densepose). You should better to check about it and I believe you can make extended codes for many data based on that.

naelsen commented 1 month ago

@samsara-ku Yes I checked, but I want to deploy a service that I will run on my own hardware. What should I check for that ?

samsara-ku commented 1 month ago
                    images = pipe(
                        prompt_embeds=prompt_embeds.to(device,torch.float16),
                        negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
                        pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
                        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
                        num_inference_steps=denoise_steps,
                        generator=generator,
                        strength = 1.0,
                        pose_img = pose_img.to(device,torch.float16),
                        text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
                        cloth = garm_tensor.to(device,torch.float16),
                        mask_image=mask,
                        image=human_img, 
                        height=1024,
                        width=768,
                        ip_adapter_image = garm_img.resize((768,1024)),
                        guidance_scale=2.0,
                    )[0]

Basically, all you have to do is preparing argument of above pipeline; 1) prompt embeddings (e.g. prompt_embeds, negative_prompt_embeds and so on), 2) pose_img, 3) text_embeds_cloth, 4) cloth, 5) mask_image, 6) image.

  1. Prompt embedding are just convert text to vector using tokenizer, you just check that code. Same with text_embeds_cloth.

  2. pose_img are came from args.func(args,human_img_arg), which is output of densepose. In the gradio_demo/app.py, they already prepare for this output so that you just read how to use this code.

  3. garm_img is just cloth image, we don't need to pre-process.

  4. mask_image is came from get_mask_location function, which is usually called agnostic mask in the VTON paper. All you need to know is just agnostic mask is came from openpose. You need to run openpose_model function to your custom dataset.

  5. I don't know what human_img_arg function is but just run it and pass to the pipe function.