mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.16k stars 911 forks source link

Any plans to support the modified VIT arch based on the VIT-22B paper #426

Open JianbangZ opened 1 year ago

JianbangZ commented 1 year ago

The changes include the QK normlizaiton, Parallel layers and etc. It would be cool to see how CLIP performs by applying those changes to VIT-L VIT-B VIT-H

rwightman commented 1 year ago

@JianbangZ

  1. QK normalization might improve stability at the larger end of the model scale, so far have managed to mitigate with bfloat16 + AMP and lowering AdamW beta2 but not perfectly for large scale models

  2. The parallel MLP + attn won't make much difference within a single device until it's (non-trivially) supported in distributed training via model/tensor parallel code. Kernels typically use the full GPU and don't execute in parallel, even if you try.

I tried torch.compile (inductor) with the parallel attn, thought I was seeing gains on the B/16 model size, and then it revered on H/14. Talked to an expert and he figures it's just the compiler behaviour (good fit vs not so great w/ some differfences due to the parallel layout), BUT there were no parallel exection optimizations in the compiler (it does not try).

lucidrains commented 1 year ago

@JianbangZ yea, i can add the qk rmsnorm this week as an option

however, there are still researchers who disagree, just fyi

~also, i feel like most of the instability has been ironed out after switching to bfloat16?~ Ross already mentioned this

rwightman commented 1 year ago

I have a different approach for the parallel blocks (manually fusing) that looks like it'll work better. Trying in timm. My first naive approach did not yield gains as mentioned, the compiler couldn't do much with it either.

For qk norm here, it won't work with the default transformer block (it fully relies on builtin nn.MHA, and don't plan to alter that right now), it would need to be added to the custom attention block. There is a new F.scaled_dot_product_attention that's a fused kernel with flash attention (or xformers mem efficient attention) in PyTorch 2.0 that can bridge the custom impl closer to nn.MHA in performance, qk norm would work with that since it only covers the scaled dot prod

WIP https://github.com/rwightman/pytorch-image-models/blob/b6eb652924a40555d6bfcee64c7ef4c8d6e4aa9c/timm/models/vision_transformer.py#L54-L102

lucidrains commented 1 year ago

@rwightman ugh yea, i forgot about that (re: nn.mha). ok, let us leave it untouched then, until flash attention is released properly in pytorch 2.0