XgTu / 2DASL

The code (pytorch for testing & matlab for 3D plot and evaluation) for our project: Joint 3D Face Reconstruction and Dense Face Alignment from A Single Image with 2D-Assisted Self-Supervised Learning (2DASL)
MIT License
464 stars 116 forks source link

compare with 3ddfa!!! i think 3ddfa is better。。 #14

Open ZHANG-SHI-CHANG opened 5 years ago

ZHANG-SHI-CHANG commented 5 years ago

3ddfa: test_vertex

2dasl: test_vertex

我感觉3ddfa更好,你认为我测试2dasl时问题会出在哪呢。。

cleardusk commented 5 years ago

What about the running (inference) speed? @ZHANG-SHI-CHANG

ZHANG-SHI-CHANG commented 5 years ago

What about the running (inference) speed? @ZHANG-SHI-CHANG

我没有关注测试时间,不过2dasl使用resnet50作为backbone,我想肯定比你的3ddfa慢很多,毕竟你用的是mobilenetv2,我比较奇怪的是为什么2dsal的表现不如3ddfa,可能我使用的数据的原因或者我写的2dasl测试程序哪里出毛病了?

XgTu commented 5 years ago

@ZHANG-SHI-CHANG Hi, our 2DASL follows the work 3DDFA, and we use resnet50 as backbone, so the running time should be longer than 3DDFA. But our performance is much better than 3DDFA, I am quite sure about that. I think there should be something wrong with your reimplementation, pls check carefully. Thx~ By the way, thanks for the great job of 3DDFA @cleardusk

Light-SH commented 5 years ago

@ZHANG-SHI-CHANG 请问方便分享一下你的测试代码吗

ZHANG-SHI-CHANG commented 5 years ago

@ZHANG-SHI-CHANG Hi, our 2DASL follows the work 3DDFA, and we use resnet50 as backbone, so the running time should be longer than 3DDFA. But our performance is much better than 3DDFA, I am quite sure about that. I think there should be something wrong with your reimplementation, pls check carefully. Thx~ By the way, thanks for the great job of 3DDFA @cleardusk

我会尽快检查并把测试代码发在这,谢谢回复!希望你有时间帮忙看一下和你实现的diff在哪里,十分感谢!!

ZHANG-SHI-CHANG commented 5 years ago
###main
from utils import *

ckpt_path = 'models/2DASL_checkpoint_epoch_allParams_stage2.pth.tar'
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['res_state_dict']
state_dict = {}
for key, value in ckpt.items():
    if key.startswith('module'):
        state_dict[key[7:]] = value
    else:
        state_dict[key] = value
model = resnet50(pretrained=False, num_classes=62)
model.load_state_dict(state_dict)
model.eval()
#print(model)
transform = transforms.Compose([
                                ToTensorGjz(),
                                NormalizeGjz(mean=127.5, std=128)
                                ])

save_path = os.path.join(os.getcwd(), 'frames_vertex')
if not os.path.exists(save_path):
    os.makedirs(save_path)
frames_path = os.path.join(os.getcwd(), 'frames')

for idx in range(len(glob.glob(os.path.join(frames_path, '*.png')))):
    frame_path = os.path.join(frames_path, '{:05}.png'.format(idx))

    img = cv2.imread(frame_path, cv2.IMREAD_COLOR)
    lmks = get_face_lmks_fan(img)
    #show_lmks(img, lmks)
    #cv2.imwrite('lmks.png', img)
    #exit(0)

    roi_box = parse_roi_box_from_landmark(lmks.T.copy())
    img_crop = crop_img(img, roi_box)
    lmks_crop = crop_lmks(roi_box, lmks)

    lmks_crop = fit_lmks(lmks_crop, img_crop.shape[:2])
    lmks_crop[lmks_crop>119] = 119
    img_crop = cv2.resize(img_crop, dsize=(120, 120), interpolation=cv2.INTER_LINEAR)
    #show_lmks(img_crop, lmks_crop)
    #cv2.imwrite('lmks.png', img_crop)
    #exit(0)

    lmks_map = get_18lmks_map_by_68lmks(lmks_crop.T)
    lmks_map = lmks_map[:,:,np.newaxis]

    lmks_map = torch.from_numpy(lmks_map).unsqueeze(0).permute(0,3,1,2)

    input = transform(img_crop).unsqueeze(0)
    input = torch.cat([input, lmks_map], dim=1)

    with torch.no_grad():
        param = model(input)
        param = param.squeeze().cpu().numpy().flatten().astype(np.float32)
    dense = get_dense_from_param(param, roi_box)
    show_lmks(img, dense.T)
    cv2.imwrite(os.path.join(save_path, '{:05}.png'.format(idx)), img)
    print('complete {}...'.format(idx))
