ZhengPeng7 / BiRefNet

[CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation
https://www.birefnet.top
MIT License
322 stars 28 forks source link

Training with customdataset #10

Closed ZeVicTech closed 3 months ago

ZeVicTech commented 3 months ago

Hello I am amazed at the performance of your created model. So I want to training with custom data, but I'm having some issues.

  1. When I resumed training after the interruption, the training loss increased significantly. (Is this because the model weights are saved but the optimizer information is not?)

  2. In the init_models_optimizers function in train.py, there is a variable epoch_st. I think epoch_st should be a global variable, but is there a reason why you have it set up like this?

  3. I currently have a custom dataset of about 9000 images. Due to the small number of data, I am adding the DIS dataset and HRSOD to run training. Is it okay to train them like this? Or should I just train it with custom data? (I use BiRefNet_ep580.pth)

I look forward to your response, thank you.

ZhengPeng7 commented 3 months ago

Hi, thanks for your interest. About the questions:

  1. Yeah, you are right, other optim info was not saved. I didn't use much the resuming functions, so the processing there might be poor.
  2. If training is from scratch, the epoch_st increases from 1. If not, the epoch_st will be the epochs for which the saved ckpt was trained. For example, the total training epochs is 100, the training was interrupted in epoch 80, and the ckpt saved was ep79.pth. Then, training with resuming, epoch_st should be from 80.
  3. Perfect! I've made up the codes flexible for more training sets. Check the training_set in config.py. For example, you can add a new key-value pair to it as: 'custom_dataset': 'YOUR_9000_DATA+DIS-TR+TR-DUTS+TR-HRSOD+TR-UHRSD'. Make sure the training sets in value are correctly put in the specific folder, and images/gts of all sets should be arranged as SET/im and SET/gt. If the data in your test set is more similar to your training set of 9000 images, I recommend you train it from scratch. And if the resolution is very high, TR-DUTS can be removed from training due to that its images are in low resolution like $({300\sim{}400 pix})^2$.

Good luck to you and your training. Tell me if you have any more questions.

ZeVicTech commented 3 months ago

Thank you for your reply I have one more question. When you first created the repository, did you train the BiRefNet_ep580 file using only DIS datasets?

ZhengPeng7 commented 3 months ago

Yeah, all the weights where you can see with 'DIS' as the task name in the file name or no task name specified in the file names are trained with only the DIS5K dataset.

ZeVicTech commented 2 months ago

Hello I'd like to hear your advice.

I want to speed up the inference of the model, so I'm going to apply quantization. When I apply quantization, do you think the speed will improve a lot without performance degradation?

thank you

ZhengPeng7 commented 2 months ago

I know techniques like half-precision and tensorRT may increase the inference speed without almost the same performance. But I don't have time for it. If you made it and want to share it (for example, in a colab demo), please tell me. I can mention it in README for others to use.

ZeVicTech commented 2 months ago

Okay, I'll share it when I succeed in that task. I have one more question, the current model is using the vision model as a backbone. I wonder exactly what role the vision model plays in BiRefNet. If I use a lightweight vision model, can I speed up inference?

thank you

ZhengPeng7 commented 2 months ago

Yeah, of course, that's a trade-off between accuracy and inference speed. You can take a look at this issue, where these kinds of things have been discussed and I provided a lightweight version with Swin-Tiny as the backbone.

ZhengPeng7 commented 1 month ago

Hi, @ZeVicTech, I've updated a BiRefNet for general segmentation with swin_v1_tiny as the backbone for the edge device. The well-trained model has been uploaded to my Google Drive. Check the stuff in README for access to the weights, performance, predicted maps, and training log in the corresponding folder (exp-xxx). The performance is a bit lower than the large version with the same massive training, but still good (HCE↓: 1152 -> 1182 on DIS-VD). Feel free to download and use them.

Meanwhile, check the update in inference.py. Set the torch.set_float32_matmul_precision to 'high' can increase the FPS of the large version on A100 from 5 to 12 with ~0 performance downgrade (Because I set it to 'high' during training).

Good luck with the smaller and faster BiRefNet with ~0 degradation.