With apex enabled, It looks like the code averages out the detached loss, and then takes a gradient step on the original loss from the current process only.
Put another way: it looks like the code currently uses multiple gpus to generate e.g. 8 loss values, calculates their average for logging purposes only, and takes a gradient step only on one of them. I've highlighted the relevant lines from train.py below:
optim.zero_grad()
main_loss = net(inputs)
if args.apex:
log_main_loss = main_loss.clone().detach_()
torch.distributed.all_reduce(******log_main_loss*****, # this looks wrong
torch.distributed.ReduceOp.SUM)
log_main_loss = log_main_loss / args.world_size
else:
main_loss = main_loss.mean()
log_main_loss = main_loss.clone().detach_()
train_main_loss.update(log_main_loss.item(), batch_pixel_size)
if args.fp16:
with amp.scale_loss(main_loss, optim) as scaled_loss:
scaled_loss.backward()
else:
main_loss.backward()
With apex enabled, It looks like the code averages out the detached loss, and then takes a gradient step on the original loss from the current process only.
Put another way: it looks like the code currently uses multiple gpus to generate e.g. 8 loss values, calculates their average for logging purposes only, and takes a gradient step only on one of them. I've highlighted the relevant lines from train.py below: