lzx1413 / PytorchSSD

pytorch version of SSD and it's enhanced methods such as RFBSSD,FSSD and RefineDet
MIT License
709 stars 237 forks source link

How can I do inference after finishing the learning? #80

Open NogizakaDaisuki opened 5 years ago

NogizakaDaisuki commented 5 years ago

As mentioned in the title, how can I do inference after finishing the learning?

cxf2015 commented 5 years ago

self.weight_file = weight_file self.gpu_id = [gpu_id] self.img_dim = 640 self.num_classes = 6 self.net = build_net(self.img_dim, self.num_classes, use_refine=True) self.rgb_std = (1, 1, 1) self.rgb_means = (104, 117, 123) self.thresh = 0.4

    # print('Loading resume network', weight_file)
    state_dict = torch.load(weight_file)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    self.net.load_state_dict(new_state_dict)

    if self.gpu_id is not None:
        self.net = torch.nn.DataParallel(self.net, device_ids=self.gpu_id)
        self.net.cuda()
        cudnn.benchmark = True

    self.priorbox = PriorBox(VOC_512)
    self.detector = Detect(self.num_classes, 0, VOC_512, object_score=0.01)
    self.priors = Variable(self.priorbox.forward(), volatile=True)

    self.net.eval()
    self.transform = BaseTransform(self.net.module.size, self.rgb_means, self.rgb_std, (2, 0, 1))
def forward(self, image):

    x = Variable(self.transform(image).unsqueeze(0), volatile=True)
    if self.gpu_id is not None:
        x = x.cuda()

    out = self.net(x=x, test=True)  # forward pass
    print('out ok')
    arm_loc, arm_conf, odm_loc, odm_conf = out
    print(np.shape(arm_loc))
    print(np.shape(arm_conf))
    print(np.shape(odm_loc))
    print(np.shape(odm_conf))
    boxes, scores = self.detector.forward((odm_loc, odm_conf), self.priors, (arm_loc, arm_conf))
    # detect_time = _t['im_detect'].toc()
    boxes = boxes[0]
    scores = scores[0]

    boxes = boxes.cpu().numpy()
    scores = scores.cpu().numpy()
    # scale each detection back up to the image
    scale = torch.Tensor([image.shape[1], image.shape[0],
                          image.shape[1], image.shape[0]]).cpu().numpy()
    boxes *= scale

    num_images = 1
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(self.num_classes)]

    bboxes = []
    for j in range(1, self.num_classes):
        # print(list(scores[:, j]))
        inds = np.where(scores[:, j] > self.thresh)[0]
        if len(inds) == 0:
            all_boxes[j][0] = np.empty([0, 5], dtype=np.float32)
            continue
        c_bboxes = boxes[inds]
        c_scores = scores[inds, j]
        c_dets = np.hstack((c_bboxes, c_scores[:, np.newaxis])).astype(np.float32, copy=False)

        keep = nms(c_dets, 0.45, force_cpu=True)
        keep = keep[:100]
        c_dets = c_dets[keep, :]
        all_boxes[j][0] = c_dets
        for x in range(len(c_dets)):
            # if not self.HaveRedMarker(image, [int(c_dets[x][0]), int(c_dets[x][1]), int(c_dets[x][2]), int(c_dets[x][3])]):
            #     continue
            bboxes.append([c_dets[x][0], c_dets[x][1], c_dets[x][2], c_dets[x][3], j])

    return bboxes