xuebinqin / DIS

This is the repo for our new project Highly Accurate Dichotomous Image Segmentation
Apache License 2.0
2.22k stars 258 forks source link

Details about training the GT encoder #71

Closed AeroDEmi closed 1 year ago

AeroDEmi commented 1 year ago

In my understanding, we first train the GT encoder (the target is the GT, I still don't know what my input is), then we train the U2Net using our images as input and the mask as a target. My questions are:

Thank you

tldrafael commented 1 year ago

I'm answering only based on my personal experience using this project.

The GTEncoder input is the label. The GTEncoder has an architecture very similar to the ISNet (or U2Net, they are basically the same), but it doesn't have the "decoder" stage and has fewer channels in the conv layers. The GTEncoder yields a feature map that will be used to regularize the feature maps generated by the ISNet during training.

It will be only used if you set this param to true:

https://github.com/xuebinqin/DIS/blob/f3837183a33dab157c636e0124e091acd6da9dd1/IS-Net/train_valid_inference_main.py#L683

So, answering your questions: 1) No, you need to set it in the parameters; if you set it, it will train the encoder on the fly before start training the ISNet (this is a quick process). 2) Again, It only uses the regularized loss if you set hypar["interm_sup"] to True. 3) No, it is not used in the inference mode.

AeroDEmi commented 1 year ago

Thank you for your answer. I will try to experiment with it, but the cache step takes some time.

kabbas570 commented 9 months ago

Hello @tldrafael and @AeroDEmi thanks for the explainiation. Can you please guide me on this? 1- When I set hypar["interm_sup"] = True, I got the size mismatch error; I tried the both the isnet.pth and isnet_general.pth here,

hypar["gt_encoder_model"] = "isnet.pth" Or

hypar["gt_encoder_model"] = "isnet-general-use.pth"

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ISNetGTEncoder:

size mismatch for stage6.rebnconv3d.bn_s1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv3d.bn_s1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv3d.bn_s1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv2d.conv_s1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 128, 3, 3]). size mismatch for stage6.rebnconv2d.conv_s1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv2d.bn_s1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv2d.bn_s1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv2d.bn_s1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv2d.bn_s1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for stage6.rebnconv1d.conv_s1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 128, 3, 3]).