AIDC-AI / Ovis

A novel Multimodal Large Language Model (MLLM) architecture, designed to structurally align visual and textual embeddings.
https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B
Apache License 2.0
557 stars 33 forks source link

About the equivalence and a slightly more complex MLP connection #5

Closed lucasjinreal closed 3 months ago

lucasjinreal commented 3 months ago

Thank you for sharing the dataset and open-source model. Ovis employed VE + Head + Tokenize (essentially a softmax) and subsequently obtained the same hidden dimension features for the LLM. I remain intrigued by the precise disparity compared to a complex MLP connector in Llava, aside from the softmax operation. Have you all conducted experiments using vanilla Llava data trained on Ovis and compared it with Llava subsequently? I am uncertain about the extent of improvements resulting from the extensive data training utilized in Ovis. The sole compelling aspect ought to be to undertake some experiments involving the use of a more intricate MLP connector to compare the efficacy of the Ovis method.

To be more specific, I can wrap those tokenize part into a Connector in llava-like, which looks actually same as Ovis:

vision_encoder = Siglip(..)

connector = [Head(fea_strided, visual_vocab_size), tokenize(func=softmax), VTE(linear=(visual_vocab_size, llm_hidden_size)]

LLM

From I can see, the operation is just added a softmax between two linear. The only differences is Ovis linear used a larger hidden dimension && added a softmax between them, also, introduced a hidden_strided machensim same as InternVL's pixel shuffle to reduce token numbers.

since this VTE to me essentially a Linear (as input is float tensor, only minor diff could be Linear used Kaiming intializer while you set to normalized, also bias is enabled in Linear but can be disabled) :

class VisualEmbedding(torch.nn.Embedding):
    def forward(self, input: Tensor) -> Tensor:
        if any((isinstance(input, LongTensor), isinstance(input, IntTensor))):
            return super().forward(input)
        return torch.matmul(input, self.weight)

    def reset_parameters(self, mean=0., std=1.) -> None:
        init.normal_(self.weight, mean=mean, std=std)
        self._fill_padding_idx_with_zero()

Hoping for your discussion, thanks for opensource the work again!

runninglsy commented 3 months ago

Thank you for your interest in our work.

For the ablation study with the MLP connector, please see Section 4.3 of the Ovis report. We also provide the mathematical formulation of the operations conducted in the vision module of Ovis in Section 3. Regarding the implementation of the visual embedding table, our early experiments show that the initialization method matters: Gaussian initialization leads to a better loss curve in the first stage of training Ovis. Therefore, in our code, we construct the visual embedding table by inheriting from the Embedding layer instead of the Linear layer.

lucasjinreal commented 3 months ago

I have reimplemented Ovis as a Connector in LLava, which appears identical to Ovis. Essentially, the architecture is as follows:

vision_encoder = Siglip(..)

connector = [Head(fea_strided, visual_vocab_size), tokenize(func=softmax), VTE(linear=(visual_vocab_size, llm_hidden_size)]

LLM

Consequently, the model code based on LLava should not undergo significant changes.

However, I discovered that when I pretrain the Connector solely as LLava does, without opening the last layer, it fails to converge at all, even though the learning rate I set is similar to yours.

Are there any possible reasons for this?

The code like this:

class OvisProjector(nn.Module):

    def __init__(self, in_hidden_size, out_hidden_size, hidden_stride=1):
        super(OvisProjector, self).__init__()

        self.hidden_stride = hidden_stride
        self.visual_vocab_size = 16384
        self.tokenize_function = "softmax"
        self.tau = 1.0

        # mapping into
        self.head = nn.Linear(
            in_hidden_size * self.hidden_stride * self.hidden_stride,
            self.visual_vocab_size,
        )
        self.vte = VisualEmbedding(self.visual_vocab_size, out_hidden_size)

    def _fold_features(self, features):
        if self.hidden_stride > 1:
            n, l, d = features.shape  # this `d` maybe different from the above `d
            sqrt_l = int(l**0.5)
            assert (
                sqrt_l**2 == l
            ), "The token sequence length should be a perfect square."
            assert (
                l % (self.hidden_stride**2) == 0
            ), "The token sequence length should be divisible by `hidden_stride**2`."
            features = features.reshape(n, sqrt_l, sqrt_l, d)
            features = features.reshape(
                n,
                sqrt_l // self.hidden_stride,
                self.hidden_stride,
                sqrt_l // self.hidden_stride,
                self.hidden_stride,
                d,
            )
            features = features.permute(
                0, 1, 3, 2, 4, 5
            )  # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
            features = features.flatten(3)  # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
            features = features.reshape(
                n,
                l // (self.hidden_stride * self.hidden_stride),
                self.hidden_stride * self.hidden_stride * d,
            )

        return features

    def tokenize(self, logits):
        def st_argmax(y_soft, dim):  # straight-through softmax
            index = y_soft.max(dim, keepdim=True)[1]
            y_hard = torch.zeros_like(
                y_soft, memory_format=torch.legacy_contiguous_format
            ).scatter_(dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
            return ret

        if self.tokenize_function == "softmax":
            tokens = F.softmax(logits, dim=-1)
        elif self.tokenize_function == "gumbel_argmax":
            tokens = F.gumbel_softmax(logits, tau=self.tau, hard=True)
        elif self.tokenize_function == "st_argmax":
            tokens = st_argmax(logits, dim=-1)
        else:
            raise ValueError(
                f"Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.tokenize_function}"
            )
        return tokens

    def forward(self, x, attention_mask: torch.Tensor = None):
        x = self._fold_features(x)
        x = self.tokenize(x)
        x = self.vte(x)
        # final tokens
        print(f'final token shape: {x.shape}')
        return x

Did I made any mistake on it?

runninglsy commented 3 months ago

I would like to suggest following the training configurations outlined in the scripts.

lucasjinreal commented 3 months ago

Does my reimplement have any bias?

They can't really converge at all.

It's opening last layer is critical? I would like dicuss about some insights about your way.

runninglsy commented 3 months ago

I cannot determine the reason for the loss not converging because it may be related to the full training code, data format, and training parameters. I recommend using our open-source codebase, data format, and training parameters directly for training. This way, if any issues arise, I will be able to reproduce and try to resolve them. As for whether to train the last layer, our experimental experience indicates that the loss can converge regardless of whether it is trained or not.

lucasjinreal commented 3 months ago

@runninglsy thanks for the reply, can u make sure is my implementation identical to Ovis interms of model architecture? If training the projector only can converge, what could I miss than, the data I used certainly successfully trained many llava like models.

runninglsy commented 3 months ago

There is a layer normalization in the visual head, and two image indicator tokens are present in the visual tokenizer. For further details, please refer to this file. Additionally, please note that we did not include an LLM conversation template in the first phase.

lucasjinreal commented 3 months ago

@runninglsy Thank u for insight.

I edit to add Layernorm, and disabled bias in head's Linear, and added indicator tokens, train the projector only, the loss still can not decrease. (extremly slow)

I noticed the scripts provided is 8x128 ( gradient_accumlation) does the batch size is 8096 for 3e-4? (assume you were using 8 GPUs as default settings)

runninglsy commented 3 months ago

Yes, the parameters in the scripts are configured for an 8-GPU setting, resulting in an overall batch size of 8×128×8=8192. During actual training, we utilize an internal distributed system that typically employs 64 or more GPUs in the first stage. However, the overall batch size remains consistent with that defined in the scripts. Specifically, we set per_device_train_batch_size to 8 and gradient_accumulation_steps to 16 in the internal 64-GPU environment.

lucasjinreal commented 3 months ago

@runninglsy Then this lr is a little bit small compare to LLava's project training lr, typically could be 1e-4 with global batchsize 128.

Have u tried using large lr? Does small lr can get converge?

runninglsy commented 3 months ago

We did not reference the training parameters of llava-type models to set the hyperparameters of Ovis. I can confirm that the code, data, and hyperparameters we used for training are consistent with the open-source version, and the loss converges properly.

ShacklesLay commented 2 months ago

Thank you for sharing! I'm really curious if you were able to successfully replicate Ovis later on.