pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.77k stars 22.59k forks source link

ddp vs fsdp #91879

Open chexiangying opened 1 year ago

chexiangying commented 1 year ago

🐛 Describe the bug

I used fsdp+ShardedGradScaler to train my model. Compared with apex. amp+ddp, the precision of my model has decreased.

The ddp is like

model, optimizer = amp.initialize(model, optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.optimizer["fp16"], opt_level='O2')

model = DDP(model, device_ids=[get_local_rank()], output_device=get_local_rank(), find_unused_parameters=False)
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale,
                    loss_id=task2scaler[name]) as scaled_loss:

    scaled_loss.backward()

and the fsdp is like

model = FSDP(model,
                 auto_wrap_policy=t5_auto_wrap_policy,
                 mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        # Gradient communication precision.
        reduce_dtype=torch.bfloat16,
        # Buffer precision.
        buffer_dtype=torch.bfloat16,
    ),
                 device_id=torch.cuda.current_device(),
                 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,  # ZERO2
                 backward_prefetch=BackwardPrefetch.BACKWARD_PRE)

What is possible reason?

Versions

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

wz337 commented 1 year ago

Hi @chexiangying. Would you be able to provide a bit more information regarding the model, how to repro the issue, and elaborate on what you meant by the precision of your model has decreased? Thanks!

chexiangying commented 1 year ago

The ddp version is

model, optimizer = amp.initialize(model, optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.optimizer["fp16"], opt_level='O2')

    model = DDP(model, device_ids=[get_local_rank()], output_device=get_local_rank(), find_unused_parameters=False)
    model.train()

    # to compute training statistics
    loss_moving_averagetors ={}
    grad_norm = 0
    optimizer.zero_grad()
    optimizer.step()

    for step, (name, batch) in enumerate(meta_loader):

        task = name.split('--')[0]

        loss_dict = model(batch, task=task, compute_loss=True)

        loss = sum(list(loss_dict.values()))
        loss_dict['total_loss'] = loss
        loss_dict = {k:v.item() for k,v in loss_dict.items()}

        delay_unscale = (step+1) % gradient_accumulation_steps != 0

        with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale,
                            loss_id=task2scaler[name]) as scaled_loss:

            scaled_loss.backward()
        if (step + 1) % gradient_accumulation_steps == 0:

            global_step += 1
            # learning rate scheduling

            lr_ratio = get_lr_sched(global_step, AttrDict(opts.optimizer))
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['init_lr'] * lr_ratio

            if global_step % 200 == 0:    
                LOGGER.info({name : averagetor.val for name, averagetor in loss_moving_averagetors.items()})   
            # update model params
            if grad_norm_init != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            grad_norm_init)
                TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
            optimizer.step()
            optimizer.zero_grad()

,and the fsdp version is

model = FSDP(model,
                 auto_wrap_policy=t5_auto_wrap_policy,
                 mixed_precision=mp_policy,
                 device_id=torch.cuda.current_device(),
                 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,  # ZERO2
                 backward_prefetch=BackwardPrefetch.BACKWARD_PRE)
    # LOGGER.info(model)
    LOGGER.info(f"Model size after FSDP: {model_mem_size(model)} GB")
    LOGGER.info("build_optimizer")
    original_optimizer = build_optimizer(model, AttrDict(opts.optimizer))
    if previous_optimizer_state:
        checkpoint_optimizer = torch.load(
            previous_optimizer_state, map_location="cpu"
        )
        original_optimizer.load_state_dict(checkpoint_optimizer)
        del checkpoint_optimizer
    LOGGER.info("build_optimizer")
    scaler = ShardedGradScaler()

    model.train()
    original_optimizer.zero_grad()
    if get_rank() == 0:
        save_training_meta(opts)
        pbar = tqdm(
            total=opts.optimizer["num_train_steps"], initial=global_step
        )
        add_log_to_file(join(opts.output_dir, "log", "log.txt"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
    grad_norm_init = opts.optimizer["grad_norm"]
    global_batch_size = args.train_micro_batch_size_per_gpu*n_gpu
    LOGGER.info("  true batch size is = %d", global_batch_size)
    # to compute training statistics
    loss_moving_averagetors = defaultdict(RunningMeter)

    # train parameters
    gradient_accumulation_steps = opts.optimizer["gradient_accumulation_steps"]
    num_train_steps = opts.optimizer["num_train_steps"]

    valid_steps = opts.schedule["valid_steps"]
    timers = get_timers()
    timers("dataloader").start()
    timers("trainStep").start()

    for step, (name, batch) in enumerate(meta_loader, start=global_step):
        timers("dataloader").end()
        task = name.split("--")[0]
        timers('forward').start()
        loss_dict = model(batch, task=task, compute_loss=True)
        timers("forward").end()
        loss = sum(list(loss_dict.values()))
        loss_dict["total_loss"] = loss
        loss_dict = {k: v.item() for k, v in loss_dict.items()}

        scaler.scale(loss).backward()

        for k, v in loss_dict.items():
            loss_moving_averagetors[f"loss_{name}/{k}"](v)
        if (step + 1) % gradient_accumulation_steps == 0:

            # scaler.unscale_(original_optimizer)
            if grad_norm_init != -1:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), grad_norm_init
                )
            set_lrs(opts, original_optimizer, global_step)
            scaler.step(original_optimizer)
            scaler.update()

            original_optimizer.zero_grad()
            global_step += 1

            if every(global_step, opts.save_every):

                save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                with FSDP.state_dict_type(
                        model, StateDictType.FULL_STATE_DICT, save_policy
                ):
                    cpu_state = model.state_dict()
                    optimizer_state = original_optimizer.state_dict()
                if dist.get_rank() == 0:
                    currEpoch = (
                            str(step) + ".pt"
                    )
                    model_save_name = "model_" + currEpoch
                    optimizer_save_name = "optimizer_" + currEpoch
                    torch.save(cpu_state, os.path.join(opts.output_dir, "ckpt", model_save_name))
                    torch.save(optimizer_state, os.path.join(opts.output_dir, "ckpt", optimizer_save_name))

I think the accuracy of the two should be equivalent, really? I first pretrain using the same train model、 optimizer and train config, and finetune on VQA task. Before compute the loss, I use dist.all_gather to synchronize model output, is this have an effect on the train result?

awgu commented 1 year ago

Maybe one difference is from using FP16 for the amp + DDP versus using BF16 for FSDP?

chexiangying commented 1 year ago

no bf16, fsdp is also fp16.

awgu commented 1 year ago

Does your model have batch norm modules or any buffers?

From the amp O2 documentation, I see that O2 keeps batch norm in FP32, and also I see from your initial issue that you are casting buffers down to low precision (BF16 in your original but assuming FP16 given your most recent comment). I am not sure if amp O2 casts buffers at all.

chexiangying commented 1 year ago
  1. My model has batch norm module (using apex.fused_layer_norm)
  2. Sure, I use fp16 and get the result that fsdp has lower accuracy than ddp+apex. I think the amp is same probably. The explain from torch/distributed/fsdp/fully_sharded_data_parallel.py is follow. According the explain fsdp disable mixed precision for BatchNorm.
    mixed_precision (Optional[MixedPrecision]): A ``MixedPrecision`` instance
            describing the mixed precision training config to be used. ``MixedPrecision``
            supports configuring parameter, buffer, and gradient communication dtype. Note
            that only floating point data is cast to the reduced precision. This allows
            users potential memory saving and training speedup while trading off
            accuracy during model training. If ``None``, no mixed precision is applied.
            Note that if ``mixed_precision`` is enabled for FSDP model that
            contains ``BatchNorm`` with ``auto_wrap_policy``, FSDP will take
            care to disable mixed precision for ``BatchNorm`` units by wrapping
            them separately in their own FSDP unit with ``mixed_precision=None``.
            This is done because several ``BatchNorm`` kernels do not implement
            reduced type support at the moment. If individually wrapping the model,
            users must take care to set ``mixed_precision=None`` for
            ``BatchNorm`` units.
            (Default: ``None``)
awgu commented 1 year ago

Unless apex.normalization.fused_layer_norm inherits from _BatchNorm, then FSDP will not disable mixed precision for it.

Another thing I noticed in your code is that you are using the local torch.nn.utils.clip_grad_norm_() even for FSDP. Since gradients are sharded, the gradient norm used in torch.nn.utils.clip_grad_norm_() will not take into account all gradient elements. Perhaps, you can try using FullyShardedDataParallel.clip_grad_norm_() instead and see if that helps your loss.

chexiangying commented 1 year ago

I have tried the advice. But I still haven't got the precision that fsdp can align with ddp.

awgu commented 1 year ago

Have you tried setting reduce_dtype=torch.float32? This will have some negative impact on throughput, but possibly having the gradient reduction accumulate and output a float32 will help? (In my understanding, Apex will not make the gradient reduction in float16.)