ZHANG-SHI-CHANG commented 5 years ago
###utils.py
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from test_codes.resnet_xgtu_4chls import resnet50
from test_codes.params import *

import numpy as np
import cv2
from math import sqrt
import pickle as pkl

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
import glob
import sys
from tqdm import tqdm

import dlib
dlib_detector = dlib.get_frontal_face_detector()
from tools.FaceBoxes_PyTorch.detector import FaceBoxesDetector
detector = FaceBoxesDetector()
from tools.FAN.api import FAN
fan = FAN(os.path.join('tools', 'FAN', 'model', '2DFAN-4.pth.tar'))

##face and lmks detector
def get_face_box(img, is_fan=False):
    rects = detector.forward(img)
    rect = rects[0]

    if is_fan:
        pass
    else:
        rect = (rect[0], rect[1], rect[2] - rect[0] + 1, rect[3] - rect[1] + 1)
    return rect
def get_face_box_dlib(img, is_fan=False):
    rects = dlib_detector(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
    rect = rects[0]

    if is_fan:
        rect = (rect.left(), rect.top(), rect.right(), rect.bottom())
    else:
        rect = (rect.left(), rect.top(), rect.right() - rect.left() + 1, rect.bottom() - rect.top() + 1)
    return rect
def get_face_lmks_fan(img):
    #box = get_face_box(img, True)
    box = get_face_box_dlib(img, True)
    lmks_68 = fan.detect(img, box)
    return lmks_68

def resize_1000(img):
    h, w = img.shape[:2]
    if max(h, w)>1000:
        img = cv2.resize(img, (1000, int(1000*h/w)))
    return img
def show_box(im, box, color=(0,255,0)):
    x1, y1, w, h = box
    x2 = x1 + w - 1
    y2 = y1 + h - 1
    cv2.rectangle(im, (int(x1),int(y1)), (int(x2), int(y2)), color, 2)
def show_lmks(im, landmarks):
    if isinstance(landmarks, list):
        points = []
        for pair in landmarks:
            x, y = int(pair[0]), int(pair[1])
            points.append([x,y])
        landmarks = np.array(points)

    for i in range(landmarks.shape[0]):
        x, y = landmarks[i,0], landmarks[i,1]
        cv2.circle(im, (int(x), int(y)), 1, (255,255,255), -1)
def show_fan_lmks(im, lmks):
    show_lmks(im, lmks)
    for i in range(0, 17):
        if i<16:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(17, 22):
        if i<21:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(22, 27):
        if i<26:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(27, 31):
        if i<30:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(31, 36):
        if i<35:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(36, 42):
        if i<41:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(42, 48):
        if i<47:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(48, 60):
        if i<59:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
    for i in range(60, 68):
        if i<67:
            cv2.line(im, (int(lmks[i][0]), int(lmks[i][1])), (int(lmks[i+1][0]), int(lmks[i+1][1])), (255,255,255), 2)
def show_img(img, wait=1):
    cv2.namedWindow('debugg', cv2.WINDOW_NORMAL)
    cv2.imshow('debugg', img)
    cv2.waitKey(wait)

def crop_img(img, roi_box):
    h, w = img.shape[:2]

    sx, sy, ex, ey = [int(round(_)) for _ in roi_box]
    dh, dw = ey - sy, ex - sx
    if len(img.shape) == 3:
        res = np.zeros((dh, dw, 3), dtype=np.uint8)
    else:
        res = np.zeros((dh, dw), dtype=np.uint8)
    if sx < 0:
        sx, dsx = 0, -sx
    else:
        dsx = 0

    if ex > w:
        ex, dex = w, dw - (ex - w)
    else:
        dex = dw

    if sy < 0:
        sy, dsy = 0, -sy
    else:
        dsy = 0

    if ey > h:
        ey, dey = h, dh - (ey - h)
    else:
        dey = dh

    res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex]
    return res
def crop_lmks(roi_box, lmks):
    sx, sy, _, _ = [int(round(_)) for _ in roi_box]
    lmks[:, 0] = lmks[:, 0] - sx
    lmks[:, 1] = lmks[:, 1] -sy
    return lmks
def fit_lmks(lmks, original_shape, target_shape=(120, 120)):
    o_h, o_w = original_shape
    t_h, t_w = target_shape
    lmks[:,0], lmks[:,1] = lmks[:,0]*t_w/o_w, lmks[:,1]*t_h/o_h
    return lmks
def parse_roi_box_from_landmark(pts):
    """calc roi box from landmark"""
    bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])]
    center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
    radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
    bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius]

    llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2)
    center_x = (bbox[2] + bbox[0]) / 2
    center_y = (bbox[3] + bbox[1]) / 2

    roi_box = [0] * 4
    roi_box[0] = center_x - llength / 2
    roi_box[1] = center_y - llength / 2
    roi_box[2] = roi_box[0] + llength
    roi_box[3] = roi_box[1] + llength

    return roi_box

def get_18lmks_by_68lmks(lms2d_68):
    _18_indx_3d22d = [17, 19, 21, 22, 24, 26, 36, 40, 39, 42, 46, 45, 31, 30, 35, 48, 66, 54]
    lms2d = lms2d_68[:,_18_indx_3d22d]
    lms2d[:,7] = (lms2d_68[:,37] + lms2d_68[:,40])/2
    lms2d[:,10] = (lms2d_68[:,43] + lms2d_68[:,46])/2
    lms2d[:,16] = (lms2d_68[:,62] + lms2d_68[:,66])/2
    return lms2d
def get_18lmks_map_by_68lmks(pts):
    pts = get_18lmks_by_68lmks(pts)
    ptsMap = np.zeros([120, 120]) - 1
    indx = np.int32(np.floor(pts))
    ptsMap[indx[1], indx[0]] = 1
    return ptsMap.astype(np.float32)

def get_dense_from_param(param, roi_box):
    param = param * param_std + param_mean
    p_ = param[:12].reshape(3, -1)
    p = p_[:, :3]
    offect = p_[:, -1].reshape(3, 1)
    alpha_shp = param[12:52].reshape(-1, 1)
    alpha_exp = param[52:].reshape(-1, 1)

    std_size = 120
    vertex = p @ (u + w_shp @ alpha_shp + w_exp @ alpha_exp).reshape(3, -1, order='F') + offect
    vertex[1, :] = std_size + 1 - vertex[1, :]
    sx, sy, ex, ey = roi_box
    scale_x = (ex - sx) / std_size
    scale_y = (ey - sy) / std_size
    vertex[0, :] = vertex[0, :]*scale_x + sx
    vertex[1, :] = vertex[1, :]*scale_y + sy
    s = (scale_x + scale_y) / 2
    vertex[2, :] *= s

    return vertex

class ToTensorGjz(object):
    def __call__(self, pic):
        if isinstance(pic, np.ndarray):
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            return img.float()
    def __repr__(self):
        return self.__class__.__name__ + '()'
