huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
32.3k stars 4.76k forks source link

[BUG] ViT finetuning eval accuracy is too high running on TPU (bits_and_tpu branch) #960

Closed eliahuhorwitz closed 3 years ago

eliahuhorwitz commented 3 years ago

Hey, I've been finetuning ViT on different datasets (cifar100, oxford_pets, etc.). I am using Google TRC TPUs, specifically V3 VM using the bits_and_tpu branch. I have found the results of finetuning to be odd, specifically, on CIFAR100 I am seeing the eval top1 accuracy reaching 94.19 within 17 epochs (I even had 1 run get to 94.44), these numbers are closer to JFT300 results and not ImageNet21K results. From the original ViT paper below they get 93.04 on a similar setup to mine and from the google research github repo also attached below the get 93.29. Even more surprising to me is the fact I get the 94.x results when I turn off the image augmentations. CleanShot 2021-11-07 at 20 14 09 CleanShot 2021-11-07 at 20 17 29

To try and ensure I didn't introduce a bug into the codebase, I cloned a new copy of the repo and performed tests aginst it. I start finetunning with: python3 launch_xla.py --num-devices 1 finetune.py ~/tensorflow_datasets --dataset tfds/cifar100:3.0.2 --opt sgd --epochs 1000 --workers 1 --val-split test --mixup 0 --cutmix 0 --opt-eps=1e-08 --train-interpolation=bicubic --warmup-lr=1e-06 --lr 0.004 -b 128 --num-classes 100 --model vit_large_patch32_384

and my finetune.py file is just a copy of the train script with a change in the way I create the mode, that is I comment out this

 # model = create_model(
    #     args.model,
    #     pretrained=args.pretrained,
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    #     global_pool=args.gp,
    #     bn_tf=args.bn_tf,
    #     bn_momentum=args.bn_momentum,
    #     bn_eps=args.bn_eps,
    #     scriptable=args.torchscript,
    #     checkpoint_path=args.initial_checkpoint)

and instead put this model = timm.create_model(args.model, pretrained=True, num_classes=args.num_classes)

The full script is below:

