microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.65k stars 4.14k forks source link

[BUG] clip_grad_norm for zero_optimization mode is not working #6767

Open chengmengli06 opened 1 week ago

chengmengli06 commented 1 week ago

set "gradient_clipping" in deepspeed does not work, look into the source code in deepspeed.runtime.engine.DeepSpeedEngine,in line 2101

    def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
        if self.gradient_clipping() > 0.0:
            if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
                self.clip_fp32_gradients()
            elif self.amp_enabled():
                # AMP's recommended way of doing clipping
                # https://nvidia.github.io/apex/advanced.html#gradient-clipping
                master_params = amp.master_params(self.optimizer)
                clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
        self.optimizer.step()

thus gradient clipping do nothing at all!!!

tjruwase commented 1 week ago

@chengmengli06, this is incorrect reading of the code. Gradient clipping is handled in the respective optimizer implementations such as:

  1. bf16 optim
  2. fp16 optim
  3. zero
chengmengli06 commented 1 week ago

I find it, and verify that it does work under zero_2 mode. Thanks!

chengmengli06 commented 1 week ago

@tjruwase another question is how log the pre-clip and after clip gradient norms to tensorboard? is there any interface to get the pre and after clip gradient norms?