RViMLab / ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation

Implementation of the methods described in "Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation", ECCV 2022
27 stars 2 forks source link

Q: Performance problem #4

Closed CK-Sung closed 1 year ago

CK-Sung commented 1 year ago

Hi, Thank you for providing nice work.

But, I got mIOU: 73 @40K iterations for the hrnet_contrastive_CTS with hrnetv2_w48_imagenet_pretrained.

Should I train this without pretrained model to get 81.5 @40K iterations? In addition, in Table 9, weight decay is 0.00005, but in the hrnet_contrastive_CTS.json wd is 0.0005? Which one should I use as weight decay?

Thank you

TheoPis commented 1 year ago

Thanks for letting me know of this issue.

I noticed in the config "hrnet_contrastive_CTS.json" the entry "graph.pretrained" is (by my mistake) set to false which would mean it does not use the imagenet pretrained checkpoint for hrnet. Did you change this to true? In general all paper results require imagenet pretrained checkpoints. The gap you mention (73 vs 81.5) is very big so I assume this may be the issue. If not, could you please share the exact command line arguments and config file you use to obtain this result so that I can investigate this further.

CK-Sung commented 1 year ago

Yes, I have changed it as true. And I also guessed it would use the imagenet pretrained checkpoint. Here is the config and arguments that I have used for training. For the training I have set the batch size as 8, in the paper ablation study is done with batch size 8 for the cityscapes that I remember.

Is there any problem other than weight decay? Also, can you provide the checkpoints that you trained?

Argument -d 0 1 -p -u theo -c config/CITYSCAPES/hrnet_contrastive_CTS.json -bs 8 -ws 3

` "name": "hrn", "mode": "training", "manager": "HRNet", "graph": { "model": "HRNet", "backbone": "hrnet48", "sync_bn":true, "out_stride": 4, "pretrained": true, "align_corners": true, "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "before_context": true} },

"load_last": true, "tta":true, "tta_scales": [0.75, 1.25, 1.5, 1.75, 2], "run_final_val": false,

"data": { "num_workers":8, "dataset": "CITYSCAPES", "use_relabeled": false, "blacklist": false, "experiment": 1, "split": ["train", "val"], "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"], "transform_values": {"crop_shape": [512, 1024], "crop_class_max_ratio": 0.75, "scale_range": [0.5, 2]}, "transforms_val": ["torchvision_normalise"], "transform_values_val": {}, "batch_size": 12 },

"loss": { "name": "LossWrapper", "label_scaling_mode": "nn", "dominant_mode": "all", "temperature": 0.1, "cross_scale_contrast": true, "weights": [1, 0.7, 0.4, 0.1], "scales": 4, "losses": {"CrossEntropyLoss": 1,"DenseContrastiveLossV2_ms": 0.1}, "losses___": {"CrossEntropyLoss": 1}, "min_views_per_class": 5, "max_views_per_class": 2500, "max_features_total": 10000 }, "train": { "learning_rate": 0.01, "lr_fct": "polynomial", "optim": "SGD", "lr_batchwise": true, "epochs": 484, "momentum": 0.9, "wd": 0.0005 }, "valid_batch_size": 1, "max_valid_imgs":2, "valid_freq": 20, "log_every_n_epochs": 20, "cuda": true, "gpu_device": 0, "parallel": false, "seed": 0`

Thank you so much!

TheoPis commented 1 year ago

I looked into this a bit and I suggest you check the following:

1) I assume you have "hrnetv2_w48_imagenet_pretrained.pt" in the root folder and it is selected and loaded at this point in the code: https://github.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/blob/97e84981fbba479dd32da3ad24de1ef8152b4070/models/HRNet.py#L669 Please ensure this actually happens in your execution as without it the results will be much worse.

2) Regarding weight decay, I found that in fact I've been using 0.0005 (as you do above) in all cases with HRNet.

3) Another important thing to consider is that results @40K refer to using the learning rate schedule for 40K iterations. If you use a batch size of 8 and set epochs to 484 (as in the above config) then the learning rate schedule extends to 484 * 2975//8 ~ 180K iterations (where 2975 is the num of CTS train set images). That possibly means that in your run, @40K iterations you have not decayed the learning rate enough (i.e you are at 40K out of 180K learning rate decay steps). So you need to adjust the epochs to get a maximum number of steps equal to 40K (i.e epochs = 108). Let me know if that is clear.

Regarding the checkpoints: I will do my best to commit several checkpoints (including HRNet) and a few clarifications in the codebase to make it a bit more user-friendly, in the coming week. I apologize for any inconvenience and lack of clarity in the code.

CK-Sung commented 1 year ago

Thank you so much!! The first one happened and I will use 0.0005 as well. I missed #3. I will modify the epochs, then I will try it. Thank you again for kindness answers!

CK-Sung commented 1 year ago

Hi, I just notice that I got 80.03 as the best mIOU(much better than before) with 1 A100 GPU (Batch size:8, Epochs:108). I will try to tune hyper-parameters.

TheoPis commented 1 year ago

Hi, I do not believe the remaining gap is due to hyperparam tuning (after all I did not perform any tuning of weight decay, lr etc). I believe it has to do with the (subtle) fact that the contrastive loss hyperparams (especially max_features_total and max_views_per_class that have to do with sampling anchor features for computing the contrastive loss terms) are gpu-specific, i.e applied on the chunck of the batch size that is processed by each gpu. In my experiments I always used 4 x 24Gb gpus so for the batch_size=8 runs, the above hyperparams refer to a per-gpu batch_size of 2. If you use a single GPU then you may have to adjust these settings: in you setup you get a maximum of 10K anchor points from every 8 images whereas in my runs I sample at most 10K anchor points from every 2 images.

CK-Sung commented 1 year ago

Yes you are right. I totally missed that part! Thank you so much! I believe that it would work fine if I modify that