facebookresearch / detr

End-to-End Object Detection with Transformers
Apache License 2.0
13.58k stars 2.45k forks source link

Need suggestions to improve custom training using detectron2 wrapper #192

Open theoutsider8060 opened 4 years ago

theoutsider8060 commented 4 years ago

Hi, thank you for this great repo on DETR. I was trying to fine-tune DETR on my custom dataset using the detectron2 wrapper and I need some suggestions regarding that.

My dataset consists of around 14k training images and 5k validation images with a maximum of 30 objects per image. There is only a single class.

These are the changes I made to the default configuration:

  1. I changed the num_classes to 1.
  2. I only have 1 gpu with about 12 gb memory. So, I changed my batch size from 64 to 4 and also divided the default learning rate 1e-4 by 16.
  3. My dataset is in standard COCO format (for each object, XY(topleft)_WH). I did not make any change to this annotation and feed it directly to the network.
  4. In the default configuration file inside d2/configs, the pretrained weights were set to imagenet weights. Instead, I used pretrained weights on COCO using DETR from 'https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth' to initialize my custom DETR network having one class. In addition, I used the converted.py file to convert the pretrained weights into the format of the detectron2 wrapper. After doing this, all the parameters except for class_embed.weight and class_embed.bias were set and those two were skipped as shown below.

WARNING [07/31 15:27:46 fvcore.common.checkpoint]: Skip loading parameter 'detr.class_embed.weight' to the model due to incompatible shapes: (81, 256) in the checkpoint but (2, 256) in the model! You might want to double check if this is expected. WARNING [07/31 15:27:46 fvcore.common.checkpoint]: Skip loading parameter 'detr.class_embed.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected. [07/31 15:27:46 fvcore.common.checkpoint]: Some model parameters or buffers are not found in the checkpoint: detr.class_embed.{weight, bias} criterion.empty_weight

  1. I trained the network for 70 epochs and I have attached the total training loss curve and the validation maP curve below.

Now, it seems my maP is stuck around 35-38 and not improving much. And, the training curve has plateaued as well. What can I do to make my model perform better? These are some questions I have off the top of my head:

  1. I tried changing num_queries to 60 and trained for 50 epochs and my maP only became 6. Should I pursue this further? I do not have much computational resource to train for 300+ epochs.

  2. Change the eos_coef? Will that help?

  3. I am currently using the default augmentation. Should I add some extra augmentation to better generalize my dataset?

  4. Is there something specific I need to change because I have only a single class? The loss values are still pretty high around 20-30 after 70 epoch.

Any help would be greatly appreciated. Thanks in advance.

loss maP

alcinos commented 4 years ago

Hi @theoutsider8060 Thank you for your interest in DETR

I suggest you to have a look at these issues: #9, #124 and #125 to see if you can gain some insights.

Regarding your specific case, some remarks/questions:

Best of luck.

theoutsider8060 commented 4 years ago

Thanks for the reply @alcinos . I will definitely look more into the issues that you mentioned.

With regards to results using other models, at the moment, I am getting AP of 40.2 using faster R-CNN with an inference time of 91 ms. By using DETR, my inference time per image has dropped to around 70 ms but the AP has also dropped as I mentioned.

Just a few queries before I make the change:

  1. Since I will be going back to the default learning rate of 1e-4, can you provide any suggestion on exactly when I should drop it further and by how much?
  2. I am not sure if I have got this wrong. But since I have only a single class, shouldn't my class_error drop down to zero as soon as I start training? While I was fine tuning, I noticed my class_error started at around 7 and now has steadily declined towards 0. Is this how it is supposed to be?
alcinos commented 4 years ago
  1. Hard to tell, monitor training and drop by 2 or 10 when you feel it plateaus. On coco, for a quick run, one could do 150 epochs with drop at 100 (but this is from scratch). 1e-4 might be a bit high considering you have some unstability already. But try, and work from there.
  2. you technically have two classes: the class you are trying to detect and the "no-object" class, this is why you don't get 0.
Antoine-ls commented 4 years ago

@theoutsider8060 Hello, I am now facing the same case, as I am using detectron2 wrapper, how do you change the class_nums to 1 as I did not see it in the yaml files ?

@alcinos I have a question about images transformation, in my case, all the images are having the same shape (1080,1440), do I still need to reshape it? And how could I change input part in yaml files like MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), it means that images need to be reshaped into a square shape but not rectangular?

Thanks in advance !

danielfennhagencab commented 4 years ago

Hi @Antoine-ls, the second part of your question has been answered by @fmassa in https://github.com/facebookresearch/detr/issues/245#issuecomment-699491315

You should resize the input to have a minimum size of 800 (and ideally a max size of 1333), the model doesn't do the resizing inside it. Check our colab notebook for further details on how to perform the input preprocessing https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb

ruodingt commented 4 years ago

@Antoine-ls

  1. For changing class_nums

https://github.com/facebookresearch/detr/blob/4e1a9281bc5621dcd65f3438631de25e255c4269/d2/detr/config.py#L11

  1. Re. image transformation.

Image will still be rectangular as MIN_SIZE_TRAIN is used along with MAX_SIZE_TRAIN (default to 1333) You can this how these params are used in the function:

https://github.com/facebookresearch/detr/blob/4e1a9281bc5621dcd65f3438631de25e255c4269/d2/detr/dataset_mapper.py#L15

https://github.com/facebookresearch/detr/blob/4e1a9281bc5621dcd65f3438631de25e255c4269/d2/detr/dataset_mapper.py#L36

class ResizeShortestEdge(Augmentation):
    """
    Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
    If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
    """

Hi @theoutsider8060 Any luck with your DETR training? I am in similar situation as you did. Did you manage to get a similar performance compared with mask rcnn?

Regards, Ruoding