#!/usr/bin/env python3
""" ImageNet Training Script

This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.

This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import time
import yaml
import os
import logging
from collections import OrderedDict
from datetime import datetime
from dataclasses import replace
from typing import Tuple

import torch
import torch.nn as nn
import torchvision.utils

from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\
    TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\
    PreprocessCfg, AugCfg, MixupCfg, AugMixDataset
from timm.models import create_model, safe_model_name, convert_splitbn_model
from timm.loss import *
from timm.optim import optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model
import timm
_logger = logging.getLogger('train')

# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Dataset / Model parameters
parser.add_argument('data_dir', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
                    help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
                    help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
                    help='dataset validation split (default: validation)')
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
                    help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
                    help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                    help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
                    help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
                    help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
                    help='Image patch size (default: None => model default)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
                    metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
                    help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
                    help='validation batch size override (default: None)')

# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                    help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                    help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                    help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
                    help='weight decay (default: 0.0001)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                    help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
                    help='Gradient clipping mode. One of ("norm", "value", "agc")')

# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                    help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.05)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                    help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                    help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                    help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                    help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
                    help='amount to decay each learning rate cycle (default: 0.5)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                    help='learning rate cycle limit, cycles enabled if > 1')
parser.add_argument('--lr-k-decay', type=float, default=1.0,
                    help='learning rate k-decay for cosine/poly (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
                    help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=300, metavar='N',
                    help='number of epochs to train (default: 300)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
                    help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
                    help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
                    help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                    help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                    help='LR decay rate (default: 0.1)')

# Augmentation & regularization parameters
parser.add_argument('--num-aug-repeats', type=int, default=3, metavar='N',
                    help='number of repeated augmentations (default: 3)')

parser.add_argument('--no-aug', action='store_true', default=False,
                    help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                    help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
                    help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
                    help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
                    help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                    help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
                    help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd-loss', action='store_true', default=False,
                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False,
                    help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--bce-target-thresh', type=float, default=None,
                    help='Threshold for binarizing softened BCE targets (default: None, disabled)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
                    help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='pixel',
                    help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
                    help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
                    help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.0,
                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
                    help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                    help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
                    help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
                    help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
                    help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                    help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                    help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                    help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                    help='Drop block rate (default: None)')

# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
                    help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
                    help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='reduce',
                    help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
                    help='Enable separate BN layers per augmentation split.')

# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
                    help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
                    help='decay factor for model weights moving average (default: 0.9998)')

# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                    help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
                    help='number of checkpoints to keep (default: 10)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
                    help='how many training processes to use (default: 1)')
parser.add_argument('--save-images', action='store_true', default=False,
                    help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--channels-last', action='store_true', default=False,
                    help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.'
                    help='name of train experiment, name of sub-folder for output')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
                    help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
                    help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
                    help='convert model torchscript for inference')
parser.add_argument('--force-cpu', action='store_true', default=False,
                    help='Force CPU to be used even if HW accelerator exists.')
parser.add_argument('--log-wandb', action='store_true', default=False,
                    help='log training and validation metrics to wandb')

def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text

def main():
    setup_default_logging()
    args, args_text = _parse_args()

    dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp, channels_last=args.channels_last)
    if dev_env.distributed:
        _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
                     % (dev_env.global_rank, dev_env.world_size))
    else:
        _logger.info('Training with a single process on 1 device.')

    random_seed(args.seed, 0)  # Set all random seeds the same for model/state init (mandatory for XLA)

    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'

    train_state = setup_train_task(args, dev_env, mixup_active)
    train_cfg = train_state.train_cfg

    # Set random seeds across ranks differently for train
    # FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS?
    random_seed(args.seed, dev_env.global_rank)

    data_config, loader_eval, loader_train = setup_data(
        args,
        unwrap_model(train_state.model).default_cfg,
        dev_env,
        mixup_active)

    # setup checkpoint manager
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    checkpoint_manager = None
    output_dir = None
    if dev_env.primary:
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
        checkpoint_manager = CheckpointManager(
            hparams=vars(args),
            checkpoint_dir=output_di
    try:
        for epoch in range(train_state.epoch, train_cfg.num_epochs):
            if dev_env.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)
            if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
                if loader_train.mixup_enabled:
                    loader_train.mixup_enabled = False

            train_metrics = train_one_epoch(
                state=train_state,
                services=services,
                loader=loader_train,
                dev_env=dev_env,
            )

            if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if dev_env.primary:
                    _logger.info("Distributing BatchNorm running means and vars")
                distribute_bn(train_state.model, args.dist_bn == 'reduce', dev_env)

            eval_metrics = evaluate(
                train_state.model,
                train_state.eval_loss,
                loader_eval,
                services.monitor,
                dev_env)

            if train_state.model_ema is not None:
                if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env)

                ema_eval_metrics = evaluate(
                    train_state.model_ema.module,
                    train_state.eval_loss,
                    loader_eval,
                    services.monitor,
                    dev_env,
                    phase_suffix='EMA')
                eval_metrics = ema_eval_metrics

            if train_state.lr_scheduler is not None:
                # step LR for next epoch
                train_state.lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if services.monitor is not None:
                services.monitor.write_summary(
                    index=epoch,
                    results=dict(train=train_metrics, eval=eval_metrics))

            if checkpoint_manager is not None:
                # save proper checkpoint with eval metric
                best_checkpoint = checkpoint_manager.save_checkpoint(train_state, eval_metrics)
                best_metric, best_epoch = best_checkpoint.sort_key, best_checkpoint.epoch

            train_state = replace(train_state, epoch=epoch + 1)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):

    # model = create_model(
    #     args.model,
    #     pretrained=args.pretrained,
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    #     global_pool=args.gp,
    #     bn_tf=args.bn_tf,
    #     bn_momentum=args.bn_momentum,
    #     bn_eps=args.bn_eps,
    #     scriptable=args.torchscript,
    #     checkpoint_path=args.initial_checkpoint)
    model = timm.create_model(args.model, pretrained=True, num_classes=args.num_classes)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = mod
    # FIXME move into updater?
    lr_scheduler, num_epochs = create_scheduler(args, train_state.updater.optimizer)
    if lr_scheduler is not None and train_state.epoch > 0:
        lr_scheduler.step(train_state.epoch)

    # setup loss function
    if args.jsd_loss:
        assert args.aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    eval_loss_fn = nn.CrossEntropyLoss()
    dev_env.to_device(train_loss_fn, eval_loss_fn)

    if dev_env.primary:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    train_cfg = TrainCfg(
        num_epochs=num_epochs,
        log_interval=args.log_interval,
        recovery_interval=args.recovery_interval,
    )

    train_state = replace(
        train_state,
        lr_scheduler=lr_scheduler,
        train_loss=train_loss_fn,
        eval_loss=eval_loss_fn,
        train_cfg=train_cfg,
    )

    return train_state

def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool):
    data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)

    # create the train and eval datasets
    dataset_train = create_dataset(
        args.dataset,
        root=args.data_dir, split=args.train_split, is_training=True,
        batch_size=args.batch_size, repeats=args.epoch_repeats)

    dataset_eval = create_dataset(
        args.dataset,
        root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)

    # setup mixup / cutmix
    mixup_cfg = None
    if mixup_active:
        mixup_cfg = MixupCfg(
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            label_smoothing=args.smoothing, num_classes=args.num_classes)

    # wrap dataset in AugMix helper
    if args.aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=args.aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']

    if args.no_aug:
        train_aug_cfg = None
    else:
        train_aug_cfg = AugCfg(
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            ratio_range=args.rat
    dataset_eval.transform = create_transform_v2(
        cfg=eval_pp_cfg, is_training=False, normalize=normalize_in_transform)

    eval_workers = args.workers
    if 'tfds' in args.dataset:
        # FIXME reduce validation issues when using TFDS w/ workers and distributed training
        eval_workers = min(2, args.workers)
    loader_eval = create_loader_v2(
        dataset_eval,
        batch_size=args.validation_batch_size or args.batch_size,
        is_training=False,
        normalize=not normalize_in_transform,
        pp_cfg=eval_pp_cfg,
        num_workers=eval_workers,
        pin_memory=args.pin_mem,
    )
    return data_config, loader_eval, loader_train

def train_one_epoch(
        state: TrainState,
        services: TrainServices,
        loader,
        dev_env: DeviceEnv,
):
    tracker = Tracker()
    loss_meter = AvgTensor()  # FIXME move loss meter into task specific TaskMetric

    state.model.train()
    state.updater.reset()  # zero-grad

    step_end_idx = len(loader) - 1
    tracker.mark_iter()
    for step_idx, (sample, target) in enumerate(loader):
        tracker.mark_iter_data_end()

        # FIXME move forward + loss into model 'task' wrapper
        with dev_env.autocast():
            output = state.model(sample)
            loss = state.train_loss(output, target)

        state.updater.apply(loss)

        tracker.mark_iter_step_end()

        state.updater.after_step(
            after_train_step,
            state,
            services,
            dev_env,
            step_idx,
            step_end_idx,
            tracker,
            loss_meter,
            (output, target, loss),
        )

        tracker.mark_iter()
        # end for

    if hasattr(state.updater.optimizer, 'sync_lookahead'):
        state.updater.optimizer.sync_lookahead()

    return OrderedDict([('loss', loss_meter.compute().item())])

def after_train_step(
        state: TrainState,
        services: TrainServices,
        dev_env: DeviceEnv,
        step_idx: int,
        step_end_idx: int,
        tracker: Tracker,
        loss_meter: AvgTensor,
        tensors: Tuple[torch.Tensor, ...],
):
    """
    After the core loss / backward / gradient apply step, we perform all non-gradient related
    activities here including updating meters, metrics, performing logging, and writing checkpoints.

    Many / most of these operations require tensors to be moved to CPU, they shoud not be done
    every step and for XLA use they should be done via the optimizer step_closure. This function includes

            loss_avg = loss_meter.compute()
            if services.monitor is not None:
                lr_avg = state.updater.get_average_lr()
                services.monitor.log_step(
                    'Train',
                    step=step_idx,
                    step_end=step_end_idx,
                    epoch=state.epoch,
                    loss=loss_avg.item(),
                    rate=tracker.get_avg_iter_rate(global_batch_size),
                    lr=lr_avg,
                )

        if services.checkpoint is not None and cfg.recovery_interval and (
                end_step or (step_idx + 1) % cfg.recovery_interval == 0):
            services.checkpoint.save_recovery(state.epoch, batch_idx=step_idx)

        if state.lr_scheduler is not None:
            # FIXME perform scheduler update here or via updater after_step call?
            state.lr_scheduler.step_update(num_updates=state.step_count_global)

