jwyang / faster-rcnn.pytorch

A faster pytorch implementation of faster r-cnn
MIT License
7.7k stars 2.33k forks source link

train on vgg11 based model get near 0 mAP #233

Open twangnh opened 6 years ago

twangnh commented 6 years ago

I was trying to change the backbone to vgg11 as

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
import torchvision.models as models
from model.faster_rcnn.faster_rcnn import _fasterRCNN
import pdb

class vgg11(_fasterRCNN):
    def __init__(self, classes, pretrained=False, class_agnostic=False):
        self.model_path = 'data/pretrained_model/vgg11_caffe.pth'
        self.dout_base_model = 512
        self.pretrained = pretrained
        self.class_agnostic = class_agnostic

        _fasterRCNN.__init__(self, classes, class_agnostic)

    def _init_modules(self):
        vgg = models.vgg11()
        if self.pretrained:
            print("Loading pretrained weights from %s" % (self.model_path))
            state_dict = torch.load(self.model_path)
            vgg.load_state_dict({k: v for k, v in state_dict.items() if k in vgg.state_dict()})

        vgg.classifier = nn.Sequential(*list(vgg.classifier._modules.values())[:-1])

        # not using the last maxpool layer
        self.RCNN_base = nn.Sequential(*list(vgg.features._modules.values())[:-1])

        # Fix the layers before conv3:
        for layer in range(7):
            for p in self.RCNN_base[layer].parameters(): p.requires_grad = False

        # self.RCNN_base = _RCNN_base(vgg.features, self.classes, self.dout_base_model)

        self.RCNN_top = vgg.classifier

        # not using the last maxpool layer
        self.RCNN_cls_score = nn.Linear(4096, self.n_classes)

        if self.class_agnostic:
            self.RCNN_bbox_pred = nn.Linear(4096, 4)
        else:
            self.RCNN_bbox_pred = nn.Linear(4096, 4 * self.n_classes)

    def _head_to_tail(self, pool5):

        pool5_flat = pool5.view(pool5.size(0), -1)
        fc7 = self.RCNN_top(pool5_flat)

        return fc7

I trained with--dataset pascal_voc --net vgg11 --bs 1 --nw 4 --lr 1e-3 --lr_decay_step 5 --cuda --gpu 1 the loss seems right, all the way down to around 0.9 with 6 epoch, but when I eval it, I got

AP for aeroplane = 0.0000
AP for bicycle = 0.0003
AP for bird = 0.0003
AP for boat = 0.0000
AP for bottle = 0.0000
AP for bus = 0.0000
AP for car = 0.0003
AP for cat = 0.0005
AP for chair = 0.0004
AP for cow = 0.0004
AP for diningtable = 0.0001
AP for dog = 0.0004
AP for horse = 0.0013
AP for motorbike = 0.0003
AP for person = 0.0163
AP for pottedplant = 0.0002
AP for sheep = 0.0000
AP for sofa = 0.0008
AP for train = 0.0002
AP for tvmonitor = 0.0000
Mean AP = 0.0011
wadimkehl commented 6 years ago

Your hyperparameters are bad. Just looking at batchsize=1 already tells me that your network numerically breaks. Have bs of at least 16

jwyang commented 6 years ago

@MrWanter have you solved this problem?

twangnh commented 6 years ago

@jwyang Yes, I dont know where I was wrong, but now train and test on VOC07, I only get

--bs 24 --lr 1e-2 --lr_decay_step 10 55.2 mAP(test at epoch 11) 
--bs 12 --lr 5e-3 --lr_decay_step 10 56.6 mAP(test at epoch 11) 
--bs 4 --lr 4e-3 --lr_decay_stpe 8  59.7 mAP(test at epoch 11) 

seems mAP is much degrade than vgg16(70 mAP), and smaller batch size get better result, do you think this is the expected performance we can get with vgg11, or could you pls give me some suggestion for improving the vgg11 version of faster-rcnn, I merely adopt the vgg16 settings, thank you.