NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.32k stars 1.39k forks source link

torch.cuda.amp > apex.amp #818

Open mcarilli opened 4 years ago

mcarilli commented 4 years ago

For a while now my main focus has been moving mixed precision functionality into Pytorch core. It was merged about a month ago: https://pytorch.org/docs/master/amp.html https://pytorch.org/docs/master/notes/amp_examples.html and is now usable via master or nightly pip/conda packages. (Full features did not make the 1.5 release, unfortunately.)

torch.cuda.amp is more flexible and intuitive, and the native integration brings more future optimizations into scope. Also, torch.cuda.amp fixes many of apex.amp's known pain points. Some things native amp can handle that apex amp can't:

If all you want is to try mixed precision, and you're comfortable using a recent Pytorch, you don't need Apex.

Damiox commented 4 years ago

Can torch.cuda.amp be used only for inferences on a FP32 model? See https://github.com/NVIDIA/apex/issues/750 and https://github.com/NVIDIA/apex/issues/809 I couldn't find an example in https://pytorch.org/docs/master/notes/amp_examples.html Maybe just wrapping up the model call with https://pytorch.org/docs/master/amp.html#torch.cuda.amp.autocast ?

mcarilli commented 4 years ago

Yes. torch.cuda.amp.autocast can be enabled wherever you want and affects only ops invoked within enabled regions. autocast and torch.cuda.amp.GradScaler are modular codewise. During training, you should use both (autocast selects per-op precision, and GradScaler scales gradients) but for inference, GradScaler is not necessary, and you can use autocast by itself. Also, the model does not need to be altered to use autocast for (regions of) inference forward passes (model leaves may be FP32).

vince62s commented 4 years ago

Before we dive into torch.cuda.amp, should we expect a behavior change versus this issue https://github.com/NVIDIA/apex/issues/475 ? thanks.

Damiox commented 4 years ago

@mcarilli what about the opt_level O1 / O2 , etc... I can't find whether that's already natively supported by torch.cuda.amp - it looks like there's no opt_level option in torch.cuda.amp ? If so, what's the opt_level being used by default when using autocast?

Damiox commented 4 years ago

Another question: Will this be supported by Torchscript?

blizda commented 4 years ago

How can I migrate from apex.amp to torch.cuda.amp if I already have pre-trained model with apex wrapper? Apex-wrapped models now can load like regular PyTorch models?

mcarilli commented 4 years ago

@Damiox torch.cuda.amp.autocast is similar to O1 in that it casts function inputs on the fly without touching model weights. However, unlike apex O1, autocast only causes casting behavior in regions the context manager is explicitly enabled. Disabled regions can also be nested in enabled regions:

with autocast():
    # autocasted ops
    with autocast(enabled=False):
        # ops that run in input dtypes as usual

autocast should work with jit tracing, if you run your tracing pass under autocast, because the trace will record the casts. I don't think it works with scripting yet, we need to make sure scripting properly parses Python context managers.

@blizda If you have a model state dict from any source saved on disk, you shouldn't need to do anything special to migrate to native Amp. Create a model in default precision (fp32), call model.load_state_dict(saved_dict), and begin training as shown in the native amp examples. autocast does not touch the model at all, it only affects op exection, and GradScaler is self-contained, it doesn't alter model or optimizer structure.

After migrating to native amp, for bitwise accurate saving/restoring, include calls to saved = scaler.state_dict() and scaler.load_state_dict(saved) along side your usual state_dict/load_state_dict calls.

mcarilli commented 4 years ago

@vince62s apex.optimizers.FusedAdam and torch.optim.Adam should both work out of the box with native Amp following the documented control flow (create model in default precision aka fp32). If you also need gradient clipping, see the example.

However, there may be a problem with apex.optimizers.FusedAdam that we never bottomed out on. I'm not sure what it could be because we use it internally and it works. If apex.optimizers.FusedAdam does not improve end to end performance vs torch.optim.Adam, definitely prefer torch.optim.Adam.

apex.contrib.optimizers.FusedAdam I don't believe will work, because it takes control of gradient scaling in an incompatible way. Frankly idk what's using that at this point.

ysystudio commented 4 years ago

seem like the model after training using torch.cuda.amp's autocast(), its dtype is fp32, if want to deploy the model , dose it need covert to fp16 manualy? it is little bit confuse.

trytolose commented 4 years ago

@mcarilli It's clear how to switch with o1, but how I can use o2 optimization with torch.cuda.amp?

mcarilli commented 4 years ago