class NormalizeGjz(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        tensor.sub_(self.mean).div_(self.std)
        return tensor
ZHANG-SHI-CHANG commented 5 years ago

@ZHANG-SHI-CHANG 请问方便分享一下你的测试代码吗

我在utils.py里不会分享faceboxes和fan的使用哦,太多了,只提供实现思路,这部分靠你自己啦,主要参考是@cleardusk大佬的3ddfa哦

Light-SH commented 5 years ago

@ZHANG-SHI-CHANG 谢谢你的分享

ZHANG-SHI-CHANG commented 5 years ago

@ZHANG-SHI-CHANG 谢谢你的分享

人脸检测你也可以直接用dlib,但是关键点最好用FAN,我测试dlib,FAN和自己的关键点模型,FAN表现的最好,但也没有3ddfa稳定,我想这个输入语义的影响还是蛮大的

ZHANG-SHI-CHANG commented 5 years ago

@ZHANG-SHI-CHANG Hi, our 2DASL follows the work 3DDFA, and we use resnet50 as backbone, so the running time should be longer than 3DDFA. But our performance is much better than 3DDFA, I am quite sure about that. I think there should be something wrong with your reimplementation, pls check carefully. Thx~ By the way, thanks for the great job of 3DDFA @cleardusk

3ddfa performance in your paper has a low score, i re-training it and got 3.99 in AFLW2000-3D, by the way, can you give some advice about my test code? thx~ for your great job

CodingMice commented 4 years ago

@ZHANG-SHI-CHANG 谢谢你的分享

人脸检测你也可以直接用dlib,但是关键点最好用FAN,我测试dlib,FAN和自己的关键点模型,FAN表现的最好,但也没有3ddfa稳定,我想这个输入语义的影响还是蛮大的

我测试了3ddfa和2dasl,但感觉对于视频输入都不太稳定,有什么好的解决思路吗?

ZHANG-SHI-CHANG commented 4 years ago

@ZHANG-SHI-CHANG 谢谢你的分享

人脸检测你也可以直接用dlib,但是关键点最好用FAN,我测试dlib,FAN和自己的关键点模型,FAN表现的最好,但也没有3ddfa稳定,我想这个输入语义的影响还是蛮大的

我测试了3ddfa和2dasl,但感觉对于视频输入都不太稳定,有什么好的解决思路吗?

如果是抖动问题就使用平滑策略吧,你的测试效果哪个好一些呢

CodingMice commented 4 years ago

@ZHANG-SHI-CHANG 谢谢你的分享

人脸检测你也可以直接用dlib,但是关键点最好用FAN,我测试dlib,FAN和自己的关键点模型,FAN表现的最好,但也没有3ddfa稳定,我想这个输入语义的影响还是蛮大的

我测试了3ddfa和2dasl,但感觉对于视频输入都不太稳定,有什么好的解决思路吗?

如果是抖动问题就使用平滑策略吧,你的测试效果哪个好一些呢

是平滑检测框吗?我测试的时候大侧脸抖动的比较明显。单凭肉眼看也看不太出来哪个更好些,有什么指标可以衡量3Dmesh抖动程度的么?

ZHANG-SHI-CHANG commented 4 years ago

@ZHANG-SHI-CHANG 谢谢你的分享

人脸检测你也可以直接用dlib,但是关键点最好用FAN,我测试dlib,FAN和自己的关键点模型,FAN表现的最好,但也没有3ddfa稳定,我想这个输入语义的影响还是蛮大的

我测试了3ddfa和2dasl,但感觉对于视频输入都不太稳定,有什么好的解决思路吗?

如果是抖动问题就使用平滑策略吧,你的测试效果哪个好一些呢

是平滑检测框吗?我测试的时候大侧脸抖动的比较明显。单凭肉眼看也看不太出来哪个更好些,有什么指标可以衡量3Dmesh抖动程度的么?

平滑关键点或者pose、形状基系数、表情基系数,你可能不是用关键点扩充作为输入的,可以参考上面的测试代码

HOMGH commented 4 years ago

@ZHANG-SHI-CHANG Hi, our 2DASL follows the work 3DDFA, and we use resnet50 as backbone, so the running time should be longer than 3DDFA. But our performance is much better than 3DDFA, I am quite sure about that. I think there should be something wrong with your reimplementation, pls check carefully. Thx~ By the way, thanks for the great job of 3DDFA @cleardusk

3ddfa performance in your paper has a low score, i re-training it and got 3.99 in AFLW2000-3D, by the way, can you give some advice about my test code? thx~ for your great job

Hi, Could you please share the training code for this paper? Thanks.

ZHANG-SHI-CHANG commented 4 years ago

@ZHANG-SHI-CHANG Hi, our 2DASL follows the work 3DDFA, and we use resnet50 as backbone, so the running time should be longer than 3DDFA. But our performance is much better than 3DDFA, I am quite sure about that. I think there should be something wrong with your reimplementation, pls check carefully. Thx~ By the way, thanks for the great job of 3DDFA @cleardusk

3ddfa performance in your paper has a low score, i re-training it and got 3.99 in AFLW2000-3D, by the way, can you give some advice about my test code? thx~ for your great job

Hi, Could you please share the training code for this paper? Thanks.

this paper have no training code, i re-training 3ddfa to 3.99