def evaluate(
        model: nn.Module,
        loss_fn: nn.Module,
        loader,
        logger: Monitor,
        dev_env: DeviceEnv,
        phase_suffix: str = '',
        log_interval: int = 10,
):

    tracker = Tracker()
    losses_m = AvgTensor()
    accuracy_m = AccuracyTopK()  # FIXME move loss and accuracy modules into task specific TaskMetric obj

    model.eval()

    end_idx = len(loader) - 1
    tracker.mark_iter()
    with torch.no_grad():
        for step_idx, (sample, target) in enumerate(loader):
            tracker.mark_iter_data_end()
            last_step = step_idx == end_idx

            with dev_env.autocast():
                output = model(sample)
                if isinstance(output, (tuple, list)):
                    output = output[0]
                loss = loss_fn(output, target)

            # FIXME, explictly marking step for XLA use since I'm not using the parallel xm loader
            # need to investigate whether parallel loader wrapper is helpful on tpu-vm or only use for 2-vm setup.
            if dev_env.type_xla:
                dev_env.mark_step()
            elif dev_env.type_cuda:
                dev_env.synchronize()

            # FIXME uncommenting this fixes race btw model `output`/`loss` and loss_m/accuracy_m meter input
            # for PyTorch XLA GPU use.
            # This issue does not exist for normal PyTorch w/ GPU (CUDA) or PyTorch XLA w/ TPU.
            # loss.item()

            tracker.mark_iter_step_end()
            losses_m.update(loss, output.size(0))
            accuracy_m.update(output, target)

            if last_step or step_idx % log_interval == 0:
                top1, top5 = accuracy_m.compute().values()
                loss_avg = losses_m.compute()
                logger.log_step(
                    'Eval',
                    step=step_idx,
                    step_end=end_idx,
                    loss=loss_avg.item(),
                    top1=top1.item(),
                    top5=top5.item(),
                    phase_suffix=phase_suffix,
                )
            tracker.mark_iter()

    top1, top5 = accuracy_m.compute().values()
    results = OrderedDict([('loss', losses_m.compute().item()), ('top1', top1.item()), ('top5', top5.item())])
    return results

def _mp_entry(*args):
    main()

if __name__ == '__main__':
    main()

Here is the summary of the above output (I stopped it once I saw it is too high)

epoch,train_loss,eval_loss,eval_top1,eval_top5
0,4.732754230499268,4.729395389556885,0.8700000047683716,4.989999771118164
1,3.210913896560669,1.154198408126831,85.3699951171875,97.27999877929688
2,1.6976295709609985,0.4715765118598938,90.56999969482422,98.86000061035156
3,1.5128341913223267,0.3998292088508606,91.66999816894531,99.20999908447266
4,1.4772536754608154,0.370217889547348,92.33999633789062,99.32999420166016
5,1.4140307903289795,0.3580523431301117,92.80999755859375,99.36000061035156
6,1.390270709991455,0.34456485509872437,93.0199966430664,99.37999725341797
7,1.3623195886611938,0.3357977569103241,93.36000061035156,99.39999389648438
8,1.3307034969329834,0.33426693081855774,93.14999389648438,99.43999481201172
9,1.307023048400879,0.3217673897743225,93.47000122070312,99.45999908447266
10,1.3035824298858643,0.32201898097991943,93.66999816894531,99.48999786376953
11,1.2851903438568115,0.329518586397171,93.41999816894531,99.38999938964844
12,1.2727124691009521,0.32014748454093933,93.66999816894531,99.43999481201172
13,1.2688237428665161,0.31492725014686584,93.88999938964844,99.45999908447266
14,1.2594046592712402,0.3136151432991028,93.95999908447266,99.44999694824219
15,1.2442022562026978,0.3131980299949646,93.65999603271484,99.45999908447266
16,1.2306550741195679,0.3129279613494873,93.72999572753906,99.41999816894531
17,1.2250698804855347,0.31124258041381836,94.19999694824219,99.47000122070312
18,1.2192376852035522,0.3087320327758789,94.15999603271484,99.50999450683594
19,1.2128868103027344,0.3063335418701172,94.15999603271484,99.5
20,1.1995835304260254,0.307146817445755,94.06999969482422,99.43000030517578
21,1.2054955959320068,0.30594122409820557,94.08999633789062,99.5

