Tencent / TFace

A trusty face analysis research platform developed by Tencent Youtu Lab
Apache License 2.0
1.29k stars 225 forks source link

CurricularFace_Backbone 加载错误问题 似乎源码有错误? #21

Open DefTruth opened 3 years ago

DefTruth commented 3 years ago

对于CurricularFace的pretrained,原来的BasicBlockIR我无法加载成功,似乎是有一个算子错位了。我修改后加载成功了。

class BasicBlockIR(Module):
    """ BasicBlock for IRNet
    """
    def __init__(self, in_channel, depth, stride):
        super(BasicBlockIR, self).__init__()
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth))
        # self.res_layer = Sequential(
        #     BatchNorm2d(in_channel),
        #     Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
        #     BatchNorm2d(depth),
        #     PReLU(depth),
        #     Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
        #     BatchNorm2d(depth))

        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
            # BatchNorm2d(depth),  # 需要去掉这个batchnorm
            PReLU(depth),
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
            BatchNorm2d(depth))

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)

        return res + shortcut

加载成功的log:

  %1081 = Gemm[alpha = 1, beta = 1, transB = 1](%1080, %output_layer.3.weight, %output_layer.3.bias)
  %1082 = Constant[value = <Tensor>]()
  %1083 = Constant[value = <Tensor>]()
  %embedding = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%1081, %1082, %1083, %output_layer.4.running_mean, %output_layer.4.running_var)
  return %embedding
}
Checking 0/3...
Checking 1/3...
Checking 2/3...
Converted ./pretrained/CurricularFace_Backbone.pth to ./pretrained/CurricularFace_Backbone.onnx done!

而在修改之前会出现以下错误:

    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Backbone:
        Missing key(s) in state_dict: "body.0.res_layer.2.bias", "body.0.res_layer.2.running_mean", "body.0.res_layer.2.running_var", "body.0.res_layer.5.weight", "body.0.res_layer.5.bias", "body.0.res_layer.5.running_mean", "body.0.res_layer.5.running_var", "body.1.res_layer.2.bias", "body.1.res_layer.2.running_mean", "body.1.res_layer.2.running_var", "body.1.res_layer.5.weight", "body.1.res_layer.5.bias", "body.1.res_layer.5.running_mean",  
...
       size mismatch for body.0.res_layer.3.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for body.0.res_layer.4.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
        size mismatch for body.1.res_layer.3.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for body.1.res_layer.4.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
        size mismatch for body.2.res_layer.3.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64]).
        size mismatch for body.2.res_layer.4.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
        size mismatch for body.3.res_layer.3.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for body.3.res_layer.4.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
        size mismatch for body.4.res_layer.3.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128]).
        size mismatch for body.4.res_layer.4.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
        size mismatch for body.5.res_layer.3.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128]).

从这段log看,似乎是由于一个算子错位引起的,它尝试将一个卷积Conv的权重拷贝到BatchNorm上。我修改过后可以成功加载pth,并转换成onnx。我想这应该不影响使用。 以下是我完整的测试代码:

import cv2
import onnx
import torch
import numpy as np

from torchkit.backbone import get_model

def convert_to_onnx(pretrained_path="./pretrained/BUPT_Balancedface_IR_34.pth",
                    backbone_type="IR_34", do_simplify=True,
                    output_path="./pretrained/BUPT_Balancedface_IR_34.onnx"):
    # assert backbone_type in ("IR_34", "IR_101", "IR_SE_101")
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = get_model(backbone_type)([112, 112])
    model.load_state_dict(torch.load(pretrained_path, map_location=device))
    model = model.to(device)
    model.eval()
    print(f"Load {pretrained_path} done! Device: {device}")

    test_path = "./test.png"
    img = cv2.imread(test_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (112, 112)).astype(np.float32)
    face = img / 255.0  # (0, 1)
    face = (face - 0.5) / 0.5  # (-1, 1)
    face = np.transpose(face, (2, 0, 1))  # (3,112,112)
    face = np.expand_dims(face, 0)  
    face_tensor = torch.from_numpy(face)

    embeddings = model(face_tensor).detach().cpu().numpy()

    print("Pth Detect done.")
    print(embeddings.shape)
    print('Converting ...')
    torch.onnx.export(model, face_tensor, output_path,
                      input_names=["input"],
                      output_names=["embedding"],
                      keep_initializers_as_inputs=False,
                      verbose=False, opset_version=11)
    model_onnx = onnx.load(output_path)
    print(onnx.helper.printable_graph(model_onnx.graph))
    if do_simplify:
        from onnxsim import simplify
        model_onnx, check = simplify(model_onnx, check_n=3)
        assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_onnx, output_path)
    print(f"Converted {pretrained_path} to {output_path} done!")

if __name__ == "__main__":
    # convert_to_onnx(pretrained_path="./pretrained/BUPT_Balancedface_IR_34.pth",
    #                 output_path="./pretrained/BUPT_Balancedface_IR_34.onnx",
    #                 backbone_type="IR_34")
    convert_to_onnx(pretrained_path="./pretrained/CurricularFace_Backbone.pth",
                    output_path="./pretrained/CurricularFace_Backbone.onnx",
                    backbone_type="IR_101")
    """
    PYTHONPATH=. python3 ./inference.py
    """

以上是我在加载基于IR_101的CurricularFace遇到的问题。但除此外,cifp中提供的模型是基于IR_34的,奇怪的是,IR_34必须在我修改前的BasicBlockIR才能加载成功。即CurricularFace中基于IR_101的pretrained与cifp中基于IR_34的pretrained,应该是使用两个不同版本的BasicBlockIR来训练的,他们并不能共用同一个BasicBlockIR模块。cifp中基于IR_34的pretrained中的BasicBlockIR需要长这样子:

class BasicBlockIR(Module):
    """ BasicBlock for IRNet
    """
    def __init__(self, in_channel, depth, stride):
        super(BasicBlockIR, self).__init__()
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth))

        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
            BatchNorm2d(depth),  # 没有注释这个batchnorm
            PReLU(depth),
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
            BatchNorm2d(depth))

        print("BasicBlockIR")

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)

        return res + shortcut
wjxzju commented 3 years ago

感谢这么细致的分析,这里的确是我们之前做早期CurricularFace实验的一个bug,当时BasicBlockIR中res_layer只有两个BN,我们放出来的CurricularFace的ckpt也是这个版本的,后来我们做cifp实验时发现了与arcface论文中的IRblock对比少了这个BN,所以加上了

DefTruth commented 3 years ago

感谢这么细致的分析,这里的确是我们之前做早期CurricularFace实验的一个bug,当时BasicBlockIR中res_layer只有两个BN,我们放出来的CurricularFace的ckpt也是这个版本的,后来我们做cifp实验时发现了与arcface论文中的IRblock对比少了这个BN,所以加上了

哦哦,那就明白了,谢谢回复~

zhangxiaopang88 commented 3 years ago

你们有试过加载cifp152的模型吗,我现在测试152的模型发现相同人和不同人的相似度都是0.99左右,我用同样的代码测试cifp101的模型就没问题。

DefTruth commented 3 years ago

你们有试过加载cifp152的模型吗,我现在测试152的模型发现相同人和不同人的相似度都是0.99左右,我用同样的代码测试cifp101的模型就没问题。

没有试过,只试过 IR34 的,没有发现问题。