Open jramapuram opened 2 years ago
thanks for the issue, just saw that, looking into it !
@jramapuram can you elaborate on your config, do you have Triton for instance ? Could you share a print(model) here, to be sure of what parts are actually instantiated ? After a quick look it seems that there could be a part where the gradients are not handled up to the same precision as torch, at least that's my #1 hypothesis
many thanks for the detailed issue and code snippets, this is perfect
if you're using Triton, could you test out installing a recent dev package ? pip install triton==1.1.2.dev20220106
also @jramapuram could you confirm that this is with torch AMP ? (fp16)
cc @dianaml0, @fmassa, is that something that you've seen ? I remember @xwhan saw that at some point, but I thought that this was fixed. I just did a quick check in the triton code, and we're keeping the data type as fp32 in the softmax and layernorm case when AMP is activated, which should lead to a similar precision as pytorch (layernorm is a bit below). It looks like a vanishing gradient problem, and the parts here are very standard (MLP and scaled_dot_product attention), I'm wondering whether it could be somewhere else in the code, or if the timm ViT adds some parameter-less normalization for instance. I'm not seeing this on the Cifar example that we host
edit: adding some more context and info
@jramapuram the eps parameter for LayerNorm is not the same in between timm and xformers (1e-5 vs. 1e-6), it's a long shot but since your issue could be related to vanishing gradient, could explain. Fixing that
Filling in details:
Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9
Filling in details:
* AMP FP16 white_check_mark * triton==1.1.1 ( can test 1.1.2.dev20220106 +1 ) * Will try the layernorm eps; good find! Might be relevant for AMP
Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9
I can repro the problem with the minimal microViT example actually (prior to the linked PRs), just need to wait long enough. Testing right now with the changes from the linked PRs
seems fine with the updated eps @jramapuram, let me know if it fixes your issue ?
Training now; will update here :)
def update_ln_eps(module: nn.Module, new_eps: float):
"""Recurse and update LN eps with this value."""
from xformers.triton.layer_norm import FusedLayerNorm
if isinstance(module, torch.nn.modules.LayerNorm):
module.eps = new_eps
if isinstance(module, FusedLayerNorm):
module.epsilon = new_eps
for _, child in module.named_children():
update_ln_eps(child, new_eps)
@blefaudeux : Unfortunately this has not seemed to fix it for me π¬ . Not sure if the scaling from microViT --> ViT-B ImageNet might be causing some issues that are not easily evident.
With LN fix using function above:
With Triton 1.1.2.dev20220106
(tested with pip freeze to validate)
Commit d4c28fbbb881753e7855d08d121c85878a72b775 (tried with and without triton 1.1.2.dev20220106
):
For sanity I also tried again swapping back to TIMM and it is still working π¬
ouch, this is not good.. the issue auto-closed it seems, but keeping it open, I'll try to dig a bit more
@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm)
Else I can think of
I can confirm that it does not happen on cifar and a smaller ViT unfortunately, would have been nice to have an easy repro
edit: adding more context
@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm)
testing with pure pytorch layers right now, and I'm not seeing any difference so far, so might not be a good explanation
Else I can think of
* different init strategies for the weights (probable but kind of unlikely that it explains I think)
init is different indeed, see for instance, while xformers mostly follows default pytorch
* shared weights in the projection,
the projection seems to follow the same structure, n x 3n matrix + bias, nothing different here
* different pre/post normalization
nope, Pre-norm in both cases
in short I don't see much difference (provided my home test with pytorch vs. triton parts is confirmed on your end @jramapuram) except for weights init, since AMP training is notoriously a little finicky maybe that could explain ? Not super intuitive to me but having a deeper look
def _init_vit_weights(module: nn.Module):
"""Transformer weight initialization from TIMM."""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
@blefaudeux : are there any xformers
linear layers that don't inherit nn.Linear
that might be missed by this init function?
Thanks for the great suggestions btw!
Will try vanilla pytorch (without triton) on ImgNet for my own sanity as well sweat_smile
PreNorm white_check_mark
I do the custom TIMM init already (see code above which distills this ; will also try a lower std (std=0.01) as well.):
def _init_vit_weights(module: nn.Module): """Transformer weight initialization from TIMM.""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.zeros_(module.bias) nn.init.ones_(module.weight)
ahh, I didn't know for the init on your side, so this rules it out also !
@blefaudeux : are there any
xformers
linear layers that don't inheritnn.Linear
that might be missed by this init function?
No I don't think so, although fused MLP uses a normal nn.Linear but fuses the dropout/bias/activation (so the bias init would be missed). It does not seem like you're using fusedMLP so it should not be the case
Thanks for the great suggestions btw!
No problem, this is a little perplexing to be honest, but we'll root it out !
seeing your curves, it does seem a little different from what I was seeing prior to the eps adjustment: validation accuracy was collapsing in the microViT example / CIFAR, but over many steps, while yours looks like a complete breakdown, one update completely breaks the model. Really looks like a raw fp16 representation problem, an underflow or overflow would look like that
this is what a faulty normalization floor looked like (eps = 1e-5, pre/post correction), not really what you're seeing, unless it's a logging issue (not logging often enough, but guess is no since I'm seeing your steps axis and you seem to log per step)
hmm turns out I was testing with rotary embeddings turned on, and they make a huge difference
Lower std on trunc normal init (0.01):
Without triton:
WARNING:root:Triton is not available, some optimizations will not be enabled.
Error No module named 'triton'
FusedMLP:
(feedforward): FusedMLP(
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): FusedDropoutBias(
(pytorch_activation): GELU()
)
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): FusedDropoutBias(
(pytorch_activation): Passthrough()
)
)
)
thanks @jramapuram, it's very informative, so no issues with the triton layers whatsover, the problem is in a pure pytorch definition.. :/
checking for what could cause an underflow or overflow in between Timm's implementation and ours, looks like the sqrt(d) normalization is done post-hoc in timm (see) while we do it prior to computing the attention (see). cc @fmassa @dianaml0, thoughts ?
@jramapuram can you check out https://github.com/facebookresearch/xformers/pull/225 ? I'm trying to find a test to repro this issue
oh @jramapuram, something else which could be significant: you can check out this paper (end of page 3) and the matching timm-vit model, Ross drops the hidden layer in the last block citing "training unstabilities" ! I'll add an option to xformers to be able to do this from the config side (pass "multiplier" == -1), it could well be the reason
nevermind, I just realized that you were already doing that..
Running with issue_219
branch π’ :
Branch 'issue_219' set up to track remote branch 'issue_219' from 'origin'.
...
Building wheels for collected packages: xformers
Building wheel for xformers (setup.py): started
Building wheel for xformers (setup.py): still running...
Building wheel for xformers (setup.py): finished with status 'done'
Created wheel for xformers: filename=xformers-0.0.9-cp39-cp39-linux_x86_64.whl size=1337425 sha256=8b0825f014b9859ad3786042973af9d389b1031c68a1eb734cb7c4b78050ce08
Stored in directory: /mnt/tmp/pip-ephem-wheel-cache-a4wszhv8/wheels/10/9b/72/7597306f87828c97afa86adf56f4c78db426e51c2ee02d2f66
Successfully built xformers
Installing collected packages: xformers
Attempting uninstall: xformers
Found existing installation: xformers 0.0.7
Uninstalling xformers-0.0.7:
Successfully uninstalled xformers-0.0.7
Successfully installed xformers-0.0.9
Running with
issue_219
branch cry :Branch 'issue_219' set up to track remote branch 'issue_219' from 'origin'. ... Building wheels for collected packages: xformers Building wheel for xformers (setup.py): started Building wheel for xformers (setup.py): still running... Building wheel for xformers (setup.py): finished with status 'done' Created wheel for xformers: filename=xformers-0.0.9-cp39-cp39-linux_x86_64.whl size=1337425 sha256=8b0825f014b9859ad3786042973af9d389b1031c68a1eb734cb7c4b78050ce08 Stored in directory: /mnt/tmp/pip-ephem-wheel-cache-a4wszhv8/wheels/10/9b/72/7597306f87828c97afa86adf56f4c78db426e51c2ee02d2f66 Successfully built xformers Installing collected packages: xformers Attempting uninstall: xformers Found existing installation: xformers 0.0.7 Uninstalling xformers-0.0.7: Successfully uninstalled xformers-0.0.7 Successfully installed xformers-0.0.9
thanks a lot for the test, it's really perplexing.. @fmassa has been using some xformers block a lot with imagenet, but not the whole model, I don't remember him getting this. I'll check the weight inits as soon as I get the time, sorry for the delay
Hey @jramapuram, back to you ! We can exchange by mail if that helps, I'd really love to get to the bottom of this. I'm adding DeepNet to xformers, some of the issues mentioned with pre-ln look like they could apply here (although it does not exaplain why Timm's implementation does not face the same instabilities). Current thoughts are that it could be related to weight init (I know that you're handling that already, but bug/issue ?) or different LR layer wise for instance, and a grep which fails on the xformers model because of different names ? Would it be possible to share more of the training code ?
Happy to chat more via email @blefaudeux !
hey there, so I just had time to have another look, weight-init focused. It seems that the _init_vit_weights()
does not apply to the attention projection steps, which are wrapped in the InProjContainer
, it's visible if you change init_vit_weights to
def _init_vit_weights(module: nn.Module):
"""Transformer weight initialization from TIMM."""
if isinstance(module, nn.Linear):
print(f"Initializing {module}")
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
print(f"Initializing {module}")
elif len(list(module.children())) == 0:
print(f" Module {module} not initialized")
Unless I'm mistaken, it looks like the InProjContainer is skipped, since it does not expose the projection step as a nn.Linear but as a nn.Parameter (see)
This means that these weights/biases are initialized according to the defaults, which are here, and should match fairseq (but not timm).
Quick questions, pulling your brain @fmassa @dianaml0 @jramapuram :
Sorry for the delayed response. Is it possible to expose the InProjContainer
s.t. it can be init-d by a check to nn.Linear
?
I'm also pulling in pre_post_norm_fix
and testing with the above init @blefaudeux . Will keep you apprised.
Sorry for the delayed response. Is it possible to expose the
InProjContainer
s.t. it can be init-d by a check tonn.Linear
?
totally, https://github.com/facebookresearch/xformers/tree/issue_219_param_init_fix is doing that (trying to get to a better projection block, while still flexible enough to support NLP). It's still buggy though as per unit tests checks, not completely sure why as of now, trying to fix that soonish
@blefaudeux: prelim results from pre_post_norm_fix
are looking good! Will post final graph comparing to timm when done!
Good thoughts on the InProjContainer! Here are the list of modules not updated for ref (I'm sure you found this already :))
Module Identity() not initialized
Module Dropout(p=0.0, inplace=False) not initialized
Module InProjContainer() not initialized
Module DropPath() not initialized
Module FusedLayerNorm() not initialized # <-- probably should be caught by LayerNorm init call as well, I realize the defaults are fine, but just in case someone wants a different init here.
- would it make sense to expose a init method in the xformer model which covers some of these options ? (and handles all the children modules properly)
Either a global init fn or a way for the layers to inherit their baseline counterparts. Which ever makes sense.
@blefaudeux : looks within tolerance! Thanks! Closing this issue.
@blefaudeux : looks within tolerance! Thanks! Closing this issue.
Thanks a lot @jramapuram for bearing with me on this and for all the testing, really appreciated!
@blefaudeux : any thoughts on what might be missing an init with Deepnorm [running into NaNs again around 50 epochs in again π ] ? Tried the following:
Still is working fine with prenorm though
hey @jramapuram, thanks for the report ! I'm guessing that this is with AMP ? not sure right now but I'll have a look !
edit: can you tell me a bit more about the LR for instance, hyper params ? Is that possible that this NaNs after the warmup because the LR is too big for instance, can you trace it to the gradient exploding or is that something else ?
It's kind of interesting that the instability is in DeepNorm there, as it should help instabilities :D Paper for reference. So it seems to work fine with the microViT example that we have (training on Cifar and small ViT)
Ah, so one issue is that the Deepnorm init looks for the Value projection, but in the case of the self-attention there's only one weight module for all, and it does not get scaled. See the proj definition and how the weight init works. Fixing that
Sorry for the delay! Testing now, but blocked due to the triton version bump:
AttributeError: module 'triton.language' has no attribute 'constexpr'
Guessing this is due to https://github.com/facebookresearch/xformers/pull/272 -- is the recommended version pip install triton==2.0.0.dev20220430
? Didn't see anything in the docs
Need to see if we can bump to triton2 without breaking deepspeed.
Bumped to triton2 but am now blocked via https://github.com/facebookresearch/xformers/issues/290 π
Will test without triton and report back
@blefaudeux : Disabled triton, pulled master and still NaN-ing (with DeepNorm). The only init in this scenario is for CLS and pos_embed. Same ViT-B as above.
can you tell me a bit more about the LR for instance, hyper params ?
Yup, same params as described in MAE appendix (shown below). Might be worthwhile kicking off a similar run on your side?
@blefaudeux : Disabled triton, pulled master and still NaN-ing (with DeepNorm). The only init in this scenario is for CLS and pos_embed. Same ViT-B as above.
can you tell me a bit more about the LR for instance, hyper params ?
Yup, same params as described in MAE appendix (shown below). Might be worthwhile kicking off a similar run on your side?
thanks for the test and report ! so I assume that this is with the same LR schedule (and possible grad clip) as pre-norm, pre-norm works and is at parity with Timm but deepnorm NaNs (I presume following gradient explosion). Clipping the gradients is not enough ? I don't have an easy way to repro myself as of now (doable but would take a lot of time), trying to think that out is another option... It's kind of peculiar since deepnorm should stabilize the training actually, and the paper claims similar dynamics to that of pre-norm if I remember correctly
Thanks for the prompt response!
Happy to test things for you as well for repro :)
thanks @jramapuram for the precision, I must have missed something in the paper, I'll have another look. It would be great if I can come up with a test which does not involve IN, that's a little too big for a regression check, I'll see what I can find :)
I'm thinking that there may be a remaining issue with the weights init in the self-attention case, where the dimensions for fan_in/fan_out would be skewed because we merge the 3 projections. Having a look asap, but reasonably within a couple of days
Still no joy on ef6de0faa8bacf91d9fb83ed733ffa0546d85db9 π¬ .
Here I only init just the pos_embed
and cls_token
with trunc_normal_(std=0.02)
and use DeepNorm:
Edit: updated curves to compare to prenorm.
Still no joy on ef6de0f grimacing . Will show pre-norm plot for comparison soon.
Here I only init just the
pos_embed
andcls_token
withtrunc_normal_(std=0.02)
and use DeepNorm:
oh yes for current main branch, nothing landed addressing this yet. Could you try https://github.com/facebookresearch/xformers/pull/303 by any chance ? I can try to start something later today, but a little bit underwater atm :/
π Bug
I'm trying to create a 1:1 config that can train a stable ViT-B with the MAE config (from appendix A.2).
Maybe I'm missing something (highly plausible), but when I use xformers instead of timm it creates an unstable training scenario [over numerous trials] with exactly the same hyper-parameters (batch_size=4096 + cutmix + mixup + label smoothing + AdamW[0.9, 0.95], lr=1e-4 [with scaling rule ofc], lr warmup + cosine decay, skip bias/CLS/pos_embed weight decay, etc, etc).
xformers ViT-B Config
xformers ViT-B
Command
To Reproduce
Steps to reproduce the behavior: