Open chexiangying opened 1 year ago
Hi @chexiangying. Would you be able to provide a bit more information regarding the model, how to repro the issue, and elaborate on what you meant by the precision of your model has decreased? Thanks!
The ddp version is
model, optimizer = amp.initialize(model, optimizer,
num_losses=len(task2scaler),
enabled=opts.optimizer["fp16"], opt_level='O2')
model = DDP(model, device_ids=[get_local_rank()], output_device=get_local_rank(), find_unused_parameters=False)
model.train()
# to compute training statistics
loss_moving_averagetors ={}
grad_norm = 0
optimizer.zero_grad()
optimizer.step()
for step, (name, batch) in enumerate(meta_loader):
task = name.split('--')[0]
loss_dict = model(batch, task=task, compute_loss=True)
loss = sum(list(loss_dict.values()))
loss_dict['total_loss'] = loss
loss_dict = {k:v.item() for k,v in loss_dict.items()}
delay_unscale = (step+1) % gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale,
loss_id=task2scaler[name]) as scaled_loss:
scaled_loss.backward()
if (step + 1) % gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
lr_ratio = get_lr_sched(global_step, AttrDict(opts.optimizer))
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['init_lr'] * lr_ratio
if global_step % 200 == 0:
LOGGER.info({name : averagetor.val for name, averagetor in loss_moving_averagetors.items()})
# update model params
if grad_norm_init != -1:
grad_norm = clip_grad_norm_(amp.master_params(optimizer),
grad_norm_init)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
optimizer.step()
optimizer.zero_grad()
,and the fsdp version is
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, # ZERO2
backward_prefetch=BackwardPrefetch.BACKWARD_PRE)
# LOGGER.info(model)
LOGGER.info(f"Model size after FSDP: {model_mem_size(model)} GB")
LOGGER.info("build_optimizer")
original_optimizer = build_optimizer(model, AttrDict(opts.optimizer))
if previous_optimizer_state:
checkpoint_optimizer = torch.load(
previous_optimizer_state, map_location="cpu"
)
original_optimizer.load_state_dict(checkpoint_optimizer)
del checkpoint_optimizer
LOGGER.info("build_optimizer")
scaler = ShardedGradScaler()
model.train()
original_optimizer.zero_grad()
if get_rank() == 0:
save_training_meta(opts)
pbar = tqdm(
total=opts.optimizer["num_train_steps"], initial=global_step
)
add_log_to_file(join(opts.output_dir, "log", "log.txt"))
else:
LOGGER.disabled = True
pbar = NoOp()
grad_norm_init = opts.optimizer["grad_norm"]
global_batch_size = args.train_micro_batch_size_per_gpu*n_gpu
LOGGER.info(" true batch size is = %d", global_batch_size)
# to compute training statistics
loss_moving_averagetors = defaultdict(RunningMeter)
# train parameters
gradient_accumulation_steps = opts.optimizer["gradient_accumulation_steps"]
num_train_steps = opts.optimizer["num_train_steps"]
valid_steps = opts.schedule["valid_steps"]
timers = get_timers()
timers("dataloader").start()
timers("trainStep").start()
for step, (name, batch) in enumerate(meta_loader, start=global_step):
timers("dataloader").end()
task = name.split("--")[0]
timers('forward').start()
loss_dict = model(batch, task=task, compute_loss=True)
timers("forward").end()
loss = sum(list(loss_dict.values()))
loss_dict["total_loss"] = loss
loss_dict = {k: v.item() for k, v in loss_dict.items()}
scaler.scale(loss).backward()
for k, v in loss_dict.items():
loss_moving_averagetors[f"loss_{name}/{k}"](v)
if (step + 1) % gradient_accumulation_steps == 0:
# scaler.unscale_(original_optimizer)
if grad_norm_init != -1:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), grad_norm_init
)
set_lrs(opts, original_optimizer, global_step)
scaler.step(original_optimizer)
scaler.update()
original_optimizer.zero_grad()
global_step += 1
if every(global_step, opts.save_every):
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
optimizer_state = original_optimizer.state_dict()
if dist.get_rank() == 0:
currEpoch = (
str(step) + ".pt"
)
model_save_name = "model_" + currEpoch
optimizer_save_name = "optimizer_" + currEpoch
torch.save(cpu_state, os.path.join(opts.output_dir, "ckpt", model_save_name))
torch.save(optimizer_state, os.path.join(opts.output_dir, "ckpt", optimizer_save_name))
I think the accuracy of the two should be equivalent, really? I first pretrain using the same train model、 optimizer and train config, and finetune on VQA task. Before compute the loss, I use dist.all_gather to synchronize model output, is this have an effect on the train result?
Maybe one difference is from using FP16 for the amp
+ DDP versus using BF16 for FSDP?
no bf16, fsdp is also fp16.
Does your model have batch norm modules or any buffers?
From the amp O2 documentation, I see that O2 keeps batch norm in FP32, and also I see from your initial issue that you are casting buffers down to low precision (BF16 in your original but assuming FP16 given your most recent comment). I am not sure if amp O2 casts buffers at all.
BatchNorm
.
mixed_precision (Optional[MixedPrecision]): A ``MixedPrecision`` instance
describing the mixed precision training config to be used. ``MixedPrecision``
supports configuring parameter, buffer, and gradient communication dtype. Note
that only floating point data is cast to the reduced precision. This allows
users potential memory saving and training speedup while trading off
accuracy during model training. If ``None``, no mixed precision is applied.
Note that if ``mixed_precision`` is enabled for FSDP model that
contains ``BatchNorm`` with ``auto_wrap_policy``, FSDP will take
care to disable mixed precision for ``BatchNorm`` units by wrapping
them separately in their own FSDP unit with ``mixed_precision=None``.
This is done because several ``BatchNorm`` kernels do not implement
reduced type support at the moment. If individually wrapping the model,
users must take care to set ``mixed_precision=None`` for
``BatchNorm`` units.
(Default: ``None``)
Unless apex.normalization.fused_layer_norm
inherits from _BatchNorm
, then FSDP will not disable mixed precision for it.
Another thing I noticed in your code is that you are using the local torch.nn.utils.clip_grad_norm_()
even for FSDP. Since gradients are sharded, the gradient norm used in torch.nn.utils.clip_grad_norm_()
will not take into account all gradient elements. Perhaps, you can try using FullyShardedDataParallel.clip_grad_norm_()
instead and see if that helps your loss.
I have tried the advice. But I still haven't got the precision that fsdp can align with ddp.
Have you tried setting reduce_dtype=torch.float32
? This will have some negative impact on throughput, but possibly having the gradient reduction accumulate and output a float32
will help? (In my understanding, Apex will not make the gradient reduction in float16
.)
🐛 Describe the bug
I used fsdp+ShardedGradScaler to train my model. Compared with apex. amp+ddp, the precision of my model has decreased.
The ddp is like
and the fsdp is like
What is possible reason?
Versions
。
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu