xypu98 / CWSAM

43 stars 0 forks source link

How the cls_upscaling helps to predict masks with categories? #15

Closed zkjisj closed 3 months ago

zkjisj commented 3 months ago

In the code, the add of cls_upsacling produces the cls_upscaled_embedding with shape b (c num_classes // 8) h w. The shape of hyper_in is b 4 (c//8), why hyper_in @ cls_upscaled_embedding can be executed? image If it can be executed, this will produce masks with shape b (c // 8 ) ( num_classes h w ), it then be view as b num_mask_tokens (5?) -1 h w, how it happens without num_classes? I'm still confused about why the final masks will with categories, that I mean the function of num_classes. image

xypu98 commented 3 months ago

we only output one mask, keep the channel dimension of output mask same to the number of classes and calculate the loss between the corresponding categories. It works.

zkjisj commented 3 months ago

Thanks a lot for your explanation!I really appreciate your help, and I understand the concept much better.