Closed Manchery closed 4 years ago
I think there is a bug in compute_miou
: we should not count those pixels without label into union
, which would make IoU of a class much lower.
I test one of my checkpoint computing miou with or without labeled mask
, resulting in miou 0.3463 -> 0.3904.
After I modify my compute_miou
, decline still exists:
Hi,
For mIOU formula, I agree that not including invalid pixel predictions is more accurate. Really thank for pointing out. I will fix them soon in this repo.
However, from a mathematical point of view, fixing invalid point prediction would have a constant ratio of improvement regardless of what class it predicts, so the overall performance remains in the same trend.
From the visualization provided, the validation loss starts to increase from very early of iterations. Could you confirm this similar trend still exists if you have another optimiser or learning rate? Also, could you confirm whether it exists in CityScapes dataset? From my experience, the CityScapes dataset is far easier, and the performance is significantly more stable comapred to NYUv2.
Thanks.
Thanks for your reply. I will do these experiments that you suggested and put results as soon as possible.
Also, I suggest to check out the paper here: https://arxiv.org/pdf/2004.13379.pdf Supplementary A and B, on how they select their hyper-parameters and data augmentation which might be helpful.
Note that: the result for MTAN in the linked paper is not accurate, and we have managed to sorted it out, but have not yet updated the paper. with DeepLabv3, the author presented me with an updated result for 44.7 mean IoU and 72.4 overall accuracy.
Hi,
I did these experiments on NYU: first, I changed the optimizer to SGD with 0.9 momentum, resulting in a bit low accuaracy but did not face a decrease on mIoU; then, I did data augmentation consistent with the linked paper and trained with Adam and same hyperparameters once again, decrease on mIoU didn't happen, too. What's more, I got a much higher mIoU and accuracy: about 45 mIoU and 77 pixel accuracy.
So I attribute the decrease to overfit with Adam on NYUv2, which is a quite small dataset without augmentation.
Thanks a lot for the suggestions provided.
That is great. Glad you solve the problem. Highly appreciate the feedback and everything.
For the fixed version of mIoU computing, could you confirm whether the following modification is correct, and consistent to your own version:
def compute_miou(x_pred, x_output):
_, x_pred_label = torch.max(x_pred, dim=1)
x_output_label = x_output
batch_size = x_pred.size(0)
class_nb = x_pred.size(1)
device = x_pred.device
for i in range(batch_size):
true_class = 0
first_switch = True
invalid_mask = (x_output[i] >= 0).float()
for j in range(class_nb):
pred_mask = torch.eq(x_pred_label[i], j * torch.ones(x_pred_label[i].shape).long().to(device))
true_mask = torch.eq(x_output_label[i], j * torch.ones(x_output_label[i].shape).long().to(device))
mask_comb = pred_mask.float() + true_mask.float()
union = torch.sum((mask_comb > 0).float() * invalid_mask) # remove non-defined pixel predictions
intsec = torch.sum((mask_comb > 1).float())
if union == 0:
continue
if first_switch:
class_prob = intsec / union
first_switch = False
else:
class_prob = intsec / union + class_prob
true_class += 1
if i == 0:
batch_avg = class_prob / true_class
else:
batch_avg = class_prob / true_class + batch_avg
return batch_avg / batch_size
Also, could you kindly provide your modification on NYUv2 data augmentation, so I can update it to make it more standardised?
Thanks.
Yep, I think your fixed version of miou is correct and below is my own version. I carefully inspect them and believe they are equivalent.
def compute_miou(x_pred, x_output):
x_pred = x_pred.detach()
_, x_pred_label = torch.max(x_pred, dim=1)
x_output_label = x_output
batch_size = x_pred.size(0)
class_nb = x_pred.size(1)
FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
temp_ones = torch.ones(x_pred_label[0].shape).type(LongTensor)
labeled_mask = torch.eq(x_output_label, -1 * temp_ones).logical_not()
for i in range(batch_size):
true_class = 0
first_switch = True
for j in range(class_nb):
pred_mask = torch.eq(x_pred_label[i], j * temp_ones) * labeled_mask[i]
true_mask = torch.eq(x_output_label[i], j * temp_ones) * labeled_mask[i]
mask_comb = pred_mask.type(FloatTensor) + true_mask.type(FloatTensor)
union = torch.sum((mask_comb > 0).type(FloatTensor))
intsec = torch.sum((mask_comb > 1).type(FloatTensor))
if union == 0:
continue
if first_switch:
class_prob = intsec / union
first_switch = False
else:
class_prob = intsec / union + class_prob
true_class += 1
if i == 0:
batch_avg = class_prob / true_class
else:
batch_avg = class_prob / true_class + batch_avg
return batch_avg / batch_size
And below is my NYUv2 with data augmentation based on your code:
from torch.utils.data.dataset import Dataset
import os
import torch
import fnmatch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import random
from . import config
# import config
class RandomScaleCrop(object):
def __init__(self, scale=[1.0, 1.2, 1.5]):
self.scale = scale
def __call__(self, img, label, depth, normal):
height, width = img.shape[-2:]
sc = self.scale[random.randint(0, len(self.scale)-1)]
h, w = int(height/sc), int(width/sc)
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img_ = F.interpolate(img[None,:,i:i+h, j:j+w], size=(height, width), mode='bilinear').squeeze(0)
label_ = F.interpolate(label[None,None,i:i+h, j:j+w], size=(height, width), mode='nearest').squeeze(0).squeeze(0)
depth_ = F.interpolate(depth[None,:,i:i+h, j:j+w], size=(height, width), mode='nearest').squeeze(0)
depth_ = depth_ / sc
normal_ = F.interpolate(normal[None,:,i:i+h, j:j+w], size=(height, width), mode='nearest').squeeze(0)
return img_, label_, depth_, normal_
class NYUv2(Dataset):
"""
This file is directly modified from https://pytorch.org/docs/stable/torchvision/datasets.html
"""
def __init__(self, root, train=True, transform=None, random_flip=False):
self.train = train
self.root = os.path.expanduser(root)
self.transform = transform
self.random_flip = random_flip
# R\read the data file
if train:
self.data_path = root + '/train'
else:
self.data_path = root + '/val'
# calculate data length
self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.npy'))
"""
Data Augmentation:
[1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
[2] Revisiting Multi-Task Learning in the Deep Learning Era
"""
def __getitem__(self, index):
# get image name from the pandas df
image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0)).type(torch.FloatTensor)
semantic = torch.from_numpy(np.load(self.data_path + '/label/{:d}.npy'.format(index))).type(torch.FloatTensor)
depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0)).type(torch.FloatTensor)
normal = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/normal/{:d}.npy'.format(index)), -1, 0)).type(torch.FloatTensor)
if self.transform is not None:
image, semantic, depth, normal = self.transform(image, semantic, depth, normal)
if self.random_flip and torch.rand(1)<0.5:
image = torch.flip(image, dims=[2])
semantic = torch.flip(semantic, dims=[1])
depth = torch.flip(depth, dims=[2])
normal = torch.flip(normal, dims=[2])
normal[0,:,:] = -normal[0,:,:]
return image, semantic, depth, normal
def __len__(self):
return self.data_len
def get_datasets(root=config.nyuv2_data_dir):
nyuv2_train_set = NYUv2(root=root, train=True, transform=RandomScaleCrop(), random_flip=True)
nyuv2_test_set = NYUv2(root=root, train=False)
return nyuv2_train_set, nyuv2_test_set
That's really helpful! Thanks.
Sorry one more question: which specific set of hyper-parameters did you choose for obtaining 0.45 mIoU on NYUv2 with data augmentation?
Sorry, I missed an important details: I got 0.45 mIoU with the model resnet_split.py
, which is simply modified from resnet_mtan.py
, using shared backbone and three DeepLabHead
for semantic, depth and normal, respectively.
And my hyperparameters are:
optimizer = optim.Adam(deeplab.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
and total_epoch = 200
and batch_size = 8
.
Moreover, Adam with or without weight_decay will result in comparable accuracy, which means you can also ignore weight_decay=1e-5
I will run resnet_mtan.py
with same hyperparameters and report results as soon as possible.
Ok. That's great. So basically we can re-use the hyper-parameter choice in the original paper.
Looking forward to your next update.
If it's okay with you, I will update your latest results in a table (I think split and MTAN is sufficient) in the main readme file, so other people could easier compare with the latest results from a better backbone architecture?
No problem, it's my pleasure. I will report precise numbers after all experiments have been done.
Hi, I'm back.
First of all, I update my NYUv2 augmentation code because I missed a detail proposed in PAD-Net: "the depth values are divided by the ratio". I call this operation as depth calibration. The only modification needed is that I add depth_ = depth_ / sc
after depth_
is interpolated. Experiments show that this can gain better or comparable results in all tasks.
Here is my results (average results on last 10 epoches):
model | Data augmentation | Depth calibration | mIoU | pixAcc | absErr | relErr | mean | median | <11.25 | <22.5 | <30 |
---|---|---|---|---|---|---|---|---|---|---|---|
resnet_split | True | False | 0.4493 | 0.7699 | 0.3881 | 0.1598 | 22.2689 | 15.8192 | 0.3727 | 0.6418 | 0.7484 |
resnet_split | True | True | 0.4527 | 0.7724 | 0.3511 | 0.1429 | 22.0554 | 15.6377 | 0.3769 | 0.6460 | 0.7518 |
resnet_mtan | True | False | 0.4562 | 0.7732 | 0.3816 | 0.1558 | 22.0135 | 15.4649 | 0.3818 | 0.6492 | 0.7534 |
resnet_mtan | True | True | 0.4577 | 0.7770 | 0.3500 | 0.1420 | 21.9099 | 15.4780 | 0.3815 | 0.6489 | 0.7539 |
#parameter for resnet_split
is 71888721 and resnet_mtan
92347921.
Hyper-paramters:
total_epoch = 200
batch_size = 8
optimizer = optim.Adam(deeplab.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
Other details:
It seems that MTAN does not have a significant supremacy compared to MTL baseline when applied to a stronger backbone .(perhaps due to lack of carefully hyperparameter tuning or task weighting?) Actually, many (encoder-based) approaches have been proved in https://arxiv.org/pdf/2004.13379.pdf to result in marginal improvement when it comes to a strong backbone. I wonder what's your opinion about that? Thanks in advance.
Thanks for the work and detailed evaluation.
First of all, I agree that depth calibration is a better augmentation method, and it is actually more precise mathematically. But I am wondering why you choose to use the nearest neighbor for the interpolation in depth and normal maps, considering they are all continuous? Maybe it would be preferred to perform bilinear interpolation in order to get a smoother result?
Regarding to marginal improvements, I suppose it is possibly due to the fact that strong generalization in the pre-trained weights outweighs the design bias for multi-task learning. In other word, as long as we have a giant network pre-trained on a large dataset, the design of the multi-task network does not contribute much to the final performance, since all the improvements are mainly from the pre-trained features rather than the proposed design bias.
That is one of the major reasons that I did not apply pre-training in the original paper, and I strongly believe the multi-task learning performance based on pre-trained networks would not accurately reflect the true performance.
This argument can also be reflected in the following:
Sorry to take up so much of your personal time. But if you are really interested in this... you may perform the following experiments to validate my arguments. (This could be a good paper idea though.)
Hope that helps.
Really thanks for all the patient replies. I think this issue can be closed. Looking forward to our further communication!
When it comes to the nearest neighbor for the interpolation in depth and normal maps, my concern is those unusual values interpolated from labeled and unlabeled pixels. I am not sure whether they would affect training results.
I also run the experiment in which depth and normal are interpolated using bilinear
.
model | Data augmentation | Depth calibration | mIoU | pixAcc | absErr | relErr | mean | median | <11.25 | <22.5 | <30 |
---|---|---|---|---|---|---|---|---|---|---|---|
resnet_split-bilinear | True | True | 0.4537 | 0.7739 | 0.3525 | 0.1436 | 21.9065 | 15.4712 | 0.3811 | 0.6498 | 0.7547 |
The results is comparable with nearest
.
So I agree with that using bilinear
rather than nearest
should be a more reasonable option.
Hi, I ran your official MTAN-DeepLabv3 and find that miou decrease when training while picAcc stay steady and loss on validate set are increasing consistently.
I also modify it to become
resnet_split.py
,resnet_single.py
,resnet_cs.py
. Same trends can be found.When training MTL model, benchmarks in depth and surface normal estimation are increasing/decreasing normally.
I also run
model_segnet_*.py
. Decreasing on miou when training also exists but is much smaller and imperceptible. I think it is because of low miou and accuracy with SegNet.I'm a new hand in NYUv2 and segmentation, I am not sure whether it is because of overfitting.
Some implement details:
I use the same code in
model_segnet_*.py
for compute miou and pixAcc except a slight modification to accelerate training.optimizer and lr_scheduler are the same with
model_segnet_*.py
for more details, my code can be find in resnet_mtan.py and train_utils.py
Here are some records(mtan, split, single, and mixed respectively):