juhongm999 / hsnet

Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, ICCV 2021
231 stars 43 forks source link

Training on data from different domain #16

Closed desmania closed 2 years ago

desmania commented 2 years ago

Hi,

Thank you for making your work publicly available.

Currently, I am trying to train on satellite imagery (5 cm resolution) but it seems like the model can't converge.

The dataset consists of 6 classes of which 0=background. I have set the following in my custom Dataset:

 self.nfolds = 5
 self.nclass = 5

Can you confirm if my thought proces is correct that each fold will have 4 base classes in the training set and 1 base class in the validation set?

As for the training, I though about unfreezing the gradients because the domain data is so much different from the pretrained data.

Do you have any ideas on how I could obtain feasible results on my data?

Cheers!

juhongm999 commented 2 years ago

To follow the same problem setup as ours, the training and validation datasets should be disjoint with respect to categories (and I think this condition is satisfied). As you said, I would try finetune the backbone network during training due to domain differences. How are the results with unfreezing the gradients of the backbone network during training?

desmania commented 2 years ago

Thank you for your response!

The progress is not too great. With the same learning rate and optimizer (Adam lr 1e-3) the model diverges (mIOU goes to 0 very quickly).

With SGD optimizer and same lr, the model converges slightly on training data (mIOU 20) while the validation mIOU highly fluctuates between 0-20. This seems to me the model is overfitting...

In your work you have not included data augmentation. Would you suggest for me to add this to training or any other regularization?

Also could you tell me where I should change the code to enable training on n-shot learning?

Thanks in advance.

juhongm999 commented 2 years ago

I would recommend augmentation methods widely adopted in FSS here: random scale, crop, rotate, flip, and blur. To achieve n-shot training, you should change dataloader code accordingly such that model takes a batch of 1 query image and n support images/masks to predict a single query mask (and is optimized by cross-entropy as is).