Closed andreapdr closed 9 months ago
The usage example is at https://github.com/microsoft/unilm/blob/master/beit3/modeling_utils.py#L21
Thank you for your prompt response, @donglixp.
From my understanding, is the MultiwayNetwork which is supposed to process the visual modality (V-FFN), the textual modality (L-FNN), or both (VL-FNN). It routes the information to the corresponding FFN (A, B, or both) according to the split_position
attribute.
What I can't understand is how/where you set split_position
(or in BeIT3, multiway_split_position
) to specifically pass multi-modal information to both A and B FNN only in the top three layers, while routing it to either A or B in the lower ones.
class MultiwayNetwork(nn.Module):
def __init__(self, module, dim=1):
super().__init__()
self.dim = dim
self.A = module
self.B = copy.deepcopy(module)
self.B.reset_parameters()
self.split_position = -1
def forward(self, x, **kwargs):
if self.split_position == -1:
return self.A(x, **kwargs)
if self.split_position == 0:
return self.B(x, **kwargs)
x1, x2 = torch.split(
x,
[self.split_position, x.size(self.dim) - self.split_position],
dim=self.dim,
)
# x1, x2 = x[:self.split_position], x[self.split_position:]
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
return torch.cat([y1, y2], dim=self.dim)
Hi @andreapdr, for the Multiway Transformer implemented in torchscale, we remove the VL-FFN exports and use different attention and FFN parameters for vision and language. We perform VL fusion via concatenating Q/K/V of vision and language. Please refer to Table 16 in our Supp. For the VL-expert implementation, please refer to this code.
Thank you @wenhui0924 for the clarification: I totally skipped over the supplementary material when reading the paper! Closing the issue now :smile:
Hello :smile:
The BeIT3 paper mentions that Vision-language experts are employed in the top three Multiway Transformer layers. However, by taking a look the MultiwayNetwork implementation, I find it difficult to understand where this is supposed to happen.
Could you help me understand this?