protossw512 / AdaptiveWingLoss

[ICCV 2019] Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression - Official Implementation
Apache License 2.0
395 stars 88 forks source link

When testing in videos, the results are dithering #9

Closed 1996scarlet closed 4 years ago

1996scarlet commented 4 years ago

在视频中测试时,结果抖动明显 a

protossw512 commented 4 years ago

@1996scarlet Hi, Did you perform face detection first? The model is not designed to localize landmarks in the wild, you have to run face detector to crop faces.

The algorithm is designed to take temporal information, so the jittering effect exists on videos. You might implement some smoothing algorithm to smooth out this effect.

1996scarlet commented 4 years ago

@protossw512 Hi, Thanks for your reply! Retinaface has been used for face detection in videos/ I'll try to add some smoothing algorithm

flynnamy commented 4 years ago

@1996scarlet Could you please share your test scripts?I will test with some smooth algorithms.

1996scarlet commented 4 years ago

@flynnamy Try this:

# usage
from face_alignment import AlignmentorModel
fa = AlignmentorModel(stack=4)
landmarks = fa.get_landmarks(frame, box)

And the face_alignment.py code:

# coding: utf-8
import os
import torch
import numpy as np
import cv2
import time
import collections
import models

models_urls = {
    'WFLW_4HG': './weights/WFLW_4HG.pth',
}

class AlignmentorModel:
    def __init__(self, device='cuda', stack=4, verbose=False):
        self.device = device
        self.verbose = verbose
        self.stack_num = stack

        if 'cuda' in device:
            torch.backends.cudnn.benchmark = True

        print(f'Using model 2DFAN (Hourglass stack {self.stack_num})')

        PRETRAINED_WEIGHTS = models_urls['WFLW_4HG']
        self.net = models.FAN(self.stack_num)

        if PRETRAINED_WEIGHTS != "None":
            checkpoint = torch.load(PRETRAINED_WEIGHTS)
            pretrained_weights = checkpoint['state_dict']
            model_weights = self.net.state_dict()
            pretrained_weights = {k: v for k, v in pretrained_weights.items()
                                  if k in model_weights}
            model_weights.update(pretrained_weights)
            self.net.load_state_dict(model_weights)

        self.net.to(device)
        self.net.eval()
        torch.set_grad_enabled(False)

    def crop_face(self, frame, det):
        img = frame[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
        H, W, _ = img.shape
        offset = W / 64.0, H / 64.0, det[0], det[1]
        # offset = W / 128.0, H / 128.0, det[0], det[1]
        return img, offset

    def get_landmarks(self, image, det=None):
        img, offset = self.crop_face(image, det)
        cropped = cv2.resize(img, (256, 256))[..., ::-1]
        inp = cropped.copy()
        inp = torch.from_numpy(inp.transpose((2, 0, 1))).float()
        inp = inp.to(self.device)
        inp.div_(255.0).unsqueeze_(0)

        # st = time.perf_counter()
        outputs, boundary_channels = self.net(inp)
        out = outputs[-1][:, :-1, :, :]
        heatmaps = out.detach().cpu().numpy()
        # print(time.perf_counter() - st)

        pred = self._calculate_points(heatmaps).reshape(-1, 2)

        pred *= offset[:2]
        pred += offset[-2:]

        return pred

    def _calculate_points(self, heatmaps, center=None, scale=None):
        B, N, H, W = heatmaps.shape
        HW = H * W
        BN_range = np.arange(B * N)

        heatline = heatmaps.reshape(B, N, HW)
        indexes = np.argmax(heatline, axis=2)

        preds = np.stack((indexes % W, indexes // W), axis=2)
        preds = preds.astype(np.float, copy=False)

        inr = indexes.ravel()

        heatline = heatline.reshape(B*N, HW)
        x_up = heatline[BN_range, inr + 1]
        x_down = heatline[BN_range, inr - 1]
        y_up = heatline[BN_range, inr + W]
        y_down = heatline[BN_range, inr - W]

        think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1))
        think_diff *= .25

        preds += think_diff.reshape(B, N, 2)
        preds += .5

        return preds
antithing commented 4 years ago

@flynnamy @1996scarlet did you manage to get smooth landmarks on video? Would it be possible to share your code? Thanks!