cleardusk / 3DDFA_V2

The official PyTorch implementation of Towards Fast, Accurate and Stable 3D Dense Face Alignment, ECCV 2020.
MIT License
2.9k stars 514 forks source link

How to infer on batch #108

Open rzamarefat opened 3 years ago

rzamarefat commented 3 years ago

Hi thank you for this awesome implementation. Is it possible to infer on batch of images at the same time?

cleardusk commented 3 years ago

I think it is easy to run on batches with small modifications, if you are familiar with pytorch.

gitlabspy commented 2 years ago

I think it is easy to run on batches with small modifications, if you are familiar with pytorch.

But the model requires cropped images. Any suggestions on batch aligned images (but not cropped) inference?

JuanFMontesinos commented 2 years ago

I did a batched version for a single box. This is intended to be used in videos where there is a single person whose face remains inside the box the whole time. Most of the funcs can be rewritten to work with batches and in pytorch.

gitlabspy commented 2 years ago

I did a batched version for a single box. This is intended to be used in videos where there is a single person whose face remains inside the box the whole time. Most of the funcs can be rewritten to work with batches and in pytorch.

Sounds cool! Can you share these funcs with us?

JuanFMontesinos commented 2 years ago

After eccv :)

El vie, 4 mar 2022 a las 16:00, hanban @.***>) escribió:

I did a batched version for a single box. This is intended to be used in videos where there is a single person whose face remains inside the box the whole time. Most of the funcs can be rewritten to work with batches and in pytorch.

Sounds cool! Can you share these funcs with us?

— Reply to this email directly, view it on GitHub https://github.com/cleardusk/3DDFA_V2/issues/108#issuecomment-1059237164, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHXWLBVE4HMKH3XPIDBWDWLU6IQQ7ANCNFSM5G5G26KA . You are receiving this because you commented.Message ID: @.***>

JuanFMontesinos commented 2 years ago
"""
Source code from cleardusk
Optimized by Juan F. Montesinos to work with batches in GPU
"""
import os.path as osp
import torch
from torch import nn
from torchvision.transforms import Compose

import models
from bfm import BFMModel
from utils.io import _load
from utils.functions import (
    crop_video, reshape_fortran, parse_roi_box_from_bbox,
)
from utils.tddfa_util import (
    load_model, _batched_parse_param, batched_similar_transform,
    ToTensorGjz, NormalizeGjz
)

make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn)

class TDDFA(nn.Module):
    """TDDFA: named Three-D Dense Face Alignment (TDDFA)"""

    def __init__(self, **kvs):
        self.size = kvs.get('size', 120)

        # load BFM
        self.bfm = BFMModel(
            bfm_fp=kvs.get('bfm_fp', make_abs_path('configs/bfm_noneck_v3.pkl')),
            shape_dim=kvs.get('shape_dim', 40),
            exp_dim=kvs.get('exp_dim', 10)
        )
        self.tri = self.bfm.tri

        param_mean_std_fp = kvs.get(
            'param_mean_std_fp', make_abs_path(f'configs/param_mean_std_62d_{self.size}x{self.size}.pkl')
        )

        # load model, default output is dimension with length 62 = 12(pose) + 40(shape) +10(expression)
        model = getattr(models, kvs.get('arch'))(
            num_classes=kvs.get('num_params', 62),
            widen_factor=kvs.get('widen_factor', 1),
            size=self.size,
            mode=kvs.get('mode', 'small')
        )
        model = load_model(model, kvs.get('checkpoint_fp'))

        self.model = model

        # data normalization
        self.transform_normalize = NormalizeGjz(mean=127.5, std=128)
        transform_to_tensor = ToTensorGjz()
        transform = Compose([transform_to_tensor, self.transform_normalize])
        self.transform = transform

        # params normalization config
        r = _load(param_mean_std_fp)
        self.param_mean = torch.from_numpy(r.get('mean'))
        self.param_std = torch.from_numpy(r.get('std'))
        self.param_mean = self.param_mean
        self.param_std = self.param_std

    def batched_inference(self, video_ori, bbox, **kvs):
        """The main call of TDDFA, given image and box / landmark, return 3DMM params and roi_box
        :param img_ori: the input image
        :param objs: left, top, right, bottom = bbox (think in lines like y=25, not points)
        :param kvs: options
        :return: param list and roi_box list
        """
        roi_box = parse_roi_box_from_bbox(bbox)
        video = crop_video(video_ori, roi_box)
        img = torch.nn.functional.interpolate(video, size=(self.size, self.size), mode='bilinear', align_corners=False)

        inp = self.transform_normalize(img)
        param = self.model(inp)

        param = param * self.param_std + self.param_mean  # re-scale

        return param, roi_box

    def batched_recon_vers(self, param, roi_box, **kvs):
        dense_flag = kvs.get('dense_flag', False)
        size = self.size
        R, offset, alpha_shp, alpha_exp = _batched_parse_param(param)
        if dense_flag:
            tensor = self.bfm.u + self.bfm.w_shp @ alpha_shp + self.bfm.w_exp @ alpha_exp
        else:
            tensor = self.bfm.u_base + self.bfm.w_shp_base @ alpha_shp + self.bfm.w_exp_base @ alpha_exp
        pts3d = R @ reshape_fortran(tensor, (param.shape[0], 3, -1)) + offset
        pts3d = batched_similar_transform(pts3d, roi_box, size)

        return pts3d

