microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.71k stars 4.05k forks source link

Crash with cpu offload #707

Open pedrocolon93 opened 3 years ago

pedrocolon93 commented 3 years ago

Hi there! I have been using this configuration:

{
"zero_allow_untested_optimizer": true,
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 2,
       "allgather_partitions": true,
       "allgather_bucket_size": 2e6,
       "reduce_scatter": true,
       "reduce_bucket_size": 2e6,
        "overlap_comm": false,
        "contiguous_gradients": true,
        "cpu_offload":true
    },
     "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 5e-5,
            "betas": [ 0.9, 0.999 ],
            "eps": 1e-6,
            "weight_decay": 0.01
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 5e-5,
            "warmup_num_steps": 10000
        }
    }
}

To train a modified XLNet model (using the transformers library) on 4 1080ti's.

However after ~20 iterations, after the gradients scale correctly and training begins, it crashes in this function:

complete_grad_norm_calculation_for_cpu_offload(self, params):
        total_norm = 0.0
        norm_type = 2.0
        for p in params:
            if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
                param_id = self.get_param_id(p)
                param_norm = self.norm_for_param_grads[param_id]
                total_norm += param_norm.item()**2

With a key error in self.norm_for_param_grads[param_id].

I just sidestepped around this with a try: param_norm = self.norm_for_param_grads[param_id] total_norm += param_norm.item()**2 except: pass and it continues to train. Would anyone know what is happening?

pedrocolon93 commented 3 years ago

As an added thing, I need the cpu offload or it goes OOM

mrgjbd commented 3 years ago

There are useless computing nodes when computing grad.

tjruwase commented 3 years ago

@pedrocolon93 thanks for reporting this issue. And thanks @mrgjbd for your suggestion, which I think is correct.

@pedrocolon93 are you using model-parallelism in this training? Also, does the key error happen on all ranks?

pedrocolon93 commented 3 years ago

I'm not sure if its happening on all ranks, and I believe I'm using a distributed model rather than parallelism but I may be wrong.

Soonhwan-Kwon commented 3 years ago

I faced same issue here, and I'm also using distributed model, and using all ranks. And yes I iterated through exactly 20 iterations.

pedrocolon93 commented 3 years ago

@Soonhwan-Kwon If you need to continue training, patch it with a try/catch. Its not an elegant fix but it will get the job done.

Soonhwan-Kwon commented 3 years ago

@Soonhwan-Kwon If you need to continue training, patch it with a try/catch. Its not an elegant fix but it will get the job done.

thank you for the suggestion, I'll try it right away and see what's happening.

HHousen commented 3 years ago

I am getting this same error. I am not using model-parallelism. (The is_model_parallel_parameter function still returns True because of deepspeed/runtime/pipe/module.py line 246.) https://github.com/huggingface/transformers/pull/9622 fixed a similar crash that happened because of gradient accumulation steps (https://github.com/microsoft/DeepSpeed/issues/671). For me it happens every time after exactly 20 steps. I am using pytorch-lightning with a huggingface/transformers model.

Here is the portion of the traceback involving DeepSpeed:

  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/deepspeed_precision.py", line 30, in pre_optimizer_step
    deepspeed_engine.step()
  File "/usr/local/lib/python3.7/dist-packages/deepspeed/runtime/engine.py", line 959, in step
    self._take_model_step(lr_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/deepspeed/runtime/engine.py", line 914, in _take_model_step
    self.optimizer.step()
  File "/usr/local/lib/python3.7/dist-packages/deepspeed/runtime/zero/stage2.py", line 1379, in step
    self.params_in_partition[i]))
  File "/usr/local/lib/python3.7/dist-packages/deepspeed/runtime/zero/stage2.py", line 881, in complete_grad_norm_calculation_for_cpu_offload
    param_norm = self.norm_for_param_grads[param_id]
KeyError: 8
Soonhwan-Kwon commented 3 years ago

@pedrocolon93 well, thank you, it is working now with your suggestion(try except), but i can't get away from the bad tastes. @mrgjbd How can we check useless nodes when computing grad and get remove it? I would greatly appreciate it if you kindly give me some advice.

self._take_model_step(lr_kwargs)
  File "/home/soouee/anaconda3/envs/pytorch_marco/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 914, in _take_model_step
    self.optimizer.step()
  File "/home/soouee/anaconda3/envs/pytorch_marco/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 1379, in step
    self.params_in_partition[i]))
  File "/home/soouee/anaconda3/envs/pytorch_marco/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 881, in complete_grad_norm_calculation_for_cpu_offload
    param_norm = self.norm_for_param_grads[param_id]
KeyError: 533
Soonhwan-Kwon commented 3 years ago

I encountered this Error when I passed the key error

timer has already been started

and it keeps happening and can't make model to learn at all.

ghosthamlet commented 3 years ago

@mrgjbd is right, this is my detailed explain: the KeyError was caused by unused parameter, if you disable deepspeed and use torch.nn.parallel.DistributedDataParallel with find_unused_parameters=False, it may have this error message:

    if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. 
This error indicates that your module has parameters that were not used in producing loss. 
You can enable unused parameter detection by 
(1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; 
(2) making sure all `forward` function outputs participate in calculating loss.
 If you already have done the above two steps, then the distributed data parallel module wasn't able to locate 
the output tensors in the return value of your module's `forward` function. 
Please include the loss function and the structure of the return value of `forward` of your module when 
reporting this issue (e.g. list, dict, iterable).

These errors happened when the model have trainable parameters but skipped in training, these skipped params will not go through backward, so their backward hooks in self.create_reduce_and_remove_grad_hooks() of zero stage2 will not run, then they have no norm_for_param_grads, if the skip is what you want, then the hack by @pedrocolon93 is the right way: try: param_norm = self.norm_for_param_grads[param_id] total_norm += param_norm.item()**2 except: pass , or better:

if param_id in self.norm_for_param_grads: 
    param_norm = self.norm_for_param_grads[param_id] 
    total_norm += param_norm.item()**2