Open Igal20 opened 5 years ago
I think thishttps://github.com/facebookresearch/maskrcnn-benchmark/issues/144 can help you.
@wjp0408 Not so much, issue 144 handle with negative samples simply by filtering them and not including in training at all. I do not want to exclude them but use them while training.
Hi
One image should contain two things: a positive element and a negative element. This is always the case for detection, except if your bounding boxes are of the full size of the image, in which case you might be looking for object classification, and not object detection?
@fmassa Thanks for the answer. If I understand right "One image should contain two things: a positive element and a negative element." is true for evaluating the net not for training.
I'll elaborate a little bit more on what I aim to do.
For example I want to train Mask-rcnn to detect and segment giraffes. So I provide 1000 sample of images and masks of giraffes. But I have also 10 image that look like the animal [for example blanket with giraffes pattern] and I want to include them in my training process but obviously masks of those images will be blank.
Thus I want to be able to add images without segmentation mask.
Thanks
I see, thanks for the explanation!
It is possible to support your use-case, but you'll need to adapt a few things in the code for that. The first issue you'll see is this one: https://github.com/facebookresearch/maskrcnn-benchmark/issues/31 You can make the Matcher return a tensor of size N, but you'll afterwards probably face a few other issues down the road that will need to be addressed.
Let me know if you get stuck in a particular problem. It might be interesting to see, once the code is working, what it would take to make it work, so that we might potentially send a PR to add support for this in the codebase.
Thanks I'll check it.
I have the same objective of being able to train on negative examples (images where there is nothing to detect).
My goal is both one of learning from negative data (the fact that there is nothing should at least help train the RPN) and calibration (most of my images have nothing, so training only on positive data tends to lead to overconfident predictions).
Would be great if this was supported, although I can see how it would make many things harder (e.g: sampling, dimensions, etc.)
I'd be willing to support this use-case, so if you have troubles getting the code to work for those images just let me know and I'd help you with that so that we could have a PR for it. There might be a few tricky things to tune though.
I have the same objective of being able to train on negative examples (images where there is nothing to detect).
I have the same question.
@LU4E I still have the same answer as before :-)
I have some code for handling empty images but I will not get around to disentangling it from the rest of my code before Dec 19
@Iwontbecreative it would be a very nice addition!
Could you briefly summarize what were the things that you did that made it work nicely?
This is probably not the right form and not directly tested on the new codebase but should be the main changes I did.
@Iwontbecreative thanks a lot for the patch!
Do you know if by applying this patch we get better results than by just removing all the images that do not have any label in it?
On my own dataset it seemed to have helped, but we have ~100 times as much unlabelled data (and it comes from a slightly different distribution, with our evaluation not being mainly about bounding box prediction).
I have not done any experiment on COCO sadly, sorry. Could be a configuration flag maybe?
@Iwontbecreative Hi! Thanks for your awesome codes for the negative sample training. But after changing the codes as you suggested in https://pastebin.com/6xXEWtvg, I cannot train my model via multi-GPU now. (always stuck at the beginning of the training, shown as follows) I would like to inquire if you ever met such a problem before and @fmassa could you please give any suggestions for this? (I'm sure that I can use multi-GPU before, so it may not be the driver problem.)
2019-01-17 20:12:01,811 maskrcnn_benchmark.trainer INFO: Start training
@BobZhangHT If you changed your PyTorch version in between, that might potentially be the reason. Apart from that, I don't know
@fmassa Sincerely thanks for your reply. Actually I didn't change the pytorch version. I will try to figure it out and let you know if it gets resolved.
I have not run into this issue, not sure what is going wrong sorry...
are the code changes available to see somewhere on github? Unfortunately I cannot access pastebin.com
Sorry, my codebase was quite different from commit at the time so I cherry picked the changes rather than a proper merge request.
diff --git a/maskrcnn_benchmark/modeling/matcher.py b/maskrcnn_benchmark/modeling/matcher.py
index 35ec5f1..074734c 100644
--- a/maskrcnn_benchmark/modeling/matcher.py
+++ b/maskrcnn_benchmark/modeling/matcher.py
@@ -53,9 +53,9 @@ class Matcher(object):
if match_quality_matrix.numel() == 0:
# empty targets or proposals not supported during training
if match_quality_matrix.shape[0] == 0:
- raise ValueError(
- "No ground-truth boxes available for one of the images "
- "during training")
+ length = match_quality_matrix.shape[-1]
+ device = match_quality_matrix.device
+ return torch.ones(length, dtype=torch.int64, device=device) * (-1)
else:
raise ValueError(
"No proposal boxes available for one of the images "
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
index 2c21f6c..bed4bbc 100644
--- a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
@@ -38,7 +38,11 @@ class FastRCNNLossComputation(object):
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
- matched_targets = target[matched_idxs.clamp(min=0)]
+ if target.bbox.shape[0]:
+ matched_targets = target[matched_idxs.clamp(min=0)]
+ else:
+ target.add_field ("labels", matched_idxs.clamp(min=1, max=1))
+ matched_targets = target
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
@@ -63,9 +67,13 @@ class FastRCNNLossComputation(object):
labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler
# compute regression targets
- regression_targets_per_image = self.box_coder.encode(
- matched_targets.bbox, proposals_per_image.bbox
- )
+ if not matched_targets.bbox.shape[0]:
+ zeros = torch.zeros_like(labels_per_image, dtype=torch.float)
+ regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
+ else:
+ regression_targets_per_image = self.box_coder.encode(
+ matched_targets.bbox, proposals_per_image.bbox
+ )
labels.append(labels_per_image)
regression_targets.append(regression_targets_per_image)
diff --git a/maskrcnn_benchmark/modeling/rpn/loss.py b/maskrcnn_benchmark/modeling/rpn/loss.py
index 0847231..a3dae25 100644
--- a/maskrcnn_benchmark/modeling/rpn/loss.py
+++ b/maskrcnn_benchmark/modeling/rpn/loss.py
@@ -43,7 +43,10 @@ class RPNLossComputation(object):
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
- matched_targets = target[matched_idxs.clamp(min=0)]
+ if matched_idxs.clamp(min=0).sum() > 0:
+ matched_targets = target[matched_idxs.clamp(min=0)]
+ else:
+ matched_targets = target
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
@@ -55,6 +58,7 @@ class RPNLossComputation(object):
anchors_per_image, targets_per_image
)
+
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
@@ -66,9 +70,13 @@ class RPNLossComputation(object):
labels_per_image[inds_to_discard] = -1
# compute regression targets
- regression_targets_per_image = self.box_coder.encode(
- matched_targets.bbox, anchors_per_image.bbox
- )
+ if not matched_targets.bbox.shape[0]:
+ zeros = torch.zeros_like(labels_per_image)
+ regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
+ else:
+ regression_targets_per_image = self.box_coder.encode(
+ matched_targets.bbox, anchors_per_image.bbox
+ )
labels.append(labels_per_image)
regression_targets.append(regression_targets_per_image)
@@ -95,6 +103,8 @@ class RPNLossComputation(object):
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
+
+
objectness_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
@Iwontbecreative Thank you! It seems to be my codes problem.
@Iwontbecreative i take it you updated the BoxList
class as well to allow for empty box arrays?
@Iwontbecreative i take it you updated the
BoxList
class as well to allow for empty box arrays?
@Iwontbecreative apologies for not fully debugging before commenting. Just needed to make sure the box input is np.zeros((0, 4))
. Code is running without error now.
Hii, I am facing the same thing as some CityScape images have no positive boxes. Is there a merged fix for this problem? I think it makes sense that the model returns a loss of around 0 if it makes no false positives on negative images, and non-zero loss otherwise. Thanks!
@IssamLaradji can you check the fix from https://github.com/facebookresearch/maskrcnn-benchmark/issues/169#issuecomment-455204465 and see if it gives reasonable results on your case?
My solution in the CityScapes case would be to remove those images during training, but if the solution that @Iwontbecreative works and is not harming performance, then it might be worth considering merging it
Thanks for your reply. It might be easier to just ignore images that have no annotations, did you use something that looks roughly like this in the data loader class?
def __getitem__(self, index):
annList = get_annotations(index)
if len(annList) == 0:
self.__getitem__(np.random.choice(self.__len__()))
else:
...everything else...
I will give #169 (comment) a try as well, thanks for the reference!
@IssamLaradji I do it in the initialization of the dataset https://github.com/facebookresearch/maskrcnn-benchmark/blob/5f2a8263a1a0f2f5f0137042cd4ba64efcb6859a/maskrcnn_benchmark/data/datasets/coco.py#L18-L23
thanks a lot @fmassa !
I have successfully used the method above to include negative images (and updated for the mask head as well). I don't have lots of data, and the negative images are important.
thanks to @fmassa for a great repo.
@jgbos great, thanks for the information!
I'll keep a note of it, it seems that merging a patch with this fix would make things work out fine in many cases, so it would be a great addition!
Hi @fmassa, @Iwontbecreative,
based on this comment and @jgbos's success, it seems that submitting a PR of this patch would be quite useful.
@AdanMora, could you help making a unit test for this patch?
@botcs I agree, if someone could send a PR with unit tests I'll be more than happy to merge it!
@botcs @fmassa Excellent, I'll take a look and try to get some unit test.
@botcs @fmassa Excellent, I'll take a look and try to get some unit test.
could you also add the metric mAP for the negative samples when testing?
@qq237942920 nyeh. Sounds like a simple task but I have a hunch that you cannot compute mAP for negative samples. I mean, you iterate over different IoUs Precision and Recall values when evaluating AP for a single class. What you could do is to compute F1 or Prec/Rec with the following classification:
[Positive: has prediction][TruePostive: has annotaion and has prediction]...
(or one extremely ugly trick could be annotating the background as an object category, and make a Dataset with this approach, but I have never said this)
@qq237942920 nyeh. Sounds like a simple task but I have a hunch that you cannot compute mAP for negative samples. I mean, you iterate over different IoUs Precision and Recall values when evaluating AP for a single class. What you could do is to compute F1 or Prec/Rec with the following classification:
[Positive: has prediction][TruePostive: has annotaion and has prediction]...
(or one extremely ugly trick could be annotating the background as an object category, and make a Dataset with this approach, but I have never said this)
Thank you for your patient reply! Maybe I didn't express myself well in the last comment. Computing the mAP for the negative samples is mean that when my dataset both have positive and negative samples, will the coco_eval eval the mAP including fp on the negative samples(no gt in the img)? The trick you suggested in the end may be a way to compute them,
@Iwontbecreative
the code matched_idxs.clamp(min=0).sum() > 0
in maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py is not correct, when there are only one fg, matched_idxs.clamp(min=0).sum() equal to zero, cause the matched target idx is 0.
how to assign category_id for negative samples images, is it should be 0?
@Iwontbecreative @fmassa Can you write detailed process about how to train model on negative samples ? Thanks!
@cltdevelop all the information that I know about I've already written here, but if @Iwontbecreative has time and wants to send a PR adding a more detailed information, I'd be happy to merge it
I don't have a lot of time to commit to this as of now unfortunately. I also have never experimented on traditional datasets and the one I did my experiments on is not public so it'd be tricky to give concrete results/numbers. I hope to have some more time to assess when it can be useful on traditional datasets at some point and come up with a proper PR.
@fmassa @Iwontbecreative I have prepared a synthetic DebugDataset
which I use for unit tests. It just puts random number of white boxes on a black plane. It could be also used, as is, for sanity checking of your implementation.
@tpys You are right!!
@Iwontbecreative the code
matched_idxs.clamp(min=0).sum() > 0
in maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py is not correct, when there are only one fg, matched_idxs.clamp(min=0).sum() equal to zero, cause the matched target idx is 0.
it should be something like this:
if len(target):
matched_targets = target[matched_idxs.clamp(min=0)]
else:
matched_targets = target
Has anyone run into troubles with the fix above using the latest master? I'm getting an error that a number of the proposals provided to the mask head now have zero width or height which causes the code to crash on this line
Where a divide by zero error happens on this line
Thank you for the details!
In the most recent state of the code, is it sufficient to initialize the BBox list with 0s and set the label to 0?
For example, in maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/coco.py
adding the following check:
def __getitem__(self, idx):
img, anno = super(COCODataset, self).__getitem__(idx)
if not anno:
boxes = [[0, 0, 0, 0]]
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
target.add_field("labels",torch.tensor([0]))
else:
# filter crowd annotations
# TODO might be better to add an extra field
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
classes = [obj["category_id"] for obj in anno]
classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
classes = torch.tensor(classes)
target.add_field("labels", classes)
target = target.clip_to_image(remove_empty=False)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target, idx
For @Igal20 question I think for a class of just 1-object and 0-background in terms of binary segmentation, using a no groundtruth(mask) or segmentation contour in few training samples will simply affect the accuracy of your training since semantic segmentation is performed pixelwise and considering your target. If you have nothing of positive pixels in your target then during backpropogation the weights of the learned features will be reversed back to non learned weights. (i.e. for example if you have equal amount of positive targets(ground truth with segmentation contours) and negative targets(ground truth with no segmentation targets) the features learned during the positve targets will be reversed back if you feed in also the negative targets in each backprop). So may be in classification it is actually needed but in segmentation the only way is to avoid it.
❓ Questions and Help
Hello, My data have only two classes first is background = 0 and object = 1. While training I it necessary for me to present to the net negative samples i.e images without an object, just background. In this case I don't have segmentation contour. How do I add those images to the training?
I use COCO-style annotations for the images, saved in json format. Thanks in advance.