NVlabs / RADIO

Official repository for "AM-RADIO: Reduce All Domains Into One"
Other
508 stars 17 forks source link

Problems of using RADIO in LLAVA setting. #15

Open weiwei0224 opened 5 months ago

weiwei0224 commented 5 months ago

First thanks for your great job! Now We're trying to replace the vision encoder in llava, i.e., clip-l-336, with RADIO. Under the default LLaVA 1.5 settings, we pretrain a multimodal projection MLP and then run instruction tuning to finetune a Vicuna 7B-1.5 model with LORA. The results are shown below, two experiments are under the same setting.

Vision encoder | GQA | SQA | Text VQA | VQA v2 | clip-l-336 | 62.9 | 68.4 | 58.6 | 77.3 | RADIO | 59.2 | 68.3 | 51.7 | 74.0 |

Unfortunately, we do not observe the improvement by using RADIO, which is different with the results in the paper. And below is my question: 1) We tune the LLM with LORA during sft stage. Did you finetune the LLM directly or with LORA, and do you think it will affect the final observation? 2) Currently, we want to replace EVA-CLIP-G with Radio as the VLM vision encoder, based on your experience, do you think Radio is a better choice?

Your prompt reply would be greatly appreciated, thanks!

gheinrich commented 5 months ago

Hello, thank you for your interest in RADIO!

Using the weights that we released for RADIO v1 we found it that the magnitude of activations is somewhat larger than usual, with standard deviations in the many tens, v.s. the usual one. We thus tried adding a LayerNorm at the output of RADIOv1 by setting global_pool="token" in the constructor and found that this improved our metrics:

Vision Encoder Variant Notes LLM Visual-Language Resolution Vision-Language Alignment train Loss Instruction Tuning Train loss GQA Val Accuracy TextVQA Accuracy ScienceQA(Image) Accuracy VQAv2-TestDev Accuracy
RADIO V1 global_pool="avg" Vicuna 1.5 7B 336 2.52 0.818 67.48 52.90 65.94 75.78
RADIO V1 global_pool="token" Vicuna 1.5 7B 336 2.34 0.798 69.17 54.92 67.08 77.88

Note that we used the val_all set for GQA. I realize most papers report on the testdev set. Sorry about that!

Our training procedure was exactly that of LLaVA 1.5, i.e. we ran pre-training (multimodal alignment) followed by instruction tuning.

We think RADIO is well suited, particularly for tasks that require better spatial understanding. RADIO is flexible about the input image dimension, which will allow you to try out different resolutions, or variations of input pre-processing (such as removing padding around rectangular images in order to process non-square images). We also find it that some of the common benchmarks are not sensitive enough to the vision encoder: for example, the performance of SQA is mainly a function of how good the LLM is. Similarly, the OCR hints in TextVQA make it possible to answer most questions without even looking at the image.

Please let us know if the feature normalization helps on your end!

Thank you.

weiwei0224 commented 5 months ago

Thanks for your reply! But we still have some confusions:

  1. According to the above table, line 1 denotes global_pool="avg" + without LayerNorm, while line 2 means global_pool="token" + LayerNorm, is that right?
  2. Did you also use the summary token when sending into LLM? Waiting for your reply, thank you~
gheinrich commented 5 months ago

Hello, Sorry for the confusion. By global_pool I was referring to the init argument of Timm's VisionTransformer code.

When you set global_pool='token', Timm adds a LayerNorm. With global_pool='avg' there is no LayerNorm. So yes, your understanding was correct :-)

To be specific, in LLaVA my code for instantiating a RADIO vision encoder looks like:

from timm.models import create_model

def _create_timm_model(timm_model_name="vit_huge_patch14_224", global_pool="token"):           

    # TIMM model.
    model = create_model(
        timm_model_name,
        pretrained=False,
        num_classes=0,
        global_pool=global_pool,
    )
    return model

If you are using the HuggingFace model I believe the same can be achieved by simply adding the LayerNorm explicitly. Additional note: the HuggingFace model expects inputs in the range of [0, 1]. Thus you should not normalize values in the image preprocessor. Alternatively, you can normalize values in the image preprocessor but in that case you should set the input_conditioner of the HuggingFace model to be the identity:

radio_model.input_conditioner = torch.nn.Identity()

For the tokens into the LLM, I did try variations: (a) summary tokens only (b) patch tokens only (c) summary+patch tokens

(a) clearly performs worse, however the difference between (b) and (c) is rather small.