biubug6 / Pytorch_Retinaface

Retinaface get 80.99% in widerface hard val using mobilenet0.25.
MIT License
2.63k stars 774 forks source link

retinaface net=mvn0.25,在onnx导出加入boxes解码和landmark解码,直接在onnx输出,但是框可以输出,landmark分支在onnx导出报错? #179

Open EricHuiK opened 3 years ago

EricHuiK commented 3 years ago

import torch import torch.nn as nn import torchvision.models.detection.backbone_utils as backbone_utils import torchvision.models._utils as _utils import torch.nn.functional as F from collections import OrderedDict from layers.functions.prior_box import PriorBox from models.net import MobileNetV1 as MobileNetV1 from models.net import FPN as FPN from models.net import SSH as SSH import numpy as np def decode_fixed(loc, priors):

print("decode_fixed,,,,,,,,,,,,,,,,,,,,,,,,,,")

# print("loc.shape: ",loc.shape)
# print("priors.shape: ",priors.shape)
# variances=[0.1,0.2]
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
    loc (tensor): location predictions for loc layers,
        Shape: [num_priors,4]
    priors (tensor): Prior boxes in center-offset form.
        Shape: [num_priors,4].
    variances: (list[float]) Variances of priorboxes
Return:
    decoded bounding box predictions
"""
# print(loc.shape,"locccc")
boxes = torch.cat((
    priors[:, :, 0:1] + loc[:, :, 0:1] * 0.1 * priors[:, :, 2:3],
    priors[:, :, 1:2] + loc[:, :, 1:2] * 0.1 * priors[:, :, 3:4],

    priors[:, :, 2:3] * torch.exp(loc[:, :, 2:3] * 0.2),
    priors[:, :, 3:4] * torch.exp(loc[:, :, 3:4] * 0.2),
), 2)
return boxes

# boxes=boxes.numpy()
# center_x = boxes[:, 0]
# center_y = boxes[:, 1]
#
# w = boxes[:, 2]
# h = boxes[:, 3]
#
# xmin = center_x - (w / 2)
# ymin = center_y - (h / 2)
#
# xmax = center_x + (w / 2)
# ymax = center_y + (h / 2)
#
# return  torch.from_numpy(np.column_stack([xmin, ymin, xmax, ymax]))

# boxes = torch.cat((
#     priors[:, :2] + loc[:, :2] * 0.1 * priors[:, 2:],
#     priors[:, 2:] * torch.exp(loc[:, 2:] * 0.2)), 1)
# boxes[:, :2] -= boxes[:, 2:] / 2
# boxes[:, 2:] += boxes[:, :2]
# return boxes

def decode_landm(pre, priors, variances):

print("decode landmarks: ", ">>>>>>>>>>>>>>>>>>.")

# print("decode pre: ", pre)
# print("decode priors: ", priors)
# print("decode variances: ", variances)

"""Decode landm from predictions using priors to undo
   the encoding we did for offset regression at train time.
   Args:
       pre (tensor): landm predictions for loc layers,
           Shape: [num_priors,10]
       priors (tensor): Prior boxes in center-offset form.
           Shape: [num_priors,4].
       variances: (list[float]) Variances of priorboxes
   Return:
       decoded landm predictions

decode pre:  torch.Size([4200, 10])
decode priors:  torch.Size([4200, 4])
decode variances:  2

