jwyang / faster-rcnn.pytorch

A faster pytorch implementation of faster r-cnn
MIT License
7.67k stars 2.33k forks source link

error while running demo.py #26

Closed manoja328 closed 6 years ago

manoja328 commented 6 years ago

while running demo.py on my own data-set it throws the following error

load checkpoint ./savemodels/res101/xview/faster_rcnn_1_17_1043.pth While copying the parameter named RCNN_rpn.RPN_cls_score.weight, whose dimensions in the model are torch.Size([18, 512, 1, 1]) and whose dimensions in the checkpoint are torch.Size([24, 512, 1, 1]), ...

Do you know what could cause this issue?

jwyang commented 6 years ago

Hi, @manoja328

You might need to change line in demo.py. Now it merely support the object categories from pascal_voc. you need to change pascal_classes so that it matches to your own dataset.

dotannn commented 6 years ago

@manoja328 - your error is actually because of the default ANCHOR_SCALES configuration have 3 scales and this model trained on 4. you need to change the line in config.py to: __C.ANCHOR_SCALES = [4,8,16,32]

After you'll do that. you will encounter the error that @jwyang mentioned about number of classes. you need to change the list of pascal_voc classes (that contains only 21 classes) to MS COCO mapping (that contains 81 classes)

manoja328 commented 6 years ago

My solution was following changes to demo.py

  args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']

  if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)
  if args.set_cfgs is not None:
    cfg_from_list(args.set_cfgs)

  print('Using config:')
  pprint.pprint(cfg)
  np.random.seed(cfg.RNG_SEED)

  # train set
  # -- Note: Use validation set and disable the flipped to enable faster loading.

  input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
  if not os.path.exists(input_dir):
    raise Exception('There is no input directory for loading network from ' + input_dir)
  load_name = os.path.join(input_dir,
    'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))

  pascal_classes = np.asarray(['__background__',"person","bicycle","car","motorbike","aeroplane","bus","train","truck","boat",
                               "traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat","dog","horse",
                               "sheep","cow","elephant","bear","zebra","giraffe","backpack","umbrella","handbag","tie",
                               "suitcase","frisbee","skis","snowboard","sports ball","kite","baseball bat",
                               "baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass",
                               "cup","fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli",
                               "carrot","hot dog","pizza","donut","cake","chair","sofa","pottedplant","bed",
                               "diningtable","toilet","tvmonitor","laptop","mouse","remote","keyboard","cell phone",
                               "microwave","oven","toaster","sink",
                               "refrigerator","book","clock","vase","scissors","teddy bear","hair drier","toothbrush"])