microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.71k stars 4.15k forks source link

Backward time grows linearly to the number of to zero3_consolidated_16bit_state_dict called #5332

Open i4never opened 8 months ago

i4never commented 8 months ago

I'm training a model with ema states which means module state dict needs to be gathered after each step. When zero stage 3 is enabled model_engine.backward become slower as step grows. image Training loop

train_steps = 0
for _ in epoch:
    for x, y in dataset:
        ts = time()
        model_engine.backward(loss / grac_acc)
        running_loss.append(loss.detach().item() * args.grad_acc_steps)
        backward_seconds += time() - ts
        # ...
        if train_steps % grac_acc == 0:
            model_engine.step()

            if model_engine.zero_optimization_stage() == 3:
                sd = model_engine._zero3_consolidated_16bit_state_dict()
            elif is_global_rank_0():
                sd = model_engine.module_state_dict()

            if is_global_rank_0():
                update_ema(ema, sd)

            tb_writer.add_scalar(
                'Train/Backward time ms',
                backward_seconds * 1000,
                train_steps,
            )
            backward_seconds = 0
            train_steps += 1

if model_engine._zero3_consolidated_16bit_state_dict is disabled, backward time is not growing and everything works fine.

ds version is 0.12.6 Tried on 8 a800 & 8 3090

i4never commented 8 months ago
def zero3_consolidated_16bit_state_dict(model_engine):
    """
    Get a full non-partitioned state_dict with fp16 weights on cpu.
    Important: this function must be called on all ranks and not just rank 0.
    This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
    1. consolidates the weights from different partitions on gpu0
    2. works on one layer at a time to require as little gpu0 memory as possible, by
    moving the already consolidated weights to cpu
    3. takes care to keep the shared params shared when gradually copying the params to cpu
    Returns:
        a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
    """
    if not model_engine.zero_optimization_partition_weights():
        raise ValueError("this function requires ZeRO-3 mode")

    state_dict = OrderedDict() if dist.get_rank() == 0 else None
    shared_params = {}

    def get_layer_state_dict(module, prefix=""):
        # gather one layer at a time to be memory-efficient
        # must use modifier_rank=0 to release GPU memory after each layer gathered
        # see_memory_usage("before GatheredParameters", force=True)
        with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=None,
                                               enabled=False):
            if dist.get_rank() == 0:
                # handle params
                for name, param in module.named_parameters(recurse=False):
                    if param is None:
                        continue
                    key = prefix + name
                    # can't rely on param.data_ptr() as it will be reused as weights gets
                    # gathered and reduced, but param.ds_id is unique across all zero weights
                    # (and shared params will have the same param.ds_id)
                    if param.ds_id in shared_params:
                        # shared weights
                        # print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
                        state_dict[key] = state_dict[shared_params[param.ds_id]]
                    else:
                        state_dict[key] = param.detach().cpu()
                        shared_params[param.ds_id] = key
                    # print(f"param {param.ds_id} {param.shape} {key} ")

                # now buffers - not sure if need to take care of potentially shared weights here
                for name, buf in module.named_buffers(recurse=False):
                    if (buf is not None and name not in module._non_persistent_buffers_set):
                        state_dict[prefix + name] = buf.detach().cpu()
        # see_memory_usage("after GatheredParameters", force=True)

        for name, child in module.named_children():
            if child is not None:
                get_layer_state_dict(child, prefix + name + ".")

    get_layer_state_dict(model_engine.module, prefix="")

    return state_dict

#...
                if model_engine.zero_optimization_stage() == 3:
                    sd = zero3_consolidated_16bit_state_dict(model_engine)
# ...

New finding after some digging: If set enabled=False in deepspeed.zero.GatheredParameters, backward time will not grow. Is there some unexpected op attached to graph during GatheredParameters? Since backward needs to scan whole graph.

image