pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.72k stars 345 forks source link

Running Larger Models (Pipeline Parallel/Model Parallel) #678

Closed aflah02 closed 1 month ago

aflah02 commented 1 month ago

🚀 Feature

Hi! I want to finetune a Llama 7B model which does not fit on one GPU. The library supports DDP but that is only helpful if the model, optimizer etc all fit on one GPU but in this case they don't. I was wondering if there is a way to support something like pipeline parallel which you can easily get by setting device_map to auto when loading a HF Transformer model. I suppose FSDP and the likes would require much more work that this.

aflah02 commented 1 month ago

I modified the clip_and_accumulate function in DPOptimizer as this -

    def clip_and_accumulate(self):
        """
        Performs gradient clipping.
        Stores clipped and aggregated gradients into `p.summed_grad```
        """

        # HACKY PARALLEL - 

        number_of_GPUS = torch.cuda.device_count()

        devices = ['cuda:' + str(i) for i in range(number_of_GPUS)]

        # Randomly choose a device
        randomly_chosen_device = devices[torch.randint(0, len(devices), (1,)).item()]

        if len(self.grad_samples[0]) == 0:
            # Empty batch
            # per_sample_clip_factor = torch.zeros(
            #     (0,), device=self.grad_samples[0].device
            # )
            per_sample_clip_factor = torch.zeros(
                (0,), device=randomly_chosen_device
            )
        else:
            per_param_norms = [
                g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
            ]
            # Move per_param_norms to randomly_chosen_device
            per_param_norms = [g.to(randomly_chosen_device) for g in per_param_norms]
            per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
            per_sample_clip_factor = (
                self.max_grad_norm / (per_sample_norms + 1e-6)
            ).clamp(max=1.0)

        for p in self.params:
            _check_processed_flag(p.grad_sample)
            grad_sample = self._get_flat_grad_sample(p)

            # Move per_sample_clip_factor to the device of grad_sample temporarily

            per_sample_clip_factor_device = per_sample_clip_factor.device

            per_sample_clip_factor = per_sample_clip_factor.to(grad_sample.device)

            grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

            # Move per_sample_clip_factor back to its original device
            per_sample_clip_factor = per_sample_clip_factor.to(per_sample_clip_factor_device)

            if p.summed_grad is not None:
                p.summed_grad += grad
            else:
                p.summed_grad = grad

            _mark_as_processed(p.grad_sample)

This works when using naive pipeline parallel when dividing model using device_map='auto'

HuanyuZhang commented 1 month ago

Thanks for the quick move! The code makes sense to me. We do have the plan to support both FSDP and pipeline parallelism, which we expect to see some initial deliveries in late Q4 or early Q1. We are on the same page that pipeline parallel is a good start point.

Just quick comments:

  1. it might be a better choice to build your code change on top of the mode of ghost clipping rather than the vanilla GradSampleModule, since GC is a must-have for large models. The code change should be very similar (for GC, per_sample_norm lives in GradSampleModule rather than the optimizer).
  2. It remains an interesting (requires more efforts) problem to see how it could support micro-batch PP (the first figure in this link), which will greatly improve QPS.
aflah02 commented 1 month ago

Thanks for the confirmation!