NVIDIA / retinanet-examples

Fast and accurate object detection with end-to-end GPU optimization
BSD 3-Clause "New" or "Revised" License
887 stars 271 forks source link

early stopping #346

Closed Egorundel closed 8 months ago

Egorundel commented 9 months ago

Hi! Can you tell me if there is such a function as early stopping in your repository? So that the model stops learning after the metrics do not improve over a certain number of iterations.

Egorundel commented 8 months ago

train.py with early stopping:

from statistics import mean
from math import isfinite
import torch
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import LambdaLR
from apex import amp, optimizers
from apex.parallel import DistributedDataParallel as ADDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler, autocast
from .backbones.layers import convert_fixedbn_model

from .data import DataIterator, RotatedDataIterator
from .dali import DaliDataIterator
from .utils import ignore_sigint, post_metrics, Profiler
from .infer import infer

def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations,
          val_iterations, lr, warmup, milestones, gamma, rank=0, world=1, mixed_precision=True, with_apex=False,
          use_dali=True, verbose=True, metrics_url=None, logdir=None, rotate_augment=False, augment_brightness=0.0,
          augment_contrast=0.0, augment_hue=0.0, augment_saturation=0.0, regularization_l2=0.0001, rotated_bbox=False,
          absolute_angle=False):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    # for early stopping
    best_mAP = 0
    count_it = 0
    patience = 3
    early_stop = False

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.to(memory_format=torch.channels_last).cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9)

    is_master = rank==0
    if with_apex:
        loss_scale = "dynamic" if use_dali else "128.0"
        model, optimizer = amp.initialize(model, optimizer,
                                        opt_level='O2' if mixed_precision else 'O0',
                                        keep_batchnorm_fp32=True,
                                        loss_scale=loss_scale,
                                        verbosity=is_master)

    if world > 1:
        model = DDP(model, device_ids=[rank]) if not with_apex else ADDP(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma ** len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)
    if 'scheduler' in state:
        scheduler.load_state_dict(state['scheduler'])

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    if rotated_bbox:
        if use_dali: raise NotImplementedError("This repo does not currently support DALI for rotated bbox detections.")
        data_iterator = RotatedDataIterator(path, jitter, max_size, batch_size, stride,
                                            world, annotations, training=True, rotate_augment=rotate_augment,
                                            augment_brightness=augment_brightness,
                                            augment_contrast=augment_contrast, augment_hue=augment_hue,
                                            augment_saturation=augment_saturation, absolute_angle=absolute_angle)
    else:
        data_iterator = (DaliDataIterator if use_dali else DataIterator)(
            path, jitter, max_size, batch_size, stride,
            world, annotations, training=True, rotate_augment=rotate_augment, augment_brightness=augment_brightness,
            augment_contrast=augment_contrast, augment_hue=augment_hue, augment_saturation=augment_saturation)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else 'GPU' if world == 1 else 'GPUs'))
        print('     batch: {}, precision: {}'.format(batch_size, 'mixed' if mixed_precision else 'full'))
        print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned')
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if is_master and logdir is not None:
        from torch.utils.tensorboard import SummaryWriter
        if verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    scaler = GradScaler(enabled=mixed_precision)
    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations and not early_stop:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            if iteration>=iterations:
                break

            # for early stopping
            if count_it >= patience:
                print("Early stopping at iteration:", iteration)
                early_stop = True
                break

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            if with_apex:
                cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target])
            else:
                with autocast(enabled=mixed_precision):
                    cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            if with_apex:
                with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
            else:
                scaler.scale(cls_loss + box_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean().clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60 or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if is_master and logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate, iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(metrics_url, {
                        'focal loss': mean(cls_losses),
                        'box loss': mean(box_losses),
                        'im_s': batch_size / profiler.means['train'],
                        'lr': learning_rate
                    })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations or iteration % val_iterations == 0):
                stats = infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations,
                            mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali,
                            with_apex=with_apex, is_validation=True, verbose=False, rotated_bbox=rotated_bbox)
                model.train()
                if is_master and logdir is not None and stats is not None:
                    writer.add_scalar(
                        'Validation_Precision/mAP', stats[0], iteration)
                    writer.add_scalar(
                        'Validation_Precision/mAP@0.50IoU', stats[1], iteration)
                    writer.add_scalar(
                        'Validation_Precision/mAP@0.75IoU', stats[2], iteration)
                    writer.add_scalar(
                        'Validation_Precision/mAP (small)', stats[3], iteration)
                    writer.add_scalar(
                        'Validation_Precision/mAP (medium)', stats[4], iteration)
                    writer.add_scalar(
                        'Validation_Precision/mAP (large)', stats[5], iteration)
                    writer.add_scalar(
                        'Validation_Recall/mAR (max 1 Dets)', stats[6], iteration)
                    writer.add_scalar(
                        'Validation_Recall/mAR (max 10 Dets)', stats[7], iteration)
                    writer.add_scalar(
                        'Validation_Recall/mAR (max 100 Dets)', stats[8], iteration)
                    writer.add_scalar(
                        'Validation_Recall/mAR (small)', stats[9], iteration)
                    writer.add_scalar(
                        'Validation_Recall/mAR (medium)', stats[10], iteration)
                    writer.add_scalar(
                        'Validation_Recall/mAR (large)', stats[11], iteration)

                    mAP = stats[0]

                    print("best_mAP before:", best_mAP)
                    print("mAP:", mAP)

                    if mAP > best_mAP:
                        best_mAP = mAP
                        count_it = 0
                    else:
                        count_it += 1
                    print("count_it:", count_it)
                    print("best_mAP after:", best_mAP)

            if (iteration==iterations and not rotated_bbox) or (iteration>iterations and rotated_bbox):
                break

        # for early stopping
        if early_stop:
            break

    if is_master and logdir is not None:
        writer.close()

patience is the number of validation "slices" (--val-iters) through which the model training stops. You need to install it in the quantity that you need.

If the mAP metric does not improve during the amount of patience, then the model stops learning.