potterhsu / easy-faster-rcnn.pytorch

An easy implementation of Faster R-CNN (https://arxiv.org/pdf/1506.01497.pdf) in PyTorch.
MIT License
165 stars 57 forks source link

Training on custom dataset #4

Closed culechetoo closed 5 years ago

culechetoo commented 5 years ago

Hi @potterhsu

I am trying to train a Faster R-CNN(Resnet101) model on my custom dataset. The dataset is in the same format as MS COCO, with 16 categories. However, I am not sure how to proceed. I am new to pytorch and don't have the slightest of ideas about where I can start from. I would really appreciate your help regarding this.

Thanks and Regards Chaitanya Agrawal

potterhsu commented 5 years ago

Basically, it can be done by

  1. Copy dataset/coco2017.py to dataset/coco2017-custom.py (or whatever you want)
    1. Change all COCO2017 to COCO2017Custom, for example:
      • class COCO2017(Base) -> class COCO2017Custom(Base)
      • self._mode == COCO2017.Mode.TRAIN -> self._mode == COCO2017Custom.Mode.TRAIN
      • os.path.join('caches', 'coco2017' -> os.path.join('caches', 'coco2017-custom'
    2. Modify CATEGORY_TO_LABEL_DICT and update def num_classes()
  2. In dataset/base.py, append a new branch under from_name function
    elif name == 'coco2017-custom':
        from dataset.coco2017_custom import COCO2017Custom
        return COCO2017Custom

Please let me know if it works for you.

culechetoo commented 5 years ago

Hi @potterhsu

Your solution worked for me perfectly! There was just a very small issue where you have this underlying assumption that all image_ids are integers. But dealing with that was really easy.

Thanks a lot for your help and your wonderful library!

Regards Chaitanya Agrawal