Closed Egorundel closed 8 months ago
train.py with early stopping:
from statistics import mean
from math import isfinite
import torch
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import LambdaLR
from apex import amp, optimizers
from apex.parallel import DistributedDataParallel as ADDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler, autocast
from .backbones.layers import convert_fixedbn_model
from .data import DataIterator, RotatedDataIterator
from .dali import DaliDataIterator
from .utils import ignore_sigint, post_metrics, Profiler
from .infer import infer
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations,
val_iterations, lr, warmup, milestones, gamma, rank=0, world=1, mixed_precision=True, with_apex=False,
use_dali=True, verbose=True, metrics_url=None, logdir=None, rotate_augment=False, augment_brightness=0.0,
augment_contrast=0.0, augment_hue=0.0, augment_saturation=0.0, regularization_l2=0.0001, rotated_bbox=False,
absolute_angle=False):
'Train the model on the given dataset'
# Prepare model
nn_model = model
stride = model.stride
# for early stopping
best_mAP = 0
count_it = 0
patience = 3
early_stop = False
model = convert_fixedbn_model(model)
if torch.cuda.is_available():
model = model.to(memory_format=torch.channels_last).cuda()
# Setup optimizer and schedule
optimizer = SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9)
is_master = rank==0
if with_apex:
loss_scale = "dynamic" if use_dali else "128.0"
model, optimizer = amp.initialize(model, optimizer,
opt_level='O2' if mixed_precision else 'O0',
keep_batchnorm_fp32=True,
loss_scale=loss_scale,
verbosity=is_master)
if world > 1:
model = DDP(model, device_ids=[rank]) if not with_apex else ADDP(model)
model.train()
if 'optimizer' in state:
optimizer.load_state_dict(state['optimizer'])
def schedule(train_iter):
if warmup and train_iter <= warmup:
return 0.9 * train_iter / warmup + 0.1
return gamma ** len([m for m in milestones if m <= train_iter])
scheduler = LambdaLR(optimizer, schedule)
if 'scheduler' in state:
scheduler.load_state_dict(state['scheduler'])
# Prepare dataset
if verbose: print('Preparing dataset...')
if rotated_bbox:
if use_dali: raise NotImplementedError("This repo does not currently support DALI for rotated bbox detections.")
data_iterator = RotatedDataIterator(path, jitter, max_size, batch_size, stride,
world, annotations, training=True, rotate_augment=rotate_augment,
augment_brightness=augment_brightness,
augment_contrast=augment_contrast, augment_hue=augment_hue,
augment_saturation=augment_saturation, absolute_angle=absolute_angle)
else:
data_iterator = (DaliDataIterator if use_dali else DataIterator)(
path, jitter, max_size, batch_size, stride,
world, annotations, training=True, rotate_augment=rotate_augment, augment_brightness=augment_brightness,
augment_contrast=augment_contrast, augment_hue=augment_hue, augment_saturation=augment_saturation)
if verbose: print(data_iterator)
if verbose:
print(' device: {} {}'.format(
world, 'cpu' if not torch.cuda.is_available() else 'GPU' if world == 1 else 'GPUs'))
print(' batch: {}, precision: {}'.format(batch_size, 'mixed' if mixed_precision else 'full'))
print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned')
print('Training model for {} iterations...'.format(iterations))
# Create TensorBoard writer
if is_master and logdir is not None:
from torch.utils.tensorboard import SummaryWriter
if verbose:
print('Writing TensorBoard logs to: {}'.format(logdir))
writer = SummaryWriter(log_dir=logdir)
scaler = GradScaler(enabled=mixed_precision)
profiler = Profiler(['train', 'fw', 'bw'])
iteration = state.get('iteration', 0)
while iteration < iterations and not early_stop:
cls_losses, box_losses = [], []
for i, (data, target) in enumerate(data_iterator):
if iteration>=iterations:
break
# for early stopping
if count_it >= patience:
print("Early stopping at iteration:", iteration)
early_stop = True
break
# Forward pass
profiler.start('fw')
optimizer.zero_grad()
if with_apex:
cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target])
else:
with autocast(enabled=mixed_precision):
cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target])
del data
profiler.stop('fw')
# Backward pass
profiler.start('bw')
if with_apex:
with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
else:
scaler.scale(cls_loss + box_loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
# Reduce all losses
cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean().clone()
if world > 1:
torch.distributed.all_reduce(cls_loss)
torch.distributed.all_reduce(box_loss)
cls_loss /= world
box_loss /= world
if is_master:
cls_losses.append(cls_loss)
box_losses.append(box_loss)
if is_master and not isfinite(cls_loss + box_loss):
raise RuntimeError('Loss is diverging!\n{}'.format(
'Try lowering the learning rate.'))
del cls_loss, box_loss
profiler.stop('bw')
iteration += 1
profiler.bump('train')
if is_master and (profiler.totals['train'] > 60 or iteration == iterations):
focal_loss = torch.stack(list(cls_losses)).mean().item()
box_loss = torch.stack(list(box_losses)).mean().item()
learning_rate = optimizer.param_groups[0]['lr']
if verbose:
msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations)))
msg += ' focal loss: {:.3f}'.format(focal_loss)
msg += ', box loss: {:.3f}'.format(box_loss)
msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size)
msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(profiler.means['fw'], profiler.means['bw'])
msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train'])
msg += ', lr: {:.2g}'.format(learning_rate)
print(msg, flush=True)
if is_master and logdir is not None:
writer.add_scalar('focal_loss', focal_loss, iteration)
writer.add_scalar('box_loss', box_loss, iteration)
writer.add_scalar('learning_rate', learning_rate, iteration)
del box_loss, focal_loss
if metrics_url:
post_metrics(metrics_url, {
'focal loss': mean(cls_losses),
'box loss': mean(box_losses),
'im_s': batch_size / profiler.means['train'],
'lr': learning_rate
})
# Save model weights
state.update({
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
})
with ignore_sigint():
nn_model.save(state)
profiler.reset()
del cls_losses[:], box_losses[:]
if val_annotations and (iteration == iterations or iteration % val_iterations == 0):
stats = infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations,
mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali,
with_apex=with_apex, is_validation=True, verbose=False, rotated_bbox=rotated_bbox)
model.train()
if is_master and logdir is not None and stats is not None:
writer.add_scalar(
'Validation_Precision/mAP', stats[0], iteration)
writer.add_scalar(
'Validation_Precision/mAP@0.50IoU', stats[1], iteration)
writer.add_scalar(
'Validation_Precision/mAP@0.75IoU', stats[2], iteration)
writer.add_scalar(
'Validation_Precision/mAP (small)', stats[3], iteration)
writer.add_scalar(
'Validation_Precision/mAP (medium)', stats[4], iteration)
writer.add_scalar(
'Validation_Precision/mAP (large)', stats[5], iteration)
writer.add_scalar(
'Validation_Recall/mAR (max 1 Dets)', stats[6], iteration)
writer.add_scalar(
'Validation_Recall/mAR (max 10 Dets)', stats[7], iteration)
writer.add_scalar(
'Validation_Recall/mAR (max 100 Dets)', stats[8], iteration)
writer.add_scalar(
'Validation_Recall/mAR (small)', stats[9], iteration)
writer.add_scalar(
'Validation_Recall/mAR (medium)', stats[10], iteration)
writer.add_scalar(
'Validation_Recall/mAR (large)', stats[11], iteration)
mAP = stats[0]
print("best_mAP before:", best_mAP)
print("mAP:", mAP)
if mAP > best_mAP:
best_mAP = mAP
count_it = 0
else:
count_it += 1
print("count_it:", count_it)
print("best_mAP after:", best_mAP)
if (iteration==iterations and not rotated_bbox) or (iteration>iterations and rotated_bbox):
break
# for early stopping
if early_stop:
break
if is_master and logdir is not None:
writer.close()
patience
is the number of validation "slices" (--val-iters) through which the model training stops. You need to install it in the quantity that you need.
If the mAP metric does not improve during the amount of patience
, then the model stops learning.
Hi! Can you tell me if there is such a function as early stopping in your repository? So that the model stops learning after the metrics do not improve over a certain number of iterations.