And here is a graph of a similar run with slightly different hyperparams which I let run for longer (it reached 94.44!!!) CleanShot 2021-11-07 at 20 31 07

I've made sure to start a clean machine for this, with a fresh download of cifar100 from TFDS, and of course, a fresh clone of the codebase.

The above results also make me completely doubt the results I have been getting for my own models that use this codebase/pretrained models. I am working now on trying to reproduce this on a GPU, but I don't have access to the same amount of compute so this is going to be more challenging.

Am I somehow missing something or doing something wrong in the fine-tuning script? Could these be real results? Or do you think there is some bug in the XLA/TPU side of things?

Do you have any recommendations as to where should I start looking for a solution?

Thanks, Eliahu

rwightman commented 3 years ago

@eliahuhorwitz looks like you are doing single node, single worker training so this isn't likely a concern, but you should be aware if you do distributed train, you should always confirm your validation results on a single node with the validation script afterwards, the distributed validation results will be a bit different due to padding of the batch, etc.

You can always double check the sanity of the bits and tpu results by using the same checkpoints and validating on a GPU with the the master branch, that is well tested.

For the official CIFAR training the vit authors used 98% of the train split https://github.com/google-research/vision_transformer/blob/main/vit_jax/configs/common.py#L93 ... unfortunately you can split 98 and then split again for multi-node train due to limitations in splitting up the samples for even distribution across distributed nodes right now.

The other factor is that the default weights are the ImageNet-21k 300epoch variants from the 'How to train your ViT' paper, not the original, 94.1 is the CIFAR-100 result for that paper for L/16 and 93.2 for B/16, L32 wasn't used, but the R50+L/32 hybrid had 93.9. Augmentation was off for the transfer runs in that paper.

One of the main observations in that paper was that when pre-training with higher augmentation + regularization w/ vision transformers, the results roughly match using an order of magnitude more data ... so in1k -> 21k and 21k -> jft300m as compared to the original paper.. thus your resutls aren't that crazy.

eliahuhorwitz commented 3 years ago

Hey @rwightman, thanks for the quick response!

you should always confirm your validation results on a single node with the validation script afterwards

Generally speaking, I do use distributed training. What would I need to do in order to change the validation to work on a single node? Also, would this still be needed for larger datasets (i.e ImageNet1k?)

You can always double check the sanity of the bits and tpu results by using the same checkpoints and validating on a GPU with the the master branch, that is well tested.

I'll try this! Is it also possible to move the validation to the CPU and run it on the master code? That way I can incorporate it into my training rather than manually switching machines?

unfortunately you can split 98 and then split

I'm assuming you meant "unfortunately you can't split 98 and then split"? In any case, this 2% split shouldn't make that much of a difference, right?

The other factor is that the default weights are the ImageNet-21k 300epoch variants from the 'How to train your ViT' paper, not the original, 94.1 is the CIFAR-100 result for that paper for L/16 and 93.2 for B/16, L32 wasn't used, but the R50+L/32 hybrid had 93.9. Augmentation was off for the transfer runs in that paper.

I was sure I was using the original ImageNet 30 epochs used in ViT, not the one from 'How to train your ViT'. Furthermotr, when I looked at 'How to train your ViT, I assumed these the VTAB results are on VTAB 1k, not the entire VTAB dataset 😓 (and hence I should compared them to the ones from the original ViT paper for example as seen below) Also, is it possible to get the hyperparams used to train this (and other "best models" as you refer to them in the paper)? And do you happen to recall roughly how long would it take to train ViT-B/32 or L/32 on ImageNet1k and ImageNet21k? CleanShot 2021-11-07 at 23 57 12

The other factor is that the default weights are the ImageNet-21k 300epoch variants from the 'How to train your ViT' paper, not the original, 94.1 is the CIFAR-100 result for that paper for L/16 and 93.2 for B/16, L32 wasn't used, but the R50+L/32 hybrid had 93.9. Augmentation was off for the transfer runs in that paper.

Wouldnt 94.44 still be too high for using just the above script? This puts it almost on par with ViT-H/14 and almost at 2nd place for CIFAR100 SOTA.

So just to sum up, apart from maybe some small difference based on distributed validation, you don't think these high result are due to bug or some faulty handling of TPU training on my end?

Thanks again for the detailed and blazing fast response!

rwightman commented 3 years ago

@eliahuhorwitz it is on the high side for L32, and yeah the L32 weights should still be the older 21k ones since there were no L32 weights for the new paper.

However, some of the L16 results for the new paper are well past 95 test accuracy in the index.csv in so it's not completely insane for good runs to be possible, you can explore all the detail of the pretrain hparams + transfer looking at the index csv used in the notebook linked below. The R50+L32 CIFAR100 highest test accuracy is listed ast 94.6 ...

I'm also not 100% clear why the VTAB table is different, if those are from a different train/val split than the index.csv transfer rseults. Lots of variables.

https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb

Yeah, using the :98% split, and then splitting that across N distributed nodes doesn't work right now.

Doing val on one node can be challenging, esp on TPU where all nodes need to do the same thing, I've tried setting up barriers and can be a bit flakey with hangs or timeouts on the idle nodes. Could probably run the same eval on all of them and throw the results out for all but rank 0. In either case, it's easier to just verify if there are any problems but running validation in isolation on a single CPU or GPU device to confirm no problems with your code or my bits and tpu code for this case. I've certainly used it a lot for imagenet and larger.

rwightman commented 3 years ago

@eliahuhorwitz Looking at this while fixing #961, there is no problem at all here, the validation is correct. None of the problematic situations I was concerned about can arise with num_devices=1 and workers=1 , and even if they did, the impact would be smaller than the diffefrence between what you expected and what you got.

I ran your train command with default train script, tfds dataset, reproduced your accuracy numbers per epoch exactly. Nice to see the reproducability.

I exported just the state dict (clean_checkpoint.py) so it's loadable in the master branch, and ran on GPU with torch cifar and tfds and both were within .02 of the original train val. 94.17 vs 94.19 which is well within expected varations across hardware types.

So the reason, most likely that you are fine tuning from the in1k fine tuned 384x384 weights and not the 224x224 in21k weights as in the paper. You can do the later, use vit_large_patch32_224_in21k but you do have to hand code an img_size=384 into the create_model since only the vit and mlp models accept and img_size arg, it will then interpolate pos embed for you.

Another possibility is that the authors were using 98% of train split for train, :2% for eval (checkpoint selection), and then final test was evaluated once. It wasn't clear from the first paper if that was the case or not. But I know this group likes to avoid doing eval on the test and sometimes make their own val splits when there are none.

When I was working with the vit authors on the How to train your vit paper, I pointed out that I'd found transfer results were better from the 21k-1k->other target dataset than from 21k->other dataset. It wasn't explored much as it wasn't the focus of that paper.

I'm closing this and will move on to fixing #961 with the checkpoints I have and will add some tweaks for the split handling in tfds and torch dataset support for this and a few other datasets.

eliahuhorwitz commented 3 years ago

@rwightman Terrific, thanks! Also, it may be worth addining some clarity in the README or documentation regarding the different ViT checkpoints and how they may differ drastically from the figures in the paper. I started by reproducing the numbers from the paper, and once I got them I started changing things comparing my result to that baseline. If you want, I am happy to write something and open a PR

rwightman commented 3 years ago

@eliahuhorwitz It'd be helpful to have some info, yes but I'm not exactly sure where it'd go (to be obvious), and how to write it to provide clarity instead of more confusion. It feels like a non-trivial effort to fully cover the different checkpoints (there are those in timm as the defaults and also many that exist outside of timm that can be loaded via .npz files from the google repo). I've got a million other tasks to get through so not a high priority for me at the moment. Open to a PR if adds clarity, but don't know quite where to put it...