mlfoundations / open_flamingo

An open-source framework for training large multimodal models.
MIT License
3.74k stars 284 forks source link

Worse performance compared to vanilla cross attention #286

Closed miguelscarv closed 9 months ago

miguelscarv commented 10 months ago

I've been building a Vision-Language model based on cross-attention mechanisms inspired by Flamingo and regular cross-attention. I've noticed that whenever I add cross-attention layers before the first decoder layer, like it's done for the 1B parameter decoder models, the initial training loss is much worse than if I had simply skipped the first decoder layer. If I skip the first layer the initial training loss is around ~6, and if I don't skip it then it's around ~12. This happens when I use an implementation inspired by this project (OpenFlamingo), where the changes I make are simply removing the dense layers and gating mechanisms of the gated cross attention layers, and only use a single image. A similar initial training loss can be observed in this issue - https://github.com/mlfoundations/open_flamingo/issues/129#issuecomment-1666563478

I tried replicating this problem with the HuggingFace's VisionEncoderDecoderModel, using the same vision encoder and language model, by switching the order of the self-attention block with the cross-attention block, but it seems like I couldn't replicate it - the initial training loss is around ~6. I can't find anything obvious, there is no masking in either implementations, there are always residual connections in both implementations, etc...

TLDR: I tried implementing the exact same cross attention mechanism with different implementations. One based on OpenFlamingo, and one based on HuggingFace's VisionEncoderDecoderModel. These 2 lead to very different training losses in the beginning and I'm not sure why.

Could this be part of the reason why there is a difference in performance between OpenFlamingo and Flamingo?

anas-awadalla commented 10 months ago

Hi @miguelscarv. This is very interesting thanks for sharing! This is definitely something I want to investigate.

Just so I confirm my understanding here for both the OpenFlamingo and Huggingface inspired implementations you do two things: remove gating/dense layers from x-attn and adding the first x-attn layer before the first decoder layer?

A few other questions:

  1. By 'skip the first layer' do you mean add the first x-attn layer after the first decoder layer or just remove it all together?
  2. Despite the different starting losses do the models converge to similar losses later in training? Is evaluation performance also different?
miguelscarv commented 10 months ago

Hi @anas-awadalla - thank you for the quick reply! So yes, you understood it correctly, but note that only OpenFlamingo's implementation has gating and dense layers in their cross attention implementation. The default implementation in HuggingFace is simply a layer norm and vanilla cross attention. To match HuggingFace's implementation I copied the code from this repo and removed the gating and dense layers.

As for the other questions:

  1. I mean I add the first x-attn layer between the first and second decoder layer (that is when the cross_attn_every_n_layers = 1). All I did was change line 100 in src/flamingo_lm.py and add and layer_idx != 0. When cross_attn_every_n_layers is set to a number different than 1, I never get this issue, with or without the change I made to that if statement, so it seems like the issue is only in the first layer...?

  2. No, not yet. I was first trying to validate my implementation by comparing the same architecture using different implementations, and since the results were so different I figured I was doing something wrong. I even tried initializing the linear layers with the same hyperparameters (bias initialized to zero and weight initialized from normal distribution withstd=0.02) that the cross attention implementation of my decoder uses (from HuggingFace), but those hyperparameters result in an even bigger initial training loss.

EDIT: I have previously ran experiments with the default HugginFace implementation, which is vanilla cross attention between the self attention and mlp blocks, using these same models and I got really satisfying results (in training and in evaluation). Since I wanted to improve these results, I thought the Flamingo implementation of cross attention would be better, but because I am facing these issues when I add cross attention layers before the first decoder layer I decided to start running the following experiments:

  1. Adding Dense Cross Attention every 4 layers (basically OpenFlamingo but without gating)
  2. Adding Dense Cross Attention every 2 layers (basically OpenFlamingo but without gating)
  3. Adding Cross Attention every layer except for the first (basically OpenFlamingo but without gating and dense blocks)

The first experiment is halfway through training and there is already a pretty big difference in the loss, when compared to the experiment I got good results with using HuggingFace's implementation. According to the original Flamingo paper vanilla cross attention every layer performs worse than Gated X Attn every 4 layers, but there is still some training to be done before I can conclude anything.

Note: the gating mechanism worsens convergence for smaller datasets, https://github.com/mlfoundations/open_flamingo/issues/129#issuecomment-1666563478, and because the dataset I am using is smaller (1M image-text pairs) I decided to exclude gating. I also tried running it with gating and the results at half the dataset were even worse than the experiment 1 I'm running now so I killed it.