facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.34k stars 585 forks source link

xformers ViT-B ImageNet MAE + Deepnorm training instability #219

Open jramapuram opened 2 years ago

jramapuram commented 2 years ago

πŸ› Bug

I'm trying to create a 1:1 config that can train a stable ViT-B with the MAE config (from appendix A.2).

Maybe I'm missing something (highly plausible), but when I use xformers instead of timm it creates an unstable training scenario [over numerous trials] with exactly the same hyper-parameters (batch_size=4096 + cutmix + mixup + label smoothing + AdamW[0.9, 0.95], lr=1e-4 [with scaling rule ofc], lr warmup + cosine decay, skip bias/CLS/pos_embed weight decay, etc, etc).

image

xformers ViT-B Config

reversible: False
block_type: "encoder"
num_layers: 12
dim_model: 768
layer_norm_style: "pre"

multi_head_config:
  num_heads: 12
  residual_dropout: 0.1  # (1) tried without this, (2) swapping this for DropPath, (3) with regular dropout
  use_rotary_embeddings: False

  attention:
    name: "scaled_dot_product"
    dropout: 0.0
    causal: False

feedforward_config:
  name: "MLP"
  dropout: 0.0
  activation: "gelu"
  hidden_layer_multiplier: 4

xformers ViT-B

"""A simple ViT-B in xformers."""

import typing as t
from pathlib import Path
import yaml

import torch
from torch import nn
from timm.models.vision_transformer import DropPath
from timm.models.layers.patch_embed import PatchEmbed
from timm.models.layers.weight_init import trunc_normal_
from xformers.factory import xFormer, xFormerConfig

def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)

class ViT(nn.Module):
    """Vision transformer + head module."""

    def __init__(
        self,
        img_size: int = 224,
        in_chans: int = 3,
        patch_size: int = 16,
        num_classes: int = 1000,
        drop_path_rate: float = 0,
        transfomer_config_file: t.Union[Path, str] = "configs/vit_b.yaml",
    ):
        """A standard ViT module."""
        super().__init__()

        # read the model config
        with open(transfomer_config_file, "rb") as fileptr:
            self.model_config = yaml.load(fileptr, Loader=yaml.FullLoader)

        # embed_dim = self.model_config["block_config"]["dim_model"]
        embed_dim = self.model_config["dim_model"]
        print(self.model_config)

        # build the patch embedding model
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            flatten=True,
        )
        self.num_patches = self.patch_embed.num_patches

        # Build the tokens / position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(  # +1 for CLS token
            torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
        )

        # build the backbone
        self.backbone = xFormer.from_config(xFormerConfig([self.model_config]))

        # Swap dropout with drop-path
        # Also tried (1) without this, (2) without dropout.
        if drop_path_rate > 0:
            dpr_idx = 0
            dpr = [
                x.item()
                for x in torch.linspace(0, drop_path_rate, len(self.backbone.encoders))
            ]
            for layer in self.backbone.encoders:
                if hasattr(layer.mha, "resid_drop"):
                    setattr(layer.mha, "resid_drop", DropPath(dpr[dpr_idx]))
                    dpr_idx += 1

        # build the head network
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim, eps=1e-6),
            nn.Linear(embed_dim, num_classes)
        )

        # do ViT initializations
        self.init_weights()

    def init_weights(self):
        """Initialize layers, pos_embed and CLS for ViTs."""
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(_init_vit_weights)

    def forward(self, inputs: torch.Tensor) -> t.Dict[str, torch.Tensor]:
        """Infer variates and return a dict with repr and logits.

        Example sizing:
        patches = [2, 196, 768] --> [2, 197, 768] (after CLS)
        representation = [2, 197, 768]
        logits = [2, 197, 1000]
        CLS = [2, 1000]  # select first of 197

        """
        patches = self.patch_embed(inputs)
        cls_token = self.cls_token.expand(inputs.shape[0], -1, -1)
        out = torch.cat((cls_token, patches), dim=1)
        out = out + self.pos_embed

        representation = self.backbone(out)
        logits = self.head(representation)
        return {
            "representation": representation.detach(),
            "logits": logits,
            "CLS": logits[:, 0],
        }

Command

vit_b = ViT()
vit_b_timm = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=1000)

number_of_parameters(vit_b)
Out[18]: 86567656

