TencentARC / BrushNet

[ECCV 2024] The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
https://tencentarc.github.io/BrushNet/
Other
1.36k stars 114 forks source link

Question about training #25

Closed dydxdt closed 4 months ago

dydxdt commented 5 months ago

Your job is great! I have some questions about training epochs. I want to train BrushNet on my own data, and I see the default training epoch is 10000. And I also see the config.json in your offered model weights: random_mask_brushnet_ckpt: "runs/logs/brushnet_randommask/checkpoint-100000" segmentation_mask_brushnet_ckpt: "runs/logs/brushnet_segmask/checkpoint-550000" And it seems that other models also corresponds to different training epochs.

So generally I can use 10000 training epochs or I need to choose based on the loss values? In the paper, it says: "BrushNet and all ablation models are trained for 430 thousands steps on 8 NVIDIA Tesla V100 GPUs, which takes around 3 days". For my situation, I think my training seems to take much longer time than that.

Thx for your reply.

juxuan27 commented 5 months ago

Hi, the “10000” epoch is simply a hyperparameter a you can stop it when the generation quality is good enough. I personally do no recommend using loss to see when to stop because sometimes generation results will get better when loss would not decrease. The config.json also simply write a random ckpt path, you may ignore it. For the training time, how long do take? I think different machine may need different time, but should roughly in the same time range.

Shuvo001 commented 4 months ago

@dydxdt hi, i would like to ask how did you prepare the custom dataset like BrushData structure? could you give me some docs or any idea?

Thanks :)

congwei1230 commented 4 months ago

Thanks for the great work! Regarding the segmentation_mask_brushnet_ckpt_sdxl_v0and random_mask_brushnet_ckpt_sdxl_v0provided on Google Drive, could you tell me how many training steps were completed for each of these checkpoints? I assume that these models are trained with a batch size of 1 and a gradient accumulation of 4 on 8 GPUs. Thank you! @juxuan27