"""
# landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
#                     ), dim=1)

# landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
#                     priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:]
#                     # ,
#                     # priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
#                     ), dim=1)
# return landms

landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
                    priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
                    priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
                    priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
                    priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
                    ), dim=1)
return landms

class ClassHead(nn.Module): def init(self,inchannels=512,num_anchors=3): super(ClassHead,self).init() self.num_anchors = num_anchors self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)

def forward(self,x):
    out = self.conv1x1(x)
    out = out.permute(0,2,3,1).contiguous()

    return out.view(out.shape[0], -1, 2)

class BboxHead(nn.Module): def init(self,inchannels=512,num_anchors=3): super(BboxHead,self).init() self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)

def forward(self,x):
    out = self.conv1x1(x)
    out = out.permute(0,2,3,1).contiguous()

    return out.view(out.shape[0], -1, 4)

class LandmarkHead(nn.Module): def init(self,inchannels=512,num_anchors=3): super(LandmarkHead,self).init() self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)

def forward(self,x):
    out = self.conv1x1(x)
    out = out.permute(0,2,3,1).contiguous()

    return out.view(out.shape[0], -1, 10)

class RetinaFace(nn.Module): def init(self, cfg = None, phase = 'train'): """ :param cfg: Network related settings. :param phase: train or test. """ super(RetinaFace,self).init() self.phase = phase backbone = None if cfg['name'] == 'mobilenet0.25': backbone = MobileNetV1() if cfg['pretrain']: checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): name = k[7:] # remove module. new_state_dict[name] = v

load params

            backbone.load_state_dict(new_state_dict)
    elif cfg['name'] == 'Resnet50':
        import torchvision.models as models
        backbone = models.resnet50(pretrained=cfg['pretrain'])

    self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
    in_channels_stage2 = cfg['in_channel']
    in_channels_list = [
        in_channels_stage2 * 2,
        in_channels_stage2 * 4,
        in_channels_stage2 * 8,
    ]
    out_channels = cfg['out_channel']
    self.fpn = FPN(in_channels_list,out_channels)
    self.ssh1 = SSH(out_channels, out_channels)
    self.ssh2 = SSH(out_channels, out_channels)
    self.ssh3 = SSH(out_channels, out_channels)

    self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
    self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
    self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])

    self.priorbox = PriorBox(cfg, image_size=(320, 320), phase=self.phase)
    self.priors = self.priorbox.forward()

def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
    classhead = nn.ModuleList()
    for i in range(fpn_num):
        classhead.append(ClassHead(inchannels,anchor_num))
    return classhead

def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
    bboxhead = nn.ModuleList()
    for i in range(fpn_num):
        bboxhead.append(BboxHead(inchannels,anchor_num))
    return bboxhead

def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
    landmarkhead = nn.ModuleList()
    for i in range(fpn_num):
        landmarkhead.append(LandmarkHead(inchannels,anchor_num))
    return landmarkhead

def forward(self,inputs):
    out = self.body(inputs)
    from math import ceil

    # FPN
    fpn = self.fpn(out)

    # SSH
    feature1 = self.ssh1(fpn[0])
    feature2 = self.ssh2(fpn[1])
    feature3 = self.ssh3(fpn[2])
    features = [feature1, feature2, feature3]

    loc = list()
    conf = list()
    landm = list()

    for (x, l, c,lam) in zip(features, self.BboxHead, self.ClassHead,self.LandmarkHead):
        # loc.append(l(x).permute(0, 2,3, 1).contiguous())
        # conf.append(c(x).permute(0, 2,3,1).contiguous())
        # landm.append(lam(x).permute(0,2,3, 1).contiguous())
        loc.append(l(x))
        conf.append(c(x))
        landm.append(lam(x))

    feature_anchors = [tmp_ratio * ceil(320 / i) * ceil(320 / i) for
                       tmp_ratio, i in zip([ 2, 2,2], [8, 16, 32])]
    #
    # for each_feature in feature_anchors:
    #     print('each_feature.shape: ',each_feature)
    #
    # for each_loc in loc:
    #     print('each_loc: ',each_loc.shape)
    #
    # for each_landm in landm:
    #     print('each_landm: ',each_landm.shape)

    bbox_regressions = torch.cat(
        [o.view(-1, tmp_feature_anchors, 4) for tmp_feature_anchors, o in zip(feature_anchors, loc)], 1)

    classifications = torch.cat(
        [o.view(-1, tmp_feature_anchors, 2) for tmp_feature_anchors, o in zip(feature_anchors, conf)], 1)

        # ldm_regressions = torch.cat(
        #     [o.view(-1, tmp_feature_anchors, 10) for tmp_feature_anchors, o in zip(feature_anchors, conf)], 1)
            # torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1)

        # ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
        # ldm_regressions = torch.cat(
        #     [o.view(-1, tmp_feature_anchors, 10) for tmp_feature_anchors, o in zip(feature_anchors, landm)], 1)
    ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1)

    # ldm_regressions = torch.cat(landm, dim=1)

    # ldm_regressions = torch.cat(
    #     [o.view(-1, tmp_feature_anchors, 10) for tmp_feature_anchors, o in zip(feature_anchors, conf)], 1)

    anchor_num = int(sum(feature_anchors))

    conf = F.softmax(classifications, dim=-1)
    # landm_mark = decode_landm(ldm_regressions.data.squeeze(0), self.priors.data.squeeze(0), [0.1, 0.2])

    self.priors = self.priors.reshape([-1, anchor_num, 4])

    boxes = decode_fixed(bbox_regressions, self.priors)

    landm_mark = decode_landm(ldm_regressions.data.squeeze(0), bbox_regressions.data.squeeze(0), [0.1, 0.2])

    output = (boxes, conf, landm_mark)

    return output

    # bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
    # classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
    # ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
    #
    # if self.phase == 'train':
    #     output = (bbox_regressions, classifications, ldm_regressions)
    # else:
    #     output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
    # return output