microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.57k stars 4.14k forks source link

GPU mem doesn't release after delete tensors in optimizer.bit16groups #6729

Open wheresmyhair opened 2 weeks ago

wheresmyhair commented 2 weeks ago

I'm developing a peft algorithm, basically it does the following:

Say the training process has 30 steps in total,

  1. For global step 0\~9: train lmhead + layer_0
  2. For global step 10\~19: train lmhead + layer_1
  3. For global step 20\~29: train lmhead + layer_0

The key point is that, after the switch, the states of lmhead are expected to be kept, while the states of the body layers should be deleted. For example, the step in lmhead state should go from 0 to 29, while step for body layers count from 0 after every switch, even if the layer has been selected before.

In this case, the parameter group looks like:

            optimizer_grouped_parameters = [
                {
                    # this should always be lmhead:
                    # `requires_grad` and `not in active_layers_names` rules out all body layers
                    # `in decay_parameters` rules out ln
                    "params": [
                        p for n, p in opt_model.named_parameters() if (
                            n not in self.active_layers_names and n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    # this should always be ln (outside of body layers)
                    "params": [
                        p for n, p in opt_model.named_parameters() if (
                            n not in self.active_layers_names and n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
                {
                    # selected body layers with decay 
                    "params": [
                        p for n, p in opt_model.named_parameters() if (
                            n in self.active_layers_names and n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    # selected body layers without decay
                    "params": [
                        p for n, p in opt_model.named_parameters() if (
                            n in self.active_layers_names and n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
            ]     

The first two represents layers that states should be kept, while the last two will change.

An approach I came up with is that partially "re-init" the optimizer at the beginning of the step that should do the switch. I modified my huggingface trainer based on ds optimizer __init__ method:

    def _reinit_deepspeed_zero_optimizer_params(self, optimizer: DeepSpeedZeroOptimizer):
        num_non_lisa_body_layer_pgs = len(self.optimizer.param_groups) - len(LISA_BODY_LAYER_PARAM_GROUPS_IDX)
        objs = [
            optimizer.bit16_groups, 
            optimizer.round_robin_bit16_groups, 
            optimizer.round_robin_bit16_indices, 
            optimizer.round_robin_bit16_meta, 
            optimizer.bit16_groups_flat, 
            optimizer.groups_padding, 
            optimizer.parallel_partitioned_bit16_groups, 
            optimizer.single_partition_of_fp32_groups, 
            optimizer.partition_size, 
            optimizer.params_in_partition, 
            optimizer.params_not_in_partition, 
            optimizer.first_offset
        ]
        for obj in objs:
            del obj[num_non_lisa_body_layer_pgs:]
        empty_cache()
        torch.cuda.empty_cache()
        gc.collect()

        for i, param_group in enumerate(optimizer.optimizer.param_groups):
            if i in range(num_non_lisa_body_layer_pgs):
                # skip lmhead, ln, etc.
                continue

            ## same as deepspeed/runtime/zero/stage_1_and_2.py DeepSpeedZeroOptimizer.__init__ below

            partition_id = dist.get_rank(group=optimizer.real_dp_process_group[i])

            # push this group to list before modify
            # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
            trainable_parameters = []
            for param in param_group['params']:
                if param.requires_grad:
                    param.grad_accum = None
                    trainable_parameters.append(param)
            optimizer.bit16_groups.append(trainable_parameters)

            # not sure why apex was cloning the weights before flattening
            # removing cloning here

            see_memory_usage(f"Before moving param group {i} to CPU")
            # move all the parameters to cpu to free up GPU space for creating flat buffer

            # Create temp CPU param copies, free accelerator tensors
            orig_group_numel = 0
            for param in optimizer.bit16_groups[i]:
                orig_group_numel += param.numel()
                param.cpu_data = param.data.cpu()
                param.data = torch.empty(1).to(param.device)

            empty_cache()
            see_memory_usage(f"After moving param group {i} to CPU", force=False)

            # Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
            # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
            # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
            # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
            if optimizer.round_robin_gradients:
                round_robin_tensors, round_robin_indices = optimizer._round_robin_reorder(
                    optimizer.bit16_groups[i], dist.get_world_size(group=optimizer.real_dp_process_group[i]))
            else:
                round_robin_tensors = optimizer.bit16_groups[i]
                round_robin_indices = list(range(len(optimizer.bit16_groups[i])))

            optimizer.round_robin_bit16_groups.append(round_robin_tensors)
            optimizer.round_robin_bit16_indices.append(round_robin_indices)

            # Create meta tensors list, ordered according to round_robin_tensors
            meta_tensors = []
            for param in round_robin_tensors:
                meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta"))
            optimizer.round_robin_bit16_meta.append(meta_tensors)

            # create flat buffer in CPU
            flattened_buffer = optimizer.flatten_dense_tensors_aligned(
                optimizer.round_robin_bit16_groups[i],
                optimizer.nccl_start_alignment_factor * dist.get_world_size(group=optimizer.real_dp_process_group[i]),
                use_cpu_data=True)

            # free temp CPU params
            for param in optimizer.bit16_groups[i]:
                del param.cpu_data

            # Move CPU flat tensor to the accelerator memory.
            optimizer.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name()))
            del flattened_buffer

            see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)

            # Record padding required for alignment
            if partition_id == dist.get_world_size(group=optimizer.real_dp_process_group[i]) - 1:
                padding = optimizer.bit16_groups_flat[i].numel() - orig_group_numel
            else:
                padding = 0
            optimizer.groups_padding.append(padding)

            if dist.get_rank(group=optimizer.real_dp_process_group[i]) == 0:
                see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)

            # set model bit16 weight to slices of flattened buffer
            optimizer._update_model_bit16_weights(i)

            # divide the flat weights into near equal partition equal to the data parallel degree
            # each process will compute on a different part of the partition
            data_parallel_partitions = optimizer.get_data_parallel_partitions(optimizer.bit16_groups_flat[i], i)
            optimizer.parallel_partitioned_bit16_groups.append(data_parallel_partitions)

            # verify that data partition start locations are 4-byte aligned
            for partitioned_data in data_parallel_partitions:
                assert (partitioned_data.data_ptr() % (2 * optimizer.nccl_start_alignment_factor) == 0)

            # A partition of the fp32 master weights that will be updated by this process.
            # Note that the params in single_partition_of_fp32_groups is cloned and detached
            # from the origin params of the model.
            if not optimizer.fp16_master_weights_and_gradients:
                weights_partition = optimizer.parallel_partitioned_bit16_groups[i][partition_id].to(
                    optimizer.device).clone().float().detach()
            else:
                weights_partition = optimizer.parallel_partitioned_bit16_groups[i][partition_id].to(
                    optimizer.device).clone().half().detach()

            if optimizer.cpu_offload:
                weights_partition = get_accelerator().pin_memory(weights_partition)

            optimizer.single_partition_of_fp32_groups.append(weights_partition)

            # Set local optimizer to have flat params of its own partition.
            # After this, the local optimizer will only contain its own partition of params.
            # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
            optimizer.single_partition_of_fp32_groups[
                i].requires_grad = True  # keep this in case internal optimizer uses it
            param_group['params'] = [optimizer.single_partition_of_fp32_groups[i]]

            partition_size = len(optimizer.bit16_groups_flat[i]) / dist.get_world_size(group=optimizer.real_dp_process_group[i])
            params_in_partition, params_not_in_partition, first_offset = optimizer.get_partition_info(
                optimizer.round_robin_bit16_groups[i], partition_size, partition_id)

            optimizer.partition_size.append(partition_size)
            optimizer.params_in_partition.append(params_in_partition)
            optimizer.params_not_in_partition.append(params_not_in_partition)
            optimizer.first_offset.append(first_offset)

However, I found del obj not working, as the mem profiling result shown below:

Image

I noticed the tensors the arrows point at spawn when:

            # Move CPU flat tensor to the accelerator memory.
            optimizer.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name()))

Are there any insights?

jomayeri commented 1 week ago

The del object doesn't work because other variables are still pointing to the tensors, likely the linked hp_params.

tjruwase commented 1 week ago

@wheresmyhair, we have put some effort into enabling better memory management. Please see the following links for relevance to your scenario:

  1. https://github.com/microsoft/DeepSpeed/blob/fc4e73370d84af5242996a90b32b3ffce8e6b922/deepspeed/runtime/engine.py#L414
  2. https://deepspeed.readthedocs.io/en/latest/zero3.html#gpu-memory-management