Closed 1996scarlet closed 4 years ago
@1996scarlet Hi, Did you perform face detection first? The model is not designed to localize landmarks in the wild, you have to run face detector to crop faces.
The algorithm is designed to take temporal information, so the jittering effect exists on videos. You might implement some smoothing algorithm to smooth out this effect.
@protossw512 Hi, Thanks for your reply! Retinaface has been used for face detection in videos/ I'll try to add some smoothing algorithm
@1996scarlet Could you please share your test scripts?I will test with some smooth algorithms.
@flynnamy Try this:
# usage
from face_alignment import AlignmentorModel
fa = AlignmentorModel(stack=4)
landmarks = fa.get_landmarks(frame, box)
And the face_alignment.py
code:
# coding: utf-8
import os
import torch
import numpy as np
import cv2
import time
import collections
import models
models_urls = {
'WFLW_4HG': './weights/WFLW_4HG.pth',
}
class AlignmentorModel:
def __init__(self, device='cuda', stack=4, verbose=False):
self.device = device
self.verbose = verbose
self.stack_num = stack
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
print(f'Using model 2DFAN (Hourglass stack {self.stack_num})')
PRETRAINED_WEIGHTS = models_urls['WFLW_4HG']
self.net = models.FAN(self.stack_num)
if PRETRAINED_WEIGHTS != "None":
checkpoint = torch.load(PRETRAINED_WEIGHTS)
pretrained_weights = checkpoint['state_dict']
model_weights = self.net.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items()
if k in model_weights}
model_weights.update(pretrained_weights)
self.net.load_state_dict(model_weights)
self.net.to(device)
self.net.eval()
torch.set_grad_enabled(False)
def crop_face(self, frame, det):
img = frame[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
H, W, _ = img.shape
offset = W / 64.0, H / 64.0, det[0], det[1]
# offset = W / 128.0, H / 128.0, det[0], det[1]
return img, offset
def get_landmarks(self, image, det=None):
img, offset = self.crop_face(image, det)
cropped = cv2.resize(img, (256, 256))[..., ::-1]
inp = cropped.copy()
inp = torch.from_numpy(inp.transpose((2, 0, 1))).float()
inp = inp.to(self.device)
inp.div_(255.0).unsqueeze_(0)
# st = time.perf_counter()
outputs, boundary_channels = self.net(inp)
out = outputs[-1][:, :-1, :, :]
heatmaps = out.detach().cpu().numpy()
# print(time.perf_counter() - st)
pred = self._calculate_points(heatmaps).reshape(-1, 2)
pred *= offset[:2]
pred += offset[-2:]
return pred
def _calculate_points(self, heatmaps, center=None, scale=None):
B, N, H, W = heatmaps.shape
HW = H * W
BN_range = np.arange(B * N)
heatline = heatmaps.reshape(B, N, HW)
indexes = np.argmax(heatline, axis=2)
preds = np.stack((indexes % W, indexes // W), axis=2)
preds = preds.astype(np.float, copy=False)
inr = indexes.ravel()
heatline = heatline.reshape(B*N, HW)
x_up = heatline[BN_range, inr + 1]
x_down = heatline[BN_range, inr - 1]
y_up = heatline[BN_range, inr + W]
y_down = heatline[BN_range, inr - W]
think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1))
think_diff *= .25
preds += think_diff.reshape(B, N, 2)
preds += .5
return preds
@flynnamy @1996scarlet did you manage to get smooth landmarks on video? Would it be possible to share your code? Thanks!
在视频中测试时,结果抖动明显