Open JianbangZ opened 1 year ago
@JianbangZ
QK normalization might improve stability at the larger end of the model scale, so far have managed to mitigate with bfloat16 + AMP and lowering AdamW beta2 but not perfectly for large scale models
The parallel MLP + attn won't make much difference within a single device until it's (non-trivially) supported in distributed training via model/tensor parallel code. Kernels typically use the full GPU and don't execute in parallel, even if you try.
I tried torch.compile (inductor) with the parallel attn, thought I was seeing gains on the B/16 model size, and then it revered on H/14. Talked to an expert and he figures it's just the compiler behaviour (good fit vs not so great w/ some differfences due to the parallel layout), BUT there were no parallel exection optimizations in the compiler (it does not try).
@JianbangZ yea, i can add the qk rmsnorm this week as an option
however, there are still researchers who disagree, just fyi
~also, i feel like most of the instability has been ironed out after switching to bfloat16?~ Ross already mentioned this
I have a different approach for the parallel blocks (manually fusing) that looks like it'll work better. Trying in timm. My first naive approach did not yield gains as mentioned, the compiler couldn't do much with it either.
For qk norm here, it won't work with the default transformer block (it fully relies on builtin nn.MHA, and don't plan to alter that right now), it would need to be added to the custom attention block. There is a new F.scaled_dot_product_attention that's a fused kernel with flash attention (or xformers mem efficient attention) in PyTorch 2.0 that can bridge the custom impl closer to nn.MHA in performance, qk norm would work with that since it only covers the scaled dot prod
@rwightman ugh yea, i forgot about that (re: nn.mha). ok, let us leave it untouched then, until flash attention is released properly in pytorch 2.0
The changes include the QK normlizaiton, Parallel layers and etc. It would be cool to see how CLIP performs by applying those changes to VIT-L VIT-B VIT-H