ZhengPeng7 / BiRefNet

[CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation
https://www.birefnet.top
MIT License
322 stars 28 forks source link

Training a smaller edge model #11

Closed rishabh063 closed 3 months ago

rishabh063 commented 3 months ago

Hey @ZhengPeng7 hope you are well . I am thinking to training a smaller model to do downstream task. Any insights where i can shrink param count ?

ZhengPeng7 commented 3 months ago

Hi Rishabh, I was also thinking about this kind of thing to help the deployment on edge devices. However, I need to pay for renting GPUs with my own money for training. Now that here is the question, I'll try to do it in recent days. The easiest way to train a smaller model is, of course, to use Swin-tiny as the backbone network.

ZhengPeng7 commented 3 months ago

Hi, I've added choices of more lightweight backbones in the codes. You can now try lightweight models like swin_v1_tiny, swin_v1_small, pvt_v2_b0, pvt_v2_b1, and pvt_v2_b2 as the backbone. But training of models with them is temporarily suspended due to the situation I mentioned above. BTW, I also test the FPS of them and the official one in the paper on a single A100-40G GPU with batch size == 1 and resolution == 1024x1024:

BiRefNet-Swin_v1_large (official) BiRefNet-Swin_v1_tiny BiRefNet-Swin_v1_small BiRefNet-PVT_v2_b0 BiRefNet-PVT_v2_b1 BiRefNet-PVT_v2_b2
5.05 20.20 12.80 26.12 20.70 15.11
rishabh063 commented 3 months ago

can you also share the param numbers ? also can we you do a quality check when using 8 bit or 16 bit ?

ZhengPeng7 commented 3 months ago

can you also share the param numbers ? also can we you do a quality check when using 8 bit or 16 bit ?

You can try [tensor/model].half() for this. I'm afraid that I have had no time for this kind of thing in recent days. Good Luck!

rishabh063 commented 3 months ago

Can you share the param count . For the swin tiny backbone ?

ZhengPeng7 commented 3 months ago

You can choose the bb in config.py as swin_v1_tiny and from models.baseline import BiRefNet; model = BiRefNet() and use existing tools to count the number of parameters.

ZhengPeng7 commented 3 months ago

Hi, I finished the 500-epoch training on a BiRefNet with swin_v1_tiny as the backbone. Although the model was trained for 500 epochs, it still hasn't converged in terms of HCE. You can find the related stuff of this tiny model here. The weights of this tiny model can be found in the same folder with other weights. The performance of BiRefNet-DIS_ep500-swin_v1_tiny is as below:

+---------+----------------------------+-------+-----------+------+----------+--------+------+-------+--------+-------+-------+ & Dataset & Method & maxFm & wFmeasure & MAE & Smeasure & meanEm & HCE & maxEm & meanFm & adpEm & adpFm & +---------+----------------------------+-------+-----------+------+----------+--------+------+-------+--------+-------+-------+ & DIS-TE1 & DIS-bb_swin_v1_tiny--ep500 & .804 & .756 & .053 & .845 & .876 & 132 & .884 & .792 & .870 & .770 & & DIS-TE2 & DIS-bb_swin_v1_tiny--ep500 & .853 & .812 & .048 & .870 & .909 & 326 & .917 & .842 & .901 & .824 & & DIS-TE3 & DIS-bb_swin_v1_tiny--ep500 & .886 & .847 & .039 & .891 & .935 & 696 & .944 & .873 & .935 & .866 & & DIS-TE4 & DIS-bb_swin_v1_tiny--ep500 & .874 & .827 & .050 & .875 & .923 & 3091 & .938 & .854 & .930 & .848 & & DIS-VD & DIS-bb_swin_v1_tiny--ep500 & .850 & .806 & .048 & .867 & .912 & 1166 & .922 & .837 & .908 & .821 & & DIS-TEs & DIS-bb_swin_v1_tiny--ep500 & .854 & .810 & .048 & .870 & .911 & 1061 & .921 & .840 & .909 & .827 & +---------+----------------------------+-------+-----------+------+----------+--------+------+-------+--------+-------+-------+

rishabh063 commented 3 months ago

Ohh great thanks 😃

ZhengPeng7 commented 3 months ago

Feel free to reopen it if there are any relevant questions.

ZhengPeng7 commented 1 month ago

Hi, @rishabh063 , I've updated a BiRefNet for general segmentation with swin_v1_tiny as the backbone for the edge device. The well-trained model has been uploaded to my Google Drive. Check the stuff in README for access to the weights, performance, predicted maps, and training log in the corresponding folder (exp-xxx). The performance is a bit lower than the large version with the same massive training, but still good (HCE↓: 1152 -> 1182 on DIS-VD). Feel free to download and use them.

Meanwhile, check the update in inference.py. Set the torch.set_float32_matmul_precision to 'high' can increase the FPS of the large version on A100 from 5 to 12 with ~0 performance downgrade (Because I set it to 'high' during training).

Good luck with the smaller and faster BiRefNet with ~0 degradation.

rishabh063 commented 1 month ago

hey thanks , i will try it out in some time .

swamped with work

btw any idea on that SOD bounding box ?

ZhengPeng7 commented 1 month ago

That's still in progress, too much training in the queue. I'll reply to you here once it's done.

rishabh063 commented 1 month ago

What backbone are you thinking to use there ?

I will also try training that . New to model training .would love all your help @ZhengPeng7 ( forgot to mention yesterday)

ZhengPeng7 commented 1 month ago

I'll try both swin_v1_large and swin_v1_tiny for the best performance and efficiency, respectively. If still too heavy, I'll try pvt_v2_b1.

rishabh063 commented 1 month ago

okay i will be using u2netp without any backbone , lets see if i can get anything