axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
1.98k stars 316 forks source link

ADD LivePortrait #1506

Open kyakuno opened 1 month ago

kyakuno commented 1 month ago

https://github.com/KwaiVGI/LivePortrait

kyakuno commented 1 month ago

Test https://huggingface.co/spaces/KwaiVGI/LivePortrait

https://github.com/axinc-ai/ailia-models/assets/38881041/50e12738-295b-44e8-9781-262443d05df0

kyakuno commented 1 month ago

@ooe1123 GroundingDINOとSAMの対応、ありがとうございます。SAMの後、可能であればこちらをお願いできると嬉しいです。

kyakuno commented 1 month ago

入力画像とリファレンス画像の顔のキーポイントを取得、入力画像のキーポイントをリファレンス画像のキーポイントに近づけるようにAIで補正、変形したキーポイントと入力画像をワープモジュールに入れて画像変換する。顔全体、目、リップで独立してキーポイントの補正をしている。

kyakuno commented 1 month ago

ビデオモードで入力した画像をリファレンスとして、リアルタイムに変形したい。

yuananf commented 1 month ago

https://github.com/warmshao/FasterLivePortrait

ooe1123 commented 1 month ago

appearance_feature_extractor.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                feature_3d = self.appearance_feature_extractor(x)

class LivePortraitWrapper(object):
    ...
    def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                if 1:
                    print("------>")
                    torch.onnx.export(
                        self.appearance_feature_extractor, x, 'appearance_feature_extractor.onnx',
                        input_names=["x"],
                        output_names=["f_s"],
                        verbose=False, opset_version=17
                    )
                    print("<------")
                    exit()

motion_extractor.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                kp_info = self.motion_extractor(x)

class LivePortraitWrapper(object):
    ...
    def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                kp_info = self.motion_extractor(x)

            if 1:
                class Exp(torch.nn.Module):
                    def __init__(self, motion_extractor):
                        super().__init__()
                        self.motion_extractor = motion_extractor

                    def forward(self, x):
                        kp_info = self.motion_extractor(x)
                        return kp_info["pitch"], kp_info["yaw"], kp_info["roll"], kp_info["t"], kp_info["exp"], kp_info["scale"], kp_info["kp"]

                print("------>")
                model = Exp(self.motion_extractor)
                torch.onnx.export(
                    model, x, 'motion_extractor.onnx',
                    input_names=["x"],
                    output_names=["pitch", "yaw", "roll", "t", "exp", "scale", "kp"],
                    verbose=False, opset_version=17
                )
                print("<------")
                exit()

stitching.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            delta = self.stitching_retargeting_module['stitching'](feat_stiching)

class LivePortraitWrapper(object):
    ...
    def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            if 1:
                print("------>")
                torch.onnx.export(
                    self.stitching_retargeting_module['stitching'], feat_stiching, 'stitching.onnx',
                    input_names=["x"],
                    output_names=["out"],
                    verbose=False, opset_version=17
                )
                print("<------")
                exit()

warping_module.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                ...
                ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            if 1:
                class Exp(torch.nn.Module):
                    def __init__(self, warping_module):
                        super().__init__()
                        self.warping_module = warping_module

                    def forward(self, feature_3d, kp_source, kp_driving):
                        ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
                        return ret_dct["out"], ret_dct["occlusion_map"], ret_dct["deformation"]

                with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
                    print("------>")
                    model = Exp(self.warping_module)
                    x = (feature_3d, kp_source, kp_driving)
                    torch.onnx.export(
                        model, x, 'warping_module.onnx',
                        input_names=["feature_3d", "kp_source", "kp_driving"],
                        output_names=["out", "occlusion_map", "deformation"],
                        verbose=False, opset_version=20
                    )
                    print("<------")
                    exit()

spade_generator.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                ...
                ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)

                # decode
                ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                ...
                ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)

                if 1:
                    print("------>")
                    torch.onnx.export(
                        self.spade_generator.cpu(), ret_dct['out'].cpu().type(torch.float32), 'spade_generator.onnx',
                        input_names=["feature"],
                        output_names=["out"],
                        verbose=False, opset_version=17
                    )
                    print("<------")
                    exit()