Jungduri / MLPaperReivew

0 stars 0 forks source link

HandOccNet: Occlusion-Robust 3D Hand Mesh Estimation Network #1

Open Jungduri opened 1 year ago

Jungduri commented 1 year ago

HandOccNet: Occlusion-Robust 3D Hand Mesh Estimation Network

official repo: https://github.com/namepllet/HandOccNet

Introduction

HandOccNet

image

Backbone - Feature injecting transformer (FIT) - Self-Enhancing transformer (SET) - Regressor

Backbone

Softmax-based attention module

image

Fs로부터 Fp와 가장 관련이 있는 정보를 추출함. Occlusion을 야기하는 특정 정보는 이 모듈을 통해서 Fs와 Fp의 강한 상관관계를 표현할 수 있음.

Sigmoid-based attention module

Figure (g)와 같이 상대적으로 맵핑된 값으로 인해 불필요하게 커져버린 correlation을 걸러주는 모듈이 필요하여 sigmoid-based attention module을 사용

Feature injection

query 정보를 residual connection과 함께 output 단계에서 사용한 기존의 transformer와는 달리, query 정보는 사라지고 값들의 정보는 비어진 곳으로 투입되기 때문에 injection라는 용어를 사용.

image

위의 그림처럼 고전적인 transformer는 q,k,v를 입력으로 하는 multi-head에서 나오는 출력과 residual을 사용하지만 이와 달리 HandOccNet은 primary feature의 value만 사용

Self-Enhancing transformer (SET)

image

F_fit에 self-attention module을 태움. Figrue 5. 참조. 모든 특징은 전형적인 self-attention의 특징을 따라 모든 key와 query가 적어도 스스로와 한번은 연관성을 갖게 됌.

Regressor

Hand mesh를 추출하기 위해서 SET의 출력 feature을 MANO pose parameters(48)와 shape parameter(10) 의 mesh로 mapping. Single-block hourglass, 4개의 residual blocks 그리고 fully connected layer로 구성됌.

Experiments

FIT and SET

image

Architecture of FIT

image image

고찰

Jungduri commented 1 year ago

전체 Forward

#    FIT = Transformer(injection=True) # feature injecting transformer
#    SET = Transformer(injection=False) # self enhancing transformer

def forward(self, inputs, targets, meta_info, mode):
        p_feats, s_feats = self.backbone(inputs['img']) # primary, secondary feats
        feats = self.FIT(s_feats, p_feats)
        feats = self.SET(feats, feats)

        if mode == 'train':
            gt_mano_params = torch.cat([targets['mano_pose'], targets['mano_shape']], dim=1)
        else:
            gt_mano_params = None
        pred_mano_results, gt_mano_results, preds_joints_img = self.regressor(feats, gt_mano_params)

        if mode == 'train':
            # loss functions
            loss = {}
            loss['mano_verts'] = cfg.lambda_mano_verts * F.mse_loss(pred_mano_results['verts3d'], gt_mano_results['verts3d'])
            loss['mano_joints'] = cfg.lambda_mano_joints * F.mse_loss(pred_mano_results['joints3d'], gt_mano_results['joints3d'])
            loss['mano_pose'] = cfg.lambda_mano_pose * F.mse_loss(pred_mano_results['mano_pose'], gt_mano_results['mano_pose'])
            loss['mano_shape'] = cfg.lambda_mano_shape * F.mse_loss(pred_mano_results['mano_shape'], gt_mano_results['mano_shape'])
            loss['joints_img'] = cfg.lambda_joints_img * F.mse_loss(preds_joints_img[0], targets['joints_img'])
            return loss

        else:
            # test output
            out = {}
            out['joints_coord_cam'] = pred_mano_results['joints3d']
            out['mesh_coord_cam'] = pred_mano_results['verts3d']
            return 

Transformer

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat

class Transformer(nn.Module):
    def __init__(self, inp_res=32, dim=256, depth=2, num_heads=4, mlp_ratio=4., injection=True):
        super().__init__()

        self.injection=injection

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, injection=injection))

        if self.injection:
            self.conv1 = nn.Sequential(
                nn.Conv2d(dim*2, dim, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(dim, dim, 3, padding=1),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(dim*2, dim, 1, padding=0),
            )

    def forward(self, query, key):
        output = query
        for i, layer in enumerate(self.layers):
            output = layer(query=output, key=key)

        if self.injection:
            output = torch.cat([key, output], dim=1)
            output = self.conv1(output) + self.conv2(output)

        return output

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self._init_weights()

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.sigmoid = nn.Sigmoid()

    def forward(self, query, key, value, query2, key2, use_sigmoid):
        B, N, C = query.shape
        query = query.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        key = key.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        value = value.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        attn = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        if use_sigmoid:
            query2 = query2.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
            key2 = key2.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
            attn2 = torch.matmul(query2, key2.transpose(-2, -1)) * self.scale
            attn2 = torch.sum(attn2, dim=-1)
            attn2 = self.sigmoid(attn2)
            attn = attn * attn2.unsqueeze(3) 

        x = torch.matmul(attn, value).transpose(1, 2).reshape(B, N, C)
        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.LayerNorm, injection=True):
        super().__init__()

        self.injection = injection

        self.channels = dim

        self.encode_value = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
        self.encode_query = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
        self.encode_key = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)

        if self.injection:
            self.encode_query2 = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
            self.encode_key2 = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)

        self.attn = Attention(dim, num_heads=num_heads)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
        self.q_embedding = nn.Parameter(torch.randn(1, 256, 32, 32))
        self.k_embedding = nn.Parameter(torch.randn(1, 256, 32, 32))

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward(self, query, key, query_embed=None, key_embed=None):
        b, c, h, w = query.shape
        query_embed = repeat(self.q_embedding, '() n c d -> b n c d', b = b)
        key_embed = repeat(self.k_embedding, '() n c d -> b n c d', b = b)

        q_embed = self.with_pos_embed(query, query_embed)
        k_embed = self.with_pos_embed(key, key_embed)

        v = self.encode_value(key).view(b, self.channels, -1)
        v = v.permute(0, 2, 1)

        q = self.encode_query(q_embed).view(b, self.channels, -1)
        q = q.permute(0, 2, 1)

        k = self.encode_key(k_embed).view(b, self.channels, -1)
        k = k.permute(0, 2, 1)

        query = query.view(b, self.channels, -1).permute(0, 2, 1)

        if self.injection:
            q2 = self.encode_query2(q_embed).view(b, self.channels, -1)
            q2 = q2.permute(0, 2, 1)

            k2 = self.encode_key2(k_embed).view(b, self.channels, -1)
            k2 = k2.permute(0, 2, 1)

            query = self.attn(query=q, key=k, value=v,query2 = q2, key2 = k2, use_sigmoid=True)
        else:
            q2 = None
            k2 = None

            query = query + self.attn(query=q, key=k, value=v, query2 = q2, key2 = k2, use_sigmoid=False)

        query = query + self.mlp(self.norm2(query))
        query = query.permute(0, 2, 1).contiguous().view(b, self.channels, h, w)

        return query