number_of_parameters(vit_b_timm)
Out[19]: 86567656

To Reproduce

Steps to reproduce the behavior:

  1. Train a ViT-B with xformers --> unstable
  2. Using same training setup with timm --> stable
  3. 😒
blefaudeux commented 2 years ago

thanks for the issue, just saw that, looking into it !

blefaudeux commented 2 years ago

@jramapuram can you elaborate on your config, do you have Triton for instance ? Could you share a print(model) here, to be sure of what parts are actually instantiated ? After a quick look it seems that there could be a part where the gradients are not handled up to the same precision as torch, at least that's my #1 hypothesis

blefaudeux commented 2 years ago

many thanks for the detailed issue and code snippets, this is perfect

blefaudeux commented 2 years ago

if you're using Triton, could you test out installing a recent dev package ? pip install triton==1.1.2.dev20220106

blefaudeux commented 2 years ago

also @jramapuram could you confirm that this is with torch AMP ? (fp16)

blefaudeux commented 2 years ago

cc @dianaml0, @fmassa, is that something that you've seen ? I remember @xwhan saw that at some point, but I thought that this was fixed. I just did a quick check in the triton code, and we're keeping the data type as fp32 in the softmax and layernorm case when AMP is activated, which should lead to a similar precision as pytorch (layernorm is a bit below). It looks like a vanishing gradient problem, and the parts here are very standard (MLP and scaled_dot_product attention), I'm wondering whether it could be somewhere else in the code, or if the timm ViT adds some parameter-less normalization for instance. I'm not seeing this on the Cifar example that we host

edit: adding some more context and info

blefaudeux commented 2 years ago

@jramapuram the eps parameter for LayerNorm is not the same in between timm and xformers (1e-5 vs. 1e-6), it's a long shot but since your issue could be related to vanishing gradient, could explain. Fixing that

jramapuram commented 2 years ago

Filling in details:

Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9

blefaudeux commented 2 years ago

Filling in details:

* AMP FP16 white_check_mark

* triton==1.1.1 ( can test 1.1.2.dev20220106 +1  )

* Will try the layernorm eps; good find! Might be relevant for AMP

Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9

I can repro the problem with the minimal microViT example actually (prior to the linked PRs), just need to wait long enough. Testing right now with the changes from the linked PRs

blefaudeux commented 2 years ago

seems fine with the updated eps @jramapuram, let me know if it fixes your issue ?

jramapuram commented 2 years ago

Training now; will update here :)

def update_ln_eps(module: nn.Module, new_eps: float):
    """Recurse and update LN eps with this value."""
    from xformers.triton.layer_norm import FusedLayerNorm

    if isinstance(module, torch.nn.modules.LayerNorm):
        module.eps = new_eps

    if isinstance(module, FusedLayerNorm):
        module.epsilon = new_eps

    for _, child in module.named_children():
        update_ln_eps(child, new_eps)
jramapuram commented 2 years ago

@blefaudeux : Unfortunately this has not seemed to fix it for me 😬 . Not sure if the scaling from microViT --> ViT-B ImageNet might be causing some issues that are not easily evident.

With LN fix using function above: image

With Triton 1.1.2.dev20220106 (tested with pip freeze to validate) image

Commit d4c28fbbb881753e7855d08d121c85878a72b775 (tried with and without triton 1.1.2.dev20220106): image

For sanity I also tried again swapping back to TIMM and it is still working 😬 image

blefaudeux commented 2 years ago

ouch, this is not good.. the issue auto-closed it seems, but keeping it open, I'll try to dig a bit more

blefaudeux commented 2 years ago

@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm)

blefaudeux commented 2 years ago

Else I can think of

I can confirm that it does not happen on cifar and a smaller ViT unfortunately, would have been nice to have an easy repro

edit: adding more context

blefaudeux commented 2 years ago

@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm)

testing with pure pytorch layers right now, and I'm not seeing any difference so far, so might not be a good explanation

blefaudeux commented 2 years ago

Else I can think of

* different init strategies for the weights (probable but kind of unlikely that it explains I think)

init is different indeed, see for instance, while xformers mostly follows default pytorch

* shared weights in the projection,

the projection seems to follow the same structure, n x 3n matrix + bias, nothing different here

* different pre/post normalization

nope, Pre-norm in both cases

in short I don't see much difference (provided my home test with pytorch vs. triton parts is confirmed on your end @jramapuram) except for weights init, since AMP training is notoriously a little finicky maybe that could explain ? Not super intuitive to me but having a deeper look

jramapuram commented 2 years ago

Thanks for the great suggestions btw!

blefaudeux commented 2 years ago
  • Will try vanilla pytorch (without triton) on ImgNet for my own sanity as well sweat_smile

    • PreNorm white_check_mark

    • I do the custom TIMM init already (see code above which distills this ; will also try a lower std (std=0.01) as well.):

def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)

ahh, I didn't know for the init on your side, so this rules it out also !

@blefaudeux : are there any xformers linear layers that don't inherit nn.Linear that might be missed by this init function?

No I don't think so, although fused MLP uses a normal nn.Linear but fuses the dropout/bias/activation (so the bias init would be missed). It does not seem like you're using fusedMLP so it should not be the case

Thanks for the great suggestions btw!

No problem, this is a little perplexing to be honest, but we'll root it out !

blefaudeux commented 2 years ago

seeing your curves, it does seem a little different from what I was seeing prior to the eps adjustment: validation accuracy was collapsing in the microViT example / CIFAR, but over many steps, while yours looks like a complete breakdown, one update completely breaks the model. Really looks like a raw fp16 representation problem, an underflow or overflow would look like that

this is what a faulty normalization floor looked like (eps = 1e-5, pre/post correction), not really what you're seeing, unless it's a logging issue (not logging often enough, but guess is no since I'm seeing your steps axis and you seem to log per step) gnome-shell-screenshot-a985s4

blefaudeux commented 2 years ago

hmm turns out I was testing with rotary embeddings turned on, and they make a huge difference

jramapuram commented 2 years ago

Lower std on trunc normal init (0.01): image

Without triton: image

WARNING:root:Triton is not available, some optimizations will not be enabled.
Error No module named 'triton'

FusedMLP: image

        (feedforward): FusedMLP(
          (mlp): Sequential(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): FusedDropoutBias(
              (pytorch_activation): GELU()
            )
            (2): Linear(in_features=3072, out_features=768, bias=True)
            (3): FusedDropoutBias(
              (pytorch_activation): Passthrough()
            )
          )
        )
blefaudeux commented 2 years ago

thanks @jramapuram, it's very informative, so no issues with the triton layers whatsover, the problem is in a pure pytorch definition.. :/

blefaudeux commented 2 years ago

checking for what could cause an underflow or overflow in between Timm's implementation and ours, looks like the sqrt(d) normalization is done post-hoc in timm (see) while we do it prior to computing the attention (see). cc @fmassa @dianaml0, thoughts ?

@jramapuram can you check out https://github.com/facebookresearch/xformers/pull/225 ? I'm trying to find a test to repro this issue

blefaudeux commented 2 years ago

oh @jramapuram, something else which could be significant: you can check out this paper (end of page 3) and the matching timm-vit model, Ross drops the hidden layer in the last block citing "training unstabilities" ! I'll add an option to xformers to be able to do this from the config side (pass "multiplier" == -1), it could well be the reason

nevermind, I just realized that you were already doing that..

jramapuram commented 2 years ago

Running with issue_219 branch 😒 :

image
Branch 'issue_219' set up to track remote branch 'issue_219' from 'origin'.
...
Building wheels for collected packages: xformers
  Building wheel for xformers (setup.py): started
  Building wheel for xformers (setup.py): still running...
  Building wheel for xformers (setup.py): finished with status 'done'
  Created wheel for xformers: filename=xformers-0.0.9-cp39-cp39-linux_x86_64.whl size=1337425 sha256=8b0825f014b9859ad3786042973af9d389b1031c68a1eb734cb7c4b78050ce08
  Stored in directory: /mnt/tmp/pip-ephem-wheel-cache-a4wszhv8/wheels/10/9b/72/7597306f87828c97afa86adf56f4c78db426e51c2ee02d2f66
Successfully built xformers
Installing collected packages: xformers
  Attempting uninstall: xformers
    Found existing installation: xformers 0.0.7
    Uninstalling xformers-0.0.7:
      Successfully uninstalled xformers-0.0.7
Successfully installed xformers-0.0.9
blefaudeux commented 2 years ago

Running with issue_219 branch cry :

image
Branch 'issue_219' set up to track remote branch 'issue_219' from 'origin'.
...
Building wheels for collected packages: xformers
  Building wheel for xformers (setup.py): started
  Building wheel for xformers (setup.py): still running...
  Building wheel for xformers (setup.py): finished with status 'done'
  Created wheel for xformers: filename=xformers-0.0.9-cp39-cp39-linux_x86_64.whl size=1337425 sha256=8b0825f014b9859ad3786042973af9d389b1031c68a1eb734cb7c4b78050ce08
  Stored in directory: /mnt/tmp/pip-ephem-wheel-cache-a4wszhv8/wheels/10/9b/72/7597306f87828c97afa86adf56f4c78db426e51c2ee02d2f66
Successfully built xformers
Installing collected packages: xformers
  Attempting uninstall: xformers
    Found existing installation: xformers 0.0.7
    Uninstalling xformers-0.0.7:
      Successfully uninstalled xformers-0.0.7
Successfully installed xformers-0.0.9

thanks a lot for the test, it's really perplexing.. @fmassa has been using some xformers block a lot with imagenet, but not the whole model, I don't remember him getting this. I'll check the weight inits as soon as I get the time, sorry for the delay

blefaudeux commented 2 years ago

Hey @jramapuram, back to you ! We can exchange by mail if that helps, I'd really love to get to the bottom of this. I'm adding DeepNet to xformers, some of the issues mentioned with pre-ln look like they could apply here (although it does not exaplain why Timm's implementation does not face the same instabilities). Current thoughts are that it could be related to weight init (I know that you're handling that already, but bug/issue ?) or different LR layer wise for instance, and a grep which fails on the xformers model because of different names ? Would it be possible to share more of the training code ?

jramapuram commented 2 years ago

Happy to chat more via email @blefaudeux !

blefaudeux commented 2 years ago

hey there, so I just had time to have another look, weight-init focused. It seems that the _init_vit_weights() does not apply to the attention projection steps, which are wrapped in the InProjContainer, it's visible if you change init_vit_weights to

def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""

    if isinstance(module, nn.Linear):
        print(f"Initializing {module}")
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)
        print(f"Initializing {module}")
    elif len(list(module.children())) == 0:
        print(f" Module {module} not initialized")

Unless I'm mistaken, it looks like the InProjContainer is skipped, since it does not expose the projection step as a nn.Linear but as a nn.Parameter (see)

This means that these weights/biases are initialized according to the defaults, which are here, and should match fairseq (but not timm).

Quick questions, pulling your brain @fmassa @dianaml0 @jramapuram :

jramapuram commented 2 years ago

Sorry for the delayed response. Is it possible to expose the InProjContainer s.t. it can be init-d by a check to nn.Linear?

I'm also pulling in pre_post_norm_fix and testing with the above init @blefaudeux . Will keep you apprised.

blefaudeux commented 2 years ago

Sorry for the delayed response. Is it possible to expose the InProjContainer s.t. it can be init-d by a check to nn.Linear?

totally, https://github.com/facebookresearch/xformers/tree/issue_219_param_init_fix is doing that (trying to get to a better projection block, while still flexible enough to support NLP). It's still buggy though as per unit tests checks, not completely sure why as of now, trying to fix that soonish

jramapuram commented 2 years ago

@blefaudeux: prelim results from pre_post_norm_fix are looking good! Will post final graph comparing to timm when done!

image

