SthPhoenix / InsightFace-REST

InsightFace REST API for easy deployment of face recognition services with TensorRT in Docker.
Apache License 2.0
489 stars 118 forks source link

What is the workflow to add a new model? #121

Open xaviermerino opened 9 months ago

xaviermerino commented 9 months ago

Hello!

Thanks for the incredible work here!

I am wondering how to add a new model to the project. I'm currently looking to work with AdaFace. I have a ONNX file that is derived from their checkpoint for R100-WebFace12M and I have tried adding it as a new model. I noticed that ArcFace expects the ONNX inputs to be under input.1 so I modified the ONNX inputs for that as well. The outputs are 512 dimensions just as in ArcFace.

By the way, AdaFace clarifies:

Note that our pretrained model takes the input in BGR color channel. This is different from the InsightFace released model which uses RGB color channel.

~The main issue is that I can't run it as if it was a custom trained ArcFace model because it requires a mean of 0.5 and std of 0.5. Although that is very similar to the 127.5 that is ArcFace's default for this project.~ So I thought I would add another entry to config.py like this:

'adaface_torch': {
    'in_package': False,
    'shape': (1, 3, 112, 112),
    'allow_batching': True,
    'function': 'adaface_torch',
    'reshape': False
}

Where I declare the function adaface_torch in model_zoo/face_processors.py as:

# Backend wrapper for PyTorch trained models, which requires image normalization
def adaface_torch(model_path, backend, **kwargs):
    model = backend.Arcface(rec_name=model_path, input_mean=127.5, input_std=127.5, **kwargs)
    return model

And add it to the func_map in model_zoo/getter.py to look like:

func_map = {
    'genderage_v1': genderage_v1,
    'retinaface_r50_v1': retinaface_r50_v1,
    'retinaface_mnet025_v1': retinaface_mnet025_v1,
    'retinaface_mnet025_v2': retinaface_mnet025_v2,
    'mnet_cov2': mnet_cov2,
    'centerface': centerface,
    'dbface': dbface,
    'scrfd': scrfd,
    'scrfd_v2': scrfd_v2,
    'arcface_mxnet': arcface_mxnet,
    'arcface_torch': arcface_torch,
    'mask_detector': mask_detector,
    'yolov5_face': yolov5_face,
    'adaface_torch': adaface_torch,
}

At this point, it errors out when loading the workers with a NameError for adaface_torch. I am not sure where it needs to be defined so I was wondering if you could shed a light on the general workflow needed to add additional models.

SthPhoenix commented 9 months ago

Hi! I can't figure out what have you missed, looks like you have done everything needed to add model.

I have just committed code to support adaface_ir101_webface12m model, including automatic download of onnx file.

Though converting adaface model to compatible onnx took few more steps:

  1. Remove norm output and embedding normalization from model forward.
  2. Add dynamic shapes for input and output to support batch inference.
SthPhoenix commented 9 months ago

Here's conversion code:

import os

import numpy as np
import torch
import onnx

import net

adaface_models = {
    "ir_101": "./pretrained/adaface_ir101_webface12m.ckpt",
}

def load_pretrained_model(architecture="ir_101"):
    # load model and pretrained statedict
    assert architecture in adaface_models.keys()
    model = net.build_model(architecture)
    statedict = torch.load(
        adaface_models[architecture])["state_dict"]
    model_statedict = {
        key[6:]: val
        for key, val in statedict.items()
        if key.startswith("model.")
    }
    model.load_state_dict(model_statedict)
    model.eval()
    return model

def to_input(pil_rgb_image):
    np_img = np.array(pil_rgb_image)
    brg_img = ((np_img[:, :, ::-1] / 255.0) - 0.5) / 0.5
    tensor = torch.tensor([brg_img.transpose(2, 0, 1)]).float()
    return tensor

if __name__ == "__main__":
    model = load_pretrained_model("ir_101")
    model.eval()
    x = torch.randn(112, 112, 3)
    x = to_input(x)

    input_names = ['input']
    output_names = ['output']
    dynamic_axes = {out: {0: '?'} for out in output_names}

    dynamic_axes[input_names[0]] = {
        0: '?',
    }
    # # * For onnx model
    torch.onnx.export(
        model,
        x,
        "adaface_ir101_webface12m.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        verbose=False,
        dynamic_axes=dynamic_axes,
        opset_version=13,
        export_params=True,
    )

Based on this comment: https://github.com/mk-minchul/AdaFace/issues/43#issuecomment-1714965955