So it's basically a matter of getting rid of the framework the author proposes. I find it's not very well designed as pytorch automatically allows u to work on cpu or gpu just by writting 2 words.

Besides, some funcs are needed, like reshape in fortran order which is not implemented in pytorch:

def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
def crop_video(video, roi_box):
    bs, c, h, w = video.shape

    sx, sy, ex, ey = [int(round(_)) for _ in roi_box]
    dh, dw = ey - sy, ex - sx
    res = torch.zeros(bs, c, dh, dw, dtype=video.dtype, device=video.device)

    if sx < 0:
        sx, dsx = 0, -sx
    else:
        dsx = 0

    if ex > w:
        ex, dex = w, dw - (ex - w)
    else:
        dex = dw

    if sy < 0:
        sy, dsy = 0, -sy
    else:
        dsy = 0

    if ey > h:
        ey, dey = h, dh - (ey - h)
    else:
        dey = dh

    res[..., dsy:dey, dsx:dex] = video[..., sy:ey, sx:ex]
    return res
def batched_similar_transform(pts3d, roi_box, size):
    pts3d[:, 0, :] -= 1  # for Python compatibility
    pts3d[:, 2, :] -= 1
    pts3d[:, 1, :] = size - pts3d[:, 1, :]

    sx, sy, ex, ey = roi_box
    scale_x = (ex - sx) / size
    scale_y = (ey - sy) / size
    pts3d[:, 0, :] = pts3d[:, 0, :] * scale_x + sx
    pts3d[:, 1, :] = pts3d[:, 1, :] * scale_y + sy
    s = (scale_x + scale_y) / 2
    pts3d[:, 2, :] *= s
    pts3d[:, 2, :] -= torch.min(pts3d[:, 2, :], dim=-1)[0].unsqueeze(-1)
    return pts3d.contiguous()
def _batched_parse_param(param):
    """matrix pose form
    param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10
    """

    assert param.ndim == 2
    bs, n = param.shape
    if n == 62:
        trans_dim, shape_dim, exp_dim = 12, 40, 10
    elif n == 72:
        trans_dim, shape_dim, exp_dim = 12, 40, 20
    elif n == 141:
        trans_dim, shape_dim, exp_dim = 12, 100, 29
    else:
        raise Exception(f'Undefined templated param parsing rule')

    R_ = param[:, :trans_dim].reshape(bs, 3, -1)
    R = R_[..., :3]
    offset = R_[..., -1].reshape(bs, 3, 1)
    alpha_shp = param[:, trans_dim:trans_dim + shape_dim].reshape(bs, -1, 1)
    alpha_exp = param[:, trans_dim + shape_dim:].reshape(bs, -1, 1)

    return R, offset, alpha_shp, alpha_exp

Lastly, rewritting some items as nn.Modules so the auto-allocation works

def _batched_parse_param(param):
    """matrix pose form
    param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10
    """

    assert param.ndim == 2
    bs, n = param.shape
    if n == 62:
        trans_dim, shape_dim, exp_dim = 12, 40, 10
    elif n == 72:
        trans_dim, shape_dim, exp_dim = 12, 40, 20
    elif n == 141:
        trans_dim, shape_dim, exp_dim = 12, 100, 29
    else:
        raise Exception(f'Undefined templated param parsing rule')

    R_ = param[:, :trans_dim].reshape(bs, 3, -1)
    R = R_[..., :3]
    offset = R_[..., -1].reshape(bs, 3, 1)
    alpha_shp = param[:, trans_dim:trans_dim + shape_dim].reshape(bs, -1, 1)
    alpha_exp = param[:, trans_dim + shape_dim:].reshape(bs, -1, 1)

    return R, offset, alpha_shp, alpha_exp

And that's all more or less

JuanFMontesinos commented 2 years ago

I'd eventually upload the code ready to use but have other prios atm.