microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

BEiT3 Vision-Language Expert question #74

Closed andreapdr closed 9 months ago

andreapdr commented 9 months ago

Hello :smile:

The BeIT3 paper mentions that Vision-language experts are employed in the top three Multiway Transformer layers. However, by taking a look the MultiwayNetwork implementation, I find it difficult to understand where this is supposed to happen.

Could you help me understand this?

donglixp commented 9 months ago

The usage example is at https://github.com/microsoft/unilm/blob/master/beit3/modeling_utils.py#L21

andreapdr commented 9 months ago

Thank you for your prompt response, @donglixp.

From my understanding, is the MultiwayNetwork which is supposed to process the visual modality (V-FFN), the textual modality (L-FNN), or both (VL-FNN). It routes the information to the corresponding FFN (A, B, or both) according to the split_position attribute.

What I can't understand is how/where you set split_position (or in BeIT3, multiway_split_position) to specifically pass multi-modal information to both A and B FNN only in the top three layers, while routing it to either A or B in the lower ones.

class MultiwayNetwork(nn.Module):
    def __init__(self, module, dim=1):
        super().__init__()
        self.dim = dim
        self.A = module
        self.B = copy.deepcopy(module)
        self.B.reset_parameters()
        self.split_position = -1

    def forward(self, x, **kwargs):
        if self.split_position == -1:
            return self.A(x, **kwargs)
        if self.split_position == 0:
            return self.B(x, **kwargs)
        x1, x2 = torch.split(
            x,
            [self.split_position, x.size(self.dim) - self.split_position],
            dim=self.dim,
        )
        # x1, x2 = x[:self.split_position], x[self.split_position:]
        y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
        return torch.cat([y1, y2], dim=self.dim)
wenhui0924 commented 9 months ago

Hi @andreapdr, for the Multiway Transformer implemented in torchscale, we remove the VL-FFN exports and use different attention and FFN parameters for vision and language. We perform VL fusion via concatenating Q/K/V of vision and language. Please refer to Table 16 in our Supp. For the VL-expert implementation, please refer to this code.

andreapdr commented 9 months ago

Thank you @wenhui0924 for the clarification: I totally skipped over the supplementary material when reading the paper! Closing the issue now :smile: