facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.61k stars 612 forks source link

xformers ViT-B ImageNet MAE + Deepnorm training instability #219

Open jramapuram opened 2 years ago

jramapuram commented 2 years ago

🐛 Bug

I'm trying to create a 1:1 config that can train a stable ViT-B with the MAE config (from appendix A.2).

Maybe I'm missing something (highly plausible), but when I use xformers instead of timm it creates an unstable training scenario [over numerous trials] with exactly the same hyper-parameters (batch_size=4096 + cutmix + mixup + label smoothing + AdamW[0.9, 0.95], lr=1e-4 [with scaling rule ofc], lr warmup + cosine decay, skip bias/CLS/pos_embed weight decay, etc, etc).

image

xformers ViT-B Config

reversible: False
block_type: "encoder"
num_layers: 12
dim_model: 768
layer_norm_style: "pre"

multi_head_config:
  num_heads: 12
  residual_dropout: 0.1  # (1) tried without this, (2) swapping this for DropPath, (3) with regular dropout
  use_rotary_embeddings: False

  attention:
    name: "scaled_dot_product"
    dropout: 0.0
    causal: False

feedforward_config:
  name: "MLP"
  dropout: 0.0
  activation: "gelu"
  hidden_layer_multiplier: 4

xformers ViT-B

"""A simple ViT-B in xformers."""

import typing as t
from pathlib import Path
import yaml

import torch
from torch import nn
from timm.models.vision_transformer import DropPath
from timm.models.layers.patch_embed import PatchEmbed
from timm.models.layers.weight_init import trunc_normal_
from xformers.factory import xFormer, xFormerConfig

def _init_vit_weights(module: nn.Module):
    """Transformer weight initialization from TIMM."""
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)

class ViT(nn.Module):
    """Vision transformer + head module."""

    def __init__(
        self,
        img_size: int = 224,
        in_chans: int = 3,
        patch_size: int = 16,
        num_classes: int = 1000,
        drop_path_rate: float = 0,
        transfomer_config_file: t.Union[Path, str] = "configs/vit_b.yaml",
    ):
        """A standard ViT module."""
        super().__init__()

        # read the model config
        with open(transfomer_config_file, "rb") as fileptr:
            self.model_config = yaml.load(fileptr, Loader=yaml.FullLoader)

        # embed_dim = self.model_config["block_config"]["dim_model"]
        embed_dim = self.model_config["dim_model"]
        print(self.model_config)

        # build the patch embedding model
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            flatten=True,
        )
        self.num_patches = self.patch_embed.num_patches

        # Build the tokens / position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(  # +1 for CLS token
            torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
        )

        # build the backbone
        self.backbone = xFormer.from_config(xFormerConfig([self.model_config]))

        # Swap dropout with drop-path
        # Also tried (1) without this, (2) without dropout.
        if drop_path_rate > 0:
            dpr_idx = 0
            dpr = [
                x.item()
                for x in torch.linspace(0, drop_path_rate, len(self.backbone.encoders))
            ]
            for layer in self.backbone.encoders:
                if hasattr(layer.mha, "resid_drop"):
                    setattr(layer.mha, "resid_drop", DropPath(dpr[dpr_idx]))
                    dpr_idx += 1

        # build the head network
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim, eps=1e-6),
            nn.Linear(embed_dim, num_classes)
        )

        # do ViT initializations
        self.init_weights()

    def init_weights(self):
        """Initialize layers, pos_embed and CLS for ViTs."""
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(_init_vit_weights)

    def forward(self, inputs: torch.Tensor) -> t.Dict[str, torch.Tensor]:
        """Infer variates and return a dict with repr and logits.

        Example sizing:
        patches = [2, 196, 768] --> [2, 197, 768] (after CLS)
        representation = [2, 197, 768]
        logits = [2, 197, 1000]
        CLS = [2, 1000]  # select first of 197

        """
        patches = self.patch_embed(inputs)
        cls_token = self.cls_token.expand(inputs.shape[0], -1, -1)
        out = torch.cat((cls_token, patches), dim=1)
        out = out + self.pos_embed

        representation = self.backbone(out)
        logits = self.head(representation)
        return {
            "representation": representation.detach(),
            "logits": logits,
            "CLS": logits[:, 0],
        }

Command

vit_b = ViT()
vit_b_timm = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=1000)

number_of_parameters(vit_b)
Out[18]: 86567656

number_of_parameters(vit_b_timm)
Out[19]: 86567656

To Reproduce

Steps to reproduce the behavior:

  1. Train a ViT-B with xformers --> unstable
  2. Using same training setup with timm --> stable
  3. 😢
jramapuram commented 2 years ago

oh yes for current main branch, nothing landed addressing this yet. Could you try https://github.com/facebookresearch/xformers/pull/303 by any chance ? I can try to start something later today, but a little bit underwater atm :/

No worries! Will give that a shot :) [feel better!]

I added the reference pre-norm graphs above. Differences are basically:

  1. CLS + pos_embed init only : i.e. use default xformer init
  2. CLS + pos_embed + weight init
  3. CLS + pos_embed + weight_init + LN init
blefaudeux commented 2 years ago

oh wow, it's pretty clear indeed, thanks @jramapuram. #303 is definitely fixing a small bug, but I doubt that it explains this really, I'll dive back into deepnorm. I may have a repro actually, with the recent metaformer+cifar10 deepnorm does not work either but I thought that was because of the decidely different model structure. I'll give it a second look, sorry for the delay

blefaudeux commented 2 years ago

hmm, I did spend some time on that and found nothing obviously wrong, it's really perplexing. I'll give IN a shot. If you have the option, would it be possible to test this without AMP, just in case it's a matter of numerical accuracy (which would not be caught by the grad scaler if not NaN) ?

blefaudeux commented 2 years ago

Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403 ? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)

blefaudeux commented 2 years ago

(no, I've not forgotten that issue.. ). I would love to be able to repro on something a little smaller than a full blown IN + training over a couple of nodes, documenting that here. Attached is a minGPT training setup, with pre/post/deepnorm (8 layers transformer, 25M params). Deepnorm doesn't converge to a solution which is as good as the others, but no catastrophic failure for either of them Screenshot from 2022-05-23 13-23-39

jramapuram commented 2 years ago

Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403 ? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)

Thanks for keeping this in mind @blefaudeux. Just checked, using triton==2.0.0.dev20220430 -- I can drop down and test!

Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend?

blefaudeux commented 2 years ago

Just in case @jramapuram, could you check that you're using triton == 2.0.0.dev20220403 ? I'm bumping into numerical stability issues with newer versions of triton (being worked on) and the symptoms are very similar (accuracy crash but not NaNs)

Thanks for keeping this in mind @blefaudeux. Just checked, using triton==2.0.0.dev20220430 -- I can drop down and test!

Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend?

20220430 was fine, the ones after that were broken, but fixed by https://github.com/openai/triton/commit/205a493b10a5112ec1fccdbe9d59fe9f172e027d so it's back to being good at the moment ! re-minGPT I can check the other metrics, as mentioned in another thread I think that it may be due to the distribution being hardcoded right now for deepnorm, I think it's not very readable, hackable, and not a great design overall, I'd like to come up with something better and more explicit (for instance with a couple of possible inits as part of the xformers config, and deepnorm respecting that). It's always possible to init from the outside, but it's tied to parameter naming conventions (not super clear right now), and it kind of negates the point of supporting deepnorm to begin with I think

jramapuram commented 2 years ago

Unfortunately no joy @blefaudeux. I tried:

  1. Triton downgrade + pos + cls init
  2. Triton downgrade + pos + cls + weight init
  3. triton==2.0.0.dev20220430 + pos + cls init
  4. triton==2.0.0.dev20220430 + pos + cls + weight init
image
blefaudeux commented 2 years ago

thanks a bunch @jramapuram ! I've a draft PR getting ready which rewrites a lot of the input projections (something we discussed earlier) + explicit handling of a couple of init methods (optional, users are still free to do as they please), I'm hoping that it solves this. To give an insight, I think that this setting is not well handled and could be the culprit (deepnorm assumes a different projection per Q/K/V, and the default here should probably be "true" I believe)

blefaudeux commented 2 years ago

I think that #312 is getting there @jramapuram, it's a lot cleaner to my eyes. Something I've seen, related to your curves above, is that it's not just deepnorm, the post- normalization path does not play well with ViT. GPT is fine with this nornalization path, I don't know if it's a known fact, I would need to check the literature. Since deepnorm is a subset of the post- normalization code path, it makes a little more sense, or at least it's not alone

blefaudeux commented 2 years ago

ok, beyond #312 which cleans things up, it looks like (given Timm, here) layernorm requires a specific treatment for ViT+Post, the weight is initialized to a very small value (vs. 1 typically). Since in our case Post & Deepnorm (same residual codepath) both fail with ViT but work well with GPT, it could explain why. I'll give that a shot

blefaudeux commented 2 years ago

I've not forgotten that @jramapuram, turns out that for vision / post norm Swin v2 already solved this (related to the message above), see their paper. The initial weights need to be scaled way down, I'll try to implement this in xformers when I get the time