junjiehe96 / UniPortrait

UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization
Apache License 2.0
188 stars 6 forks source link

About implementation details of Qformer #5

Closed 963658029 closed 2 months ago

963658029 commented 2 months ago

Thanks for your great work! Can you share the code for Qformer combined with CurricularFace and CLIP? I want to see how it is implemented without model weights.

963658029 commented 2 months ago

Are all the Attention Blocks in the illustrated QFormer CrossAttention? Are Q the learnable queries, and are K and V the CLIP features or face features? How many layers of MLP does the FFN include, and does it use dropout and normalization? Do the Attention Blocks include FFN, dropout, and normalization? image

junjiehe96 commented 2 months ago

Hi, thank you for your interest in our work. Our QFormer implementation is modified from https://github.com/tencent-ailab/IP-Adapter/blob/62e4af9d0c1ac7d5f8dd386a0ccf2211346af1a2/ip_adapter/resampler.py#L81, and here is its implementation code. More details will be released later.

class UniPortraitFaceIDResampler(torch.nn.Module):
    def __init__(
            self,
            intrinsic_id_embedding_dim=512,
            structure_embedding_dim=64+128+256+1280,
            num_tokens=16,
            depth=6,
            dim=768,
            dim_head=64,
            heads=12,
            ff_mult=4,
            output_dim=768,
    ):
        super().__init__()

        self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5)

        self.proj_id = torch.nn.Sequential(
            torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2),
            torch.nn.GELU(),
            torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim),
        )
        self.proj_clip = torch.nn.Sequential(
            torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2),
            torch.nn.GELU(),
            torch.nn.Linear(structure_embedding_dim * 2, dim),
        )

        self.layers = torch.nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                torch.nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.proj_out = torch.nn.Linear(dim, output_dim)
        self.norm_out = torch.nn.LayerNorm(output_dim)

    def forward(
            self,
            intrinsic_id_embeds,
            structure_embeds,
            structure_scale=1.0,
            intrinsic_id_attention_mask=None,
            structure_attention_mask=None
    ):

        latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1)

        intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds)
        structure_embeds = self.proj_clip(structure_embeds)

        for attn1, attn2, ff in self.layers:
            latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents
            latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)
963658029 commented 2 months ago

Great! Thanks for your reply. And I want to know what is the role of "structure_attention_mask" and "intrinsic_id_attention_mask", why not use full attention? It seems that IP-Adapter's PerceiverAttention does not have this parameter, and your Arxiv paper does not have the relevant content.

963658029 commented 2 months ago

and can you share the dim of input's sintrinsic_id_embeds and structure_embeds ? e.g., (b=, c=, h=, w=), thanks.

junjiehe96 commented 2 months ago

The structure_attention_mask is used to implement the droptoken operation, and the intrinsic_id_attention_mask is None;

The shape of intrinsic_id_embeds is (b, l=7*7, d=512), and the shape of structure_embeds is (b, l=16*16, d=64+128+256+1280).

963658029 commented 2 months ago

thanks for your reply, keep in touch