@ysystudio Autocast does not touch the model object itself, so its dtype (param type) remains as you created it (leaving it to default FP32 is recommended). Save the trained model then deploy it in whatever format you want.

@trytolose O2 isn't a thing in torch.cuda.amp. O2 is more brittle (does not make any per-op casting decisions) so it isn't fit for upstream (or anyone, tbh). We are identifying performance opportunities that don't endanger convergence and upstreaming them gradually. Please prefer native amp for stability and future-proofing, and the native amp implementation will get faster as you update pytorch without changes to your network. We already observe torch.cuda.amp is often faster than apex O1 due to reduced python overhead.

SeungjunNah commented 4 years ago

@mcarilli Thanks for a great job in mixed-precision! I'm trying both apex.amp and torch.cuda.amp and both of them turn out to be effective in terms of memory reduction and speed improvements. But currently I see torch.cuda.amp.GradScaler is a bit limited compared to apex. For example, in apex, we can set the max_loss_scale at amp.initialize() but I don't find such feature in GradScaler. Also, there are many other possible options in apex.amp that are not currently supported in torch.cuda.amp. Will they be implemented in torch.cuda.amp?

mcarilli commented 4 years ago

@SeungjunNah The options available in native Amp are a better representation of what users should control. Some apex options, like opt-level O2, are unsafe for general use. If an option is present in apex amp but not present in native amp, it's probably not an important knob for the user to experiment with, therefore including it would make the API more cluttered and confusing. For example, I'm not aware of any network where setting max_loss_scale was required for convergence. If you have evidence that max_loss_scale is required, I can add it.

In general, torch.cuda.amp tries to add support for use cases that people complained were unsupported in Apex, and hide options that people should not or did not care about.

SeungjunNah commented 4 years ago

@mcarilli I use max_loss_scale to avoid gradient explosion when training my models in this line of my repository. (loss scaling here) From my experiments, I know that gradients usually explode when the scale factor is 2048 for similar tasks, and setting the upper limit to 1024 would work. Otherwise, the amp would skip a gradient-overflowed batch every N intervals and I don't want to lose a batch during training if possible. (PyTorch master doc says optimizer.step() is skipped when inf/NaNs are found.)

A workaround could be to recompute the loss scaling until the overflow is avoided but I didn't find a way to implement it myself.

I'd appreciate if you could add max_loss_scale option to torch.cuda.amp.

mcarilli commented 4 years ago

amp would skip a gradient-overflowed batch every N intervals

That's true, but N is a large value (2000 by default). After the initial few iterations where GradScaler calibrates, it settles to a steady state where step skipping should only occur once every 2000 iterations (when it attempts a higher scale value). Generally, letting GradScaler dynamically find a steady state scale value is the best approach. Skipping one out of every 2000 iterations should have a negligible effect on both convergence and performance.

What you're suggesting is more like "static loss scaling": locking the scale to a user-defined value rather than letting GradScaler adjust it dynamically. This is also possible (though not recommended) with the native API without an additional max_loss_scale constructor arg: call scaler.update(1024.) instead of scaler.update() at the end of each iteration.

SeungjunNah commented 4 years ago

Ok, skipping with 1/2000 ratio doesn't hurt practically. I wanted to see if there were ways to control the number of iterations completely, though. Thanks for the explanation!

Quetzalcohuatl commented 4 years ago

@mcarilli I just watched a video that says you can used FusedAdam, FusedSGD, etc. for a faster optimizer when using amp. How do we use this in native Pytorch 1.6 with amp? Ty

Vincent717 commented 3 years ago

@mcarilli hi, thanks for you great work! In my task, comparing to opt-level O1, opt-level O2 can train faster yet has no damage on performance. So are there any workaround to support amp behavior like O2. Can I just cast the model weights to FP16 (except batch-norm and etc.) before training ? like

model = convert_most_weights_to_half(model)
with autocast():
        output = model(input)
        loss = loss_fn(output, target)
loss.backward()
optimizer.step()
ImMrMa commented 3 years ago

I can't find the example that test the performance in imagenet with torch.cuda.amp. In my case,I test the performance with nvidia's dali dataloader, imagenet, ResNet50 and torch.cuda.amp. But can only get the performance at ~0.68 in 90 epochs. Here is my code:


    import torch
    import torchvision.transforms as transforms
    from PIL import Image
    import io
    import argparse
    import os
    import random
    import shutil
    import time
    import warnings
    import torch.nn as nn
    import torch.nn.parallel
    import torch.backends.cudnn as cudnn
    import torch.distributed as dist
    import torch.optim
    import torch.multiprocessing as mp
    import torch.utils.data
    import torch.utils.data.distributed
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    import torchvision.models as models
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    from nvidia.dali.pipeline import Pipeline
    import nvidia.dali.ops as ops
    import nvidia.dali.types as types
    from torch.cuda.amp import GradScaler,autocast

    class HybridTrainPipe(Pipeline):
        def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
                    shard_id, num_shards, dali_cpu=False):
            super(HybridTrainPipe, self).__init__(
                batch_size, num_threads, device_id, seed=12 + device_id
            )
            self.input = ops.FileReader(
                file_root=data_dir,
                shard_id=shard_id,
                num_shards=num_shards,
                random_shuffle=True,
            )
            #let user decide which pipeline works him bets for RN version he runs
            dali_device = 'cpu' if dali_cpu else 'gpu'
            decoder_device = 'cpu' if dali_cpu else 'mixed'
            # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
            # without additional reallocations
            device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
            host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
            if dali_cpu:
                self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB)
            else:
                self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
                                                    device_memory_padding=device_memory_padding,
                                                    host_memory_padding=host_memory_padding,
                                                    random_aspect_ratio=[0.8, 1.25],
                                                    random_area=[0.1, 1.0],
                                                    num_attempts=100)
            self.res = ops.RandomResizedCrop(
                device=dali_device,
                size=[crop, crop],
                interp_type=types.INTERP_LINEAR,
                random_aspect_ratio=[0.75, 4.0 / 3.0],
                random_area=[0.08, 1.0],
                num_attempts=100,
            )
            self.cmnp = ops.CropMirrorNormalize(
                device="gpu",
                output_layout=types.NCHW,
                crop=(crop, crop),
                mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
            )
            self.coin = ops.CoinFlip(probability=0.5)
            print('DALI "{0}" variant'.format(dali_device))

        def define_graph(self):
            rng = self.coin()
            self.jpegs, self.labels = self.input(name="Reader")
            images = self.decode(self.jpegs)
            images = self.res(images)
            output = self.cmnp(images.gpu(), mirror=rng)
            labels_gpu=self.labels.gpu()
            return [output, labels_gpu]

    class HybridValPipe(Pipeline):
        def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
                    size, shard_id, num_shards):
            super(HybridValPipe, self).__init__(batch_size,
                                            num_threads,
                                                device_id,
                                                seed=12 + device_id)
            self.input = ops.FileReader(file_root=data_dir,
                                        shard_id=shard_id,
                                        num_shards=num_shards,
                                        random_shuffle=False,
                                        pad_last_batch=True)
            self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
            self.res = ops.Resize(device="gpu",
                                resize_shorter=size)
            self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                                dtype=types.FLOAT,
                                                crop=(crop, crop),
                                                output_layout=types.NCHW,
                                                mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                                std=[0.229 * 255,0.224 * 255,0.225 * 255])

        def define_graph(self):
            self.jpegs, self.labels = self.input(name="Reader")
            images = self.decode(self.jpegs)
            images = self.res(images)
            output = self.cmnp(images)
            return [output, self.labels.gpu()]

    model_names = sorted(name for name in models.__dict__
                        if name.islower() and not name.startswith("__")
                        and callable(models.__dict__[name]))

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR',
                        help='path to dataset')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                        choices=model_names,
                        help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('-j', '--workers', default=96, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=512, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                            'batch size of all GPUs on the current node when '
                            'using Data Parallel or Distributed Data Parallel')
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print-freq', default=30, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--world-size', default=-1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('--seed', default=None, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--gpu', default=None, type=int,
                        help='GPU id to use.')
    parser.add_argument('--multiprocessing-distributed', action='store_true',
                        help='Use multi-processing distributed training to launch '
                            'N processes per node, which has N GPUs. This is the '
                            'fastest way to use PyTorch for either single node or '
                            'multi node data parallel training')
    parser.add_argument('--dali_cpu', action='store_true',
                        help='dali_cpu')
    parser.add_argument('--amp', action='store_true',
                        help='dali_cpu')
    best_acc1 = 0

    def main():
        args = parser.parse_args()

        if args.seed is not None:
            random.seed(args.seed)
            torch.manual_seed(args.seed)
            cudnn.deterministic = True
            warnings.warn('You have chosen to seed training. '
                        'This will turn on the CUDNN deterministic setting, '
                        'which can slow down your training considerably! '
                        'You may see unexpected behavior when restarting '
                        'from checkpoints.')

        if args.gpu is not None:
            warnings.warn('You have chosen a specific GPU. This will completely '
                        'disable data parallelism.')

        if args.dist_url == "env://" and args.world_size == -1:
            args.world_size = int(os.environ["WORLD_SIZE"])

        args.distributed = args.world_size > 1 or args.multiprocessing_distributed

        ngpus_per_node = torch.cuda.device_count()
        if args.multiprocessing_distributed:
            # Since we have ngpus_per_node processes per node, the total world_size
            # needs to be adjusted accordingly
            args.world_size = ngpus_per_node * args.world_size
            # Use torch.multiprocessing.spawn to launch distributed processes: the
            # main_worker process function
            mp.spawn(main_worker, nprocs=ngpus_per_node,
                    args=(ngpus_per_node, args))
        else:
            # Simply call main_worker function
            main_worker(args.gpu, ngpus_per_node, args)

    def main_worker(gpu, ngpus_per_node, args):
        global best_acc1
        args.gpu = gpu

        if args.gpu is not None:
            print("Use GPU: {} for training".format(args.gpu))

        if args.distributed:
            if args.dist_url == "env://" and args.rank == -1:
                args.rank = int(os.environ["RANK"])
            if args.multiprocessing_distributed:
                # For multiprocessing distributed training, rank needs to be the
                # global rank among all the processes
                args.rank = args.rank * ngpus_per_node + gpu
            dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                    world_size=args.world_size, rank=args.rank)
        # create model
        if args.pretrained:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()
        # model=nn.DataParallel(model)
        if not torch.cuda.is_available():
            print('using CPU, this will be slow')
        elif args.distributed:
            # For multiprocessing distributed, DistributedDataParallel constructor
            # should always set the single device scope, otherwise,
            # DistributedDataParallel will use all available devices.
            if args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                model.cuda(args.gpu)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs we have
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int(
                    (args.workers + ngpus_per_node - 1) / ngpus_per_node)
                model = torch.nn.parallel.DistributedDataParallel(
                    model, device_ids=[args.gpu])
            else:
                model.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)
        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model = model.cuda(args.gpu)
        else:
            # DataParallel will divide and allocate batch_size to all available GPUs
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                model.features = torch.nn.DataParallel(model.features)
                model.cuda()
            else:
                model = torch.nn.DataParallel(model).cuda()

        # define loss function (criterion) and optimizer
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)

        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                if args.gpu is None:
                    checkpoint = torch.load(args.resume)
                else:
                    # Map model to be loaded to specified single gpu.
                    loc = 'cuda:{}'.format(args.gpu)
                    checkpoint = torch.load(args.resume, map_location=loc)
                    print(loc)
                args.start_epoch = checkpoint['epoch']
                best_acc1 = checkpoint['best_acc1']
                # if args.gpu is not None:
                #     # best_acc1 may be from a checkpoint from a different GPU
                #     best_acc1 = best_acc1.to(args.gpu)
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        cudnn.benchmark = True
        crop_size = 224
        val_size = 256
        # Data loading code

        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        pipe = HybridTrainPipe(batch_size=args.batch_size,
                            num_threads=args.workers,
                            device_id=args.rank,
                            data_dir=traindir,
                            crop=crop_size,
                            dali_cpu=args.dali_cpu,
                            shard_id=args.rank,
                            num_shards=args.world_size)
        pipe.build()
        train_loader = DALIClassificationIterator(pipe, reader_name="Reader")

        pipe = HybridValPipe(batch_size=args.batch_size,
                            num_threads=args.workers,
                            device_id=args.rank,
                            data_dir=valdir,
                            crop=crop_size,
                            size=val_size,
                            shard_id=args.rank,
                            num_shards=args.world_size)
        pipe.build()
        val_loader = DALIClassificationIterator(pipe, reader_name="Reader")

        if args.evaluate:
            validate(val_loader, model, criterion, args)
            return

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args)

            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch, args)

            # evaluate on validation set
            acc1 = validate(val_loader, model, criterion, args)
            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                        and args.rank % ngpus_per_node == 0):
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
            train_loader.reset()
            val_loader.reset()

    def to_python_float(t):
        if hasattr(t, 'item'):
            return t.item()
        else:
            return t[0]

    def train(train_loader, model, criterion, optimizer, epoch, args):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        train_loader_len=int(train_loader._size / args.batch_size)

        # switch to train mode
        model.train()
        if args.amp:
            scaler = GradScaler()
        end = time.time()
        for i, dict_data in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            images = dict_data[0]['data']
            target = dict_data[0]['label'].squeeze().long()

            # compute output
            if args.amp:
                with autocast():
                    output = model(images)
                    loss = criterion(output, target)
            else:
                output = model(images)
                loss = criterion(output, target)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            if args.amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            # measure elapsed time
            if i%args.print_freq == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

                # Average loss and accuracy across processes for logging
                if args.distributed:
                    reduced_loss = reduce_tensor(loss.data,args)
                    prec1 = reduce_tensor(prec1,args)
                    prec5 = reduce_tensor(prec5,args)
                else:
                    reduced_loss = loss.data

                # to_python_float incurs a host<->device sync
                losses.update(to_python_float(reduced_loss), images.size(0))
                top1.update(to_python_float(prec1), images.size(0))
                top5.update(to_python_float(prec5), images.size(0))

                torch.cuda.synchronize()
                batch_time.update((time.time() - end)/args.print_freq)
                end = time.time()

                if args.rank == 0:
                    print('Epoch: [{0}][{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Speed {3:.3f} ({4:.3f})\t'
                        'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch, i, train_loader_len,
                        args.world_size*args.batch_size/batch_time.val,
                        args.world_size*args.batch_size/batch_time.avg,
                        batch_time=batch_time,
                        loss=losses, top1=top1, top5=top5))

    def validate(val_loader, model, criterion, args):
        batch_time = AverageMeter('Time', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        val_loader_len=int(val_loader._size / args.batch_size)

        # switch to evaluate mode
        model.eval()

        with torch.no_grad():
            end = time.time()
            for i, dict_data in enumerate(val_loader):
                images = dict_data[0]['data']
                target = dict_data[0]['label'].squeeze().long()

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

                if args.distributed:
                    reduced_loss = reduce_tensor(loss.data,args)
                    prec1 = reduce_tensor(prec1,args)
                    prec5 = reduce_tensor(prec5,args)
                else:
                    reduced_loss = loss.data

                losses.update(to_python_float(reduced_loss), images.size(0))
                top1.update(to_python_float(prec1), images.size(0))
                top5.update(to_python_float(prec5), images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if args.rank == 0 and i % args.print_freq == 0:
                    print('Test: [{0}/{1}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Speed {2:.3f} ({3:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        i, val_loader_len,
                        args.world_size * args.batch_size / batch_time.val,
                        args.world_size * args.batch_size / batch_time.avg,
                        batch_time=batch_time, loss=losses,
                        top1=top1, top5=top5))

            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                .format(top1=top1, top5=top5))

        return top1.avg

    def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, 'model_best.pth.tar')

    class AverageMeter(object):
        """Computes and stores the average and current value"""

        def __init__(self, name, fmt=':f'):
            self.name = name
            self.fmt = fmt
            self.reset()

        def reset(self):
            self.val = 0
            self.avg = 0
            self.sum = 0
            self.count = 0

        def update(self, val, n=1):
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count

        def __str__(self):
            fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
            return fmtstr.format(**self.__dict__)

    class ProgressMeter(object):
        def __init__(self, num_batches, meters, prefix=""):
            self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
            self.meters = meters
            self.prefix = prefix

        def display(self, batch):
            entries = [self.prefix + self.batch_fmtstr.format(batch)]
            entries += [str(meter) for meter in self.meters]
            print('\t'.join(entries))

        def _get_batch_fmtstr(self, num_batches):
            num_digits = len(str(num_batches // 1))
            fmt = '{:' + str(num_digits) + 'd}'
            return '[' + fmt + '/' + fmt.format(num_batches) + ']'

    def adjust_learning_rate(optimizer, epoch, args):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        lr = args.lr * (0.1 ** (epoch // 30))
        print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def accuracy(output, target, topk=(1,)):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            return res
    def reduce_tensor(tensor,args):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= args.world_size
        return rt

    if __name__ == '__main__':
        main()
yl-guo-working commented 2 years ago

@mcarilli I found one case where we might need min_loss_scale. In my training with AMP, the first several iterations have NaN gradient quite often. Thus the first usable scaling value becomes 0.0325 (or something like that). Does a scaling value make sense?

lminer commented 2 years ago

"O2" is stable for me where "O1" and native amp give me NaNs. It would be really nice if there were some way to duplicate 02 behavior using native torch.cuda.amp. I've tried casting all batch norms to 32, but that didn't do it. So I guess something else is happening under the hood.