Good thoughts on the InProjContainer! Here are the list of modules not updated for ref (I'm sure you found this already :))

 Module Identity() not initialized
 Module Dropout(p=0.0, inplace=False) not initialized
 Module InProjContainer() not initialized
 Module DropPath() not initialized
 Module FusedLayerNorm() not initialized  # <-- probably should be caught by LayerNorm init call as well, I realize the defaults are fine, but just in case someone wants a different init here.
jramapuram commented 2 years ago
  • would it make sense to expose a init method in the xformer model which covers some of these options ? (and handles all the children modules properly)

Either a global init fn or a way for the layers to inherit their baseline counterparts. Which ever makes sense.

jramapuram commented 2 years ago
image

@blefaudeux : looks within tolerance! Thanks! Closing this issue.

blefaudeux commented 2 years ago
image

@blefaudeux : looks within tolerance! Thanks! Closing this issue.

Thanks a lot @jramapuram for bearing with me on this and for all the testing, really appreciated!

jramapuram commented 2 years ago

@blefaudeux : any thoughts on what might be missing an init with Deepnorm [running into NaNs again around 50 epochs in again 😭 ] ? Tried the following:

  1. Same init as above.
  2. No init as I know that deepnorm has a custom beta term for weights and alpha for layernorm.

Still is working fine with prenorm though

blefaudeux commented 2 years ago

hey @jramapuram, thanks for the report ! I'm guessing that this is with AMP ? not sure right now but I'll have a look !

edit: can you tell me a bit more about the LR for instance, hyper params ? Is that possible that this NaNs after the warmup because the LR is too big for instance, can you trace it to the gradient exploding or is that something else ?

blefaudeux commented 2 years ago

It's kind of interesting that the instability is in DeepNorm there, as it should help instabilities :D Paper for reference. So it seems to work fine with the microViT example that we have (training on Cifar and small ViT)

blefaudeux commented 2 years ago

Ah, so one issue is that the Deepnorm init looks for the Value projection, but in the case of the self-attention there's only one weight module for all, and it does not get scaled. See the proj definition and how the weight init works. Fixing that

jramapuram commented 2 years ago

Sorry for the delay! Testing now, but blocked due to the triton version bump:

AttributeError: module 'triton.language' has no attribute 'constexpr' 

Guessing this is due to https://github.com/facebookresearch/xformers/pull/272 -- is the recommended version pip install triton==2.0.0.dev20220430 ? Didn't see anything in the docs

Need to see if we can bump to triton2 without breaking deepspeed.

jramapuram commented 2 years ago

Bumped to triton2 but am now blocked via https://github.com/facebookresearch/xformers/issues/290 😭

jramapuram commented 2 years ago

Will test without triton and report back

jramapuram commented 2 years ago

@blefaudeux : Disabled triton, pulled master and still NaN-ing (with DeepNorm). The only init in this scenario is for CLS and pos_embed. Same ViT-B as above.

can you tell me a bit more about the LR for instance, hyper params ?

Yup, same params as described in MAE appendix (shown below). Might be worthwhile kicking off a similar run on your side? image

blefaudeux commented 2 years ago

@blefaudeux : Disabled triton, pulled master and still NaN-ing (with DeepNorm). The only init in this scenario is for CLS and pos_embed. Same ViT-B as above.

can you tell me a bit more about the LR for instance, hyper params ?

Yup, same params as described in MAE appendix (shown below). Might be worthwhile kicking off a similar run on your side? image

thanks for the test and report ! so I assume that this is with the same LR schedule (and possible grad clip) as pre-norm, pre-norm works and is at parity with Timm but deepnorm NaNs (I presume following gradient explosion). Clipping the gradients is not enough ? I don't have an easy way to repro myself as of now (doable but would take a lot of time), trying to think that out is another option... It's kind of peculiar since deepnorm should stabilize the training actually, and the paper claims similar dynamics to that of pre-norm if I remember correctly

jramapuram commented 2 years ago

Thanks for the prompt response!

Happy to test things for you as well for repro :)

blefaudeux commented 2 years ago

thanks @jramapuram for the precision, I must have missed something in the paper, I'll have another look. It would be great if I can come up with a test which does not involve IN, that's a little too big for a regression check, I'll see what I can find :)

blefaudeux commented 2 years ago

I'm thinking that there may be a remaining issue with the weights init in the self-attention case, where the dimensions for fan_in/fan_out would be skewed because we merge the 3 projections. Having a look asap, but reasonably within a couple of days

jramapuram commented 2 years ago

Still no joy on ef6de0faa8bacf91d9fb83ed733ffa0546d85db9 😬 .

Here I only init just the pos_embed and cls_token with trunc_normal_(std=0.02) and use DeepNorm: image image

Edit: updated curves to compare to prenorm. image

blefaudeux commented 2 years ago

Still no joy on ef6de0f grimacing . Will show pre-norm plot for comparison soon.

Here I only init just the pos_embed and cls_token with trunc_normal_(std=0.02) and use DeepNorm: image image

oh yes for current main branch, nothing landed addressing this yet. Could you try https://github.com/facebookresearch/xformers/pull/303 by any chance ? I can try to start something later today, but a little bit underwater atm :/