Open kyakuno opened 1 month ago
@ooe1123 GroundingDINOとSAMの対応、ありがとうございます。SAMの後、可能であればこちらをお願いできると嬉しいです。
入力画像とリファレンス画像の顔のキーポイントを取得、入力画像のキーポイントをリファレンス画像のキーポイントに近づけるようにAIで補正、変形したキーポイントと入力画像をワープモジュールに入れて画像変換する。顔全体、目、リップで独立してキーポイントの補正をしている。
ビデオモードで入力した画像をリファレンスとして、リアルタイムに変形したい。
〇 src/live_portrait_wrapper.py
class LivePortraitWrapper(object):
...
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
with torch.autocast(...):
feature_3d = self.appearance_feature_extractor(x)
↓
class LivePortraitWrapper(object):
...
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
with torch.autocast(...):
if 1:
print("------>")
torch.onnx.export(
self.appearance_feature_extractor, x, 'appearance_feature_extractor.onnx',
input_names=["x"],
output_names=["f_s"],
verbose=False, opset_version=17
)
print("<------")
exit()
〇 src/live_portrait_wrapper.py
class LivePortraitWrapper(object):
...
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
...
with torch.no_grad():
with torch.autocast(...):
kp_info = self.motion_extractor(x)
↓
class LivePortraitWrapper(object):
...
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
...
with torch.no_grad():
with torch.autocast(...):
kp_info = self.motion_extractor(x)
if 1:
class Exp(torch.nn.Module):
def __init__(self, motion_extractor):
super().__init__()
self.motion_extractor = motion_extractor
def forward(self, x):
kp_info = self.motion_extractor(x)
return kp_info["pitch"], kp_info["yaw"], kp_info["roll"], kp_info["t"], kp_info["exp"], kp_info["scale"], kp_info["kp"]
print("------>")
model = Exp(self.motion_extractor)
torch.onnx.export(
model, x, 'motion_extractor.onnx',
input_names=["x"],
output_names=["pitch", "yaw", "roll", "t", "exp", "scale", "kp"],
verbose=False, opset_version=17
)
print("<------")
exit()
〇 src/live_portrait_wrapper.py
class LivePortraitWrapper(object):
...
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
↓
class LivePortraitWrapper(object):
...
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
if 1:
print("------>")
torch.onnx.export(
self.stitching_retargeting_module['stitching'], feat_stiching, 'stitching.onnx',
input_names=["x"],
output_names=["out"],
verbose=False, opset_version=17
)
print("<------")
exit()
〇 src/live_portrait_wrapper.py
class LivePortraitWrapper(object):
...
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
with torch.autocast(...):
...
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
↓
class LivePortraitWrapper(object):
...
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
if 1:
class Exp(torch.nn.Module):
def __init__(self, warping_module):
super().__init__()
self.warping_module = warping_module
def forward(self, feature_3d, kp_source, kp_driving):
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
return ret_dct["out"], ret_dct["occlusion_map"], ret_dct["deformation"]
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
print("------>")
model = Exp(self.warping_module)
x = (feature_3d, kp_source, kp_driving)
torch.onnx.export(
model, x, 'warping_module.onnx',
input_names=["feature_3d", "kp_source", "kp_driving"],
output_names=["out", "occlusion_map", "deformation"],
verbose=False, opset_version=20
)
print("<------")
exit()
〇 src/live_portrait_wrapper.py
class LivePortraitWrapper(object):
...
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
with torch.autocast(...):
...
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
↓
class LivePortraitWrapper(object):
...
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
...
with torch.no_grad():
with torch.autocast(...):
...
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
if 1:
print("------>")
torch.onnx.export(
self.spade_generator.cpu(), ret_dct['out'].cpu().type(torch.float32), 'spade_generator.onnx',
input_names=["feature"],
output_names=["out"],
verbose=False, opset_version=17
)
print("<------")
exit()
https://github.com/KwaiVGI/LivePortrait