Hi, it's great work! But I have some questions about the code implementation of cross-domain attention module.
In the previous code version, it seems that this part can achieve the cross-domain. And it can match the design of cross-domain in paper.
if multiview_attention:
key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
if cross_domain_attention:
# memory efficient, cross domain attention
key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
key_cross = torch.concat([key_1, key_0], dim=0)
value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
key = torch.cat([key, key_cross], dim=1)
value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
else:
# print("don't use multiview attention.")
key = key_raw
value = value_raw
But in the current code, I can't find this part, and you use 'my_repeat' instead. I just wonder whether you delete the code which can achieve the cross-domain or you just use another strategy to achieve cross-domain? And no matter which answer it is, could you explain the reason, please?
if multiview_attention:
if not sparse_mv_attention:
key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
else:
key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
value = torch.cat([value_front, value_raw], dim=1)
else:
# print("don't use multiview attention.")
key = key_raw
value = value_raw
Hi, it's great work! But I have some questions about the code implementation of cross-domain attention module. In the previous code version, it seems that this part can achieve the cross-domain. And it can match the design of cross-domain in paper.
But in the current code, I can't find this part, and you use 'my_repeat' instead. I just wonder whether you delete the code which can achieve the cross-domain or you just use another strategy to achieve cross-domain? And no matter which answer it is, could you explain the reason, please?
Waiting for you answer, thanks a lot!!