xxlong0 / Wonder3D

Single Image to 3D using Cross-Domain Diffusion for 3D Generation
https://www.xxlong.site/Wonder3D/
GNU Affero General Public License v3.0
4.49k stars 351 forks source link

The code implementation of Cross-domain attetion #161

Open A-pril opened 2 months ago

A-pril commented 2 months ago

Hi, it's great work! But I have some questions about the code implementation of cross-domain attention module. In the previous code version, it seems that this part can achieve the cross-domain. And it can match the design of cross-domain in paper.

        if multiview_attention:
            key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
            value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)

            if cross_domain_attention:
                # memory efficient, cross domain attention
                key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2)  # keys shape (b t) d c
                value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
                key_cross = torch.concat([key_1, key_0], dim=0)
                value_cross = torch.concat([value_1, value_0], dim=0) #  shape (b t) d c
                key = torch.cat([key, key_cross], dim=1)
                value = torch.cat([value, value_cross], dim=1)  #  shape (b t) (t+1 d) c
        else:
            # print("don't use multiview attention.")
            key = key_raw
            value = value_raw

But in the current code, I can't find this part, and you use 'my_repeat' instead. I just wonder whether you delete the code which can achieve the cross-domain or you just use another strategy to achieve cross-domain? And no matter which answer it is, could you explain the reason, please?

if multiview_attention:
            if not sparse_mv_attention:
                key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
                value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
            else:
                key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
                value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
                key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
                value = torch.cat([value_front, value_raw], dim=1)

        else:
            # print("don't use multiview attention.")
            key = key_raw
            value = value_raw

Waiting for you answer, thanks a lot!!

onpix commented 2 months ago

Looks like it is implemented as XFormersJointAttnProcessor in the same file.