Open jramapuram opened 2 years ago
oh yes for current main branch, nothing landed addressing this yet. Could you try https://github.com/facebookresearch/xformers/pull/303 by any chance ? I can try to start something later today, but a little bit underwater atm :/
No worries! Will give that a shot :) [feel better!]
I added the reference pre-norm graphs above. Differences are basically:
oh wow, it's pretty clear indeed, thanks @jramapuram. #303 is definitely fixing a small bug, but I doubt that it explains this really, I'll dive back into deepnorm. I may have a repro actually, with the recent metaformer+cifar10 deepnorm does not work either but I thought that was because of the decidely different model structure. I'll give it a second look, sorry for the delay
hmm, I did spend some time on that and found nothing obviously wrong, it's really perplexing. I'll give IN a shot. If you have the option, would it be possible to test this without AMP, just in case it's a matter of numerical accuracy (which would not be caught by the grad scaler if not NaN) ?
Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403
? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)
(no, I've not forgotten that issue.. ). I would love to be able to repro on something a little smaller than a full blown IN + training over a couple of nodes, documenting that here. Attached is a minGPT training setup, with pre/post/deepnorm (8 layers transformer, 25M params). Deepnorm doesn't converge to a solution which is as good as the others, but no catastrophic failure for either of them
Just in case @jramapuram, could you check that you're using
triton == 2.0.0.dev20220403
? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)
Thanks for keeping this in mind @blefaudeux. Just checked, using triton==2.0.0.dev20220430
-- I can drop down and test!
Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend?
Just in case @jramapuram, could you check that you're using
triton == 2.0.0.dev20220403
? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)Thanks for keeping this in mind @blefaudeux. Just checked, using
triton==2.0.0.dev20220430
-- I can drop down and test!Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend?
20220430 was fine, the ones after that were broken, but fixed by https://github.com/openai/triton/commit/205a493b10a5112ec1fccdbe9d59fe9f172e027d so it's back to being good at the moment ! re-minGPT I can check the other metrics, as mentioned in another thread I think that it may be due to the distribution being hardcoded right now for deepnorm, I think it's not very readable, hackable, and not a great design overall, I'd like to come up with something better and more explicit (for instance with a couple of possible inits as part of the xformers config, and deepnorm respecting that). It's always possible to init from the outside, but it's tied to parameter naming conventions (not super clear right now), and it kind of negates the point of supporting deepnorm to begin with I think
Unfortunately no joy @blefaudeux. I tried:
triton==2.0.0.dev20220430
+ pos + cls inittriton==2.0.0.dev20220430
+ pos + cls + weight initthanks a bunch @jramapuram ! I've a draft PR getting ready which rewrites a lot of the input projections (something we discussed earlier) + explicit handling of a couple of init methods (optional, users are still free to do as they please), I'm hoping that it solves this. To give an insight, I think that this setting is not well handled and could be the culprit (deepnorm assumes a different projection per Q/K/V, and the default here should probably be "true" I believe)
I think that #312 is getting there @jramapuram, it's a lot cleaner to my eyes. Something I've seen, related to your curves above, is that it's not just deepnorm, the post- normalization path does not play well with ViT. GPT is fine with this nornalization path, I don't know if it's a known fact, I would need to check the literature. Since deepnorm is a subset of the post- normalization code path, it makes a little more sense, or at least it's not alone
ok, beyond #312 which cleans things up, it looks like (given Timm, here) layernorm requires a specific treatment for ViT+Post, the weight is initialized to a very small value (vs. 1 typically). Since in our case Post & Deepnorm (same residual codepath) both fail with ViT but work well with GPT, it could explain why. I'll give that a shot
I've not forgotten that @jramapuram, turns out that for vision / post norm Swin v2 already solved this (related to the message above), see their paper. The initial weights need to be scaled way down, I'll try to implement this in xformers when I get the time
🐛 Bug
I'm trying to create a 1:1 config that can train a stable ViT-B with the MAE config (from appendix A.2).
Maybe I'm missing something (highly plausible), but when I use xformers instead of timm it creates an unstable training scenario [over numerous trials] with exactly the same hyper-parameters (batch_size=4096 + cutmix + mixup + label smoothing + AdamW[0.9, 0.95], lr=1e-4 [with scaling rule ofc], lr warmup + cosine decay, skip bias/CLS/pos_embed weight decay, etc, etc).
xformers ViT-B Config
xformers ViT-B
Command
To Reproduce
Steps to reproduce the behavior: