princeton-nlp / LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
https://arxiv.org/abs/2310.06694
MIT License
533 stars 39 forks source link

NotImplementedError: offload_to_cpu=True and NO_SHARD is not supported yet #23

Closed Longyichen closed 9 months ago

Longyichen commented 9 months ago

This problem occurs when the pruning program is finished and saved. This issue results in no way to save checkpoints normally

User
log(event, f'Running callback {type(cb).__na │
│ ❱ 468 │   │   │   │   cb.run_event(event, self.state, self.logger)             │
│   469 │                                                                        │
│   470 │   def _run_loggers(self, event: Union[Event, str]):                    │
│   471 │   │   loggers = [callback for callback in self.state.callbacks if isin │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/core/call │
│ back.py:96 in run_event                                                        │
│                                                                                │
│    93 │   │   │   logger (Logger): The logger.                                 │
│    94 │   │   """                                                              │
│    95 │   │   event_cb = getattr(self, event.value)                            │
│ ❱  96 │   │   return event_cb(state, logger)                                   │
│    97 │                                                                        │
│    98 │   def init(self, state: State, logger: Logger) -> None:                │
│    99 │   │   """Called on the :attr:`.Event.INIT` event.                      │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/callbacks │
│ /checkpoint_saver.py:294 in batch_checkpoint                                   │
│                                                                                │
│   291 │   def batch_checkpoint(self, state: State, logger: Logger):            │
│   292 │   │   assert callable(self.save_interval)                              │
│   293 │   │   if self.save_interval(state, Event.BATCH_CHECKPOINT) and self.la │
│ ❱ 294 │   │   │   self._save_checkpoint(                                       │
│   295 │   │   │   │   state,                                                   │
│   296 │   │   │   │   logger,                                                  │
│   297 │   │   │   )                                                            │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/callbacks │
│ /checkpoint_saver.py:332 in _save_checkpoint                                   │
│                                                                                │
│   329 │   │   # save the checkpoint to the filename                            │
│   330 │   │   filename_with_placeholders = self.filename.format(state, is_deep │
│   331 │   │                                                                    │
│ ❱ 332 │   │   saved_path = checkpoint.save_checkpoint(                         │
│   333 │   │   │   state=state,                                                 │
│   334 │   │   │   filename=filename_with_placeholders,                         │
│   335 │   │   │   weights_only=self.weights_only,                              │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/utils/che │
│ ckpoint.py:761 in save_checkpoint                                              │
│                                                                                │
│   758 │   │   }                                                                │
│   759 │   else:                                                                │
│   760 │   │   state_dict = {                                                   │
│ ❱ 761 │   │   │   'state': state.state_dict(),                                 │
│   762 │   │   │   'rng': reproducibility.get_rng_state(),                      │
│   763 │   │   }                                                                │
│   764                                                                          │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/core/stat │
│ e.py:891 in state_dict                                                         │
│                                                                                │
│    888 │   │   │   if attribute_name == 'dataset_state':                       │
│    889 │   │   │   │   serialized_value = self._dataset_state_dict()           │
│    890 │   │   │   elif attribute_name == 'model':                             │
│ ❱  891 │   │   │   │   serialized_value = self.get_model_state_dict()          │
│    892 │   │   │   elif attribute_name == 'optimizers':                        │
│    893 │   │   │   │   optimizer = ensure_tuple(attribute_value)[              │
│    894 │   │   │   │   │   0]  # Let's stop pretending. We don't support more  │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/composer/core/stat │
│ e.py:868 in get_model_state_dict                                               │
│                                                                                │
│    865 │   │   """                                                             │
│    866 │   │   if self.fsdp_enabled and self.fsdp_state_dict_type is not None: │
│    867 │   │   │   with fsdp_state_dict_type_context(self.model, state_dict_ty │
│ ❱  868 │   │   │   │   model_state_dict = self.model.state_dict()              │
│    869 │   │   else:                                                           │
│    870 │   │   │   model_state_dict = self.model.state_dict()                  │
│    871                                                                         │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/nn/modules/m │
│ odule.py:1818 in state_dict                                                    │
│                                                                                │
│   1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)        │
│   1816 │   │   for name, module in self._modules.items():                      │
│   1817 │   │   │   if module is not None:                                      │
│ ❱ 1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefi │
│   1819 │   │   for hook in self._state_dict_hooks.values():                    │
│   1820 │   │   │   hook_result = hook(self, destination, prefix, local_metadat │
│   1821 │   │   │   if hook_result is not None:                                 │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/nn/modules/m │
│ odule.py:1815 in state_dict                                                    │
│                                                                                │
│   1812 │   │   if hasattr(destination, "_metadata"):                           │
│   1813 │   │   │   destination._metadata[prefix[:-1]] = local_metadata         │
│   1814 │   │                                                                   │
│ ❱ 1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)        │
│   1816 │   │   for name, module in self._modules.items():                      │
│   1817 │   │   │   if module is not None:                                      │
│   1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefi │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/nn/modules/m │
│ odule.py:1722 in _save_to_state_dict                                           │
│                                                                                │
│   1719 │   │   │   │   module                                                  │
│   1720 │   │   """                                                             │
│   1721 │   │   for hook in self._state_dict_pre_hooks.values():                │
│ ❱ 1722 │   │   │   hook(self, prefix, keep_vars)                               │
│   1723 │   │                                                                   │
│   1724 │   │   for name, param in self._parameters.items():                    │
│   1725 │   │   │   if param is not None:                                       │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/utils/_conte │
│ xtlib.py:115 in decorate_context                                               │
│                                                                                │
│   112 │   @functools.wraps(func)                                               │
│   113 │   def decorate_context(*args, **kwargs):                               │
│   114 │   │   with ctx_factory():                                              │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                 │
│   116 │                                                                        │
│   117 │   return decorate_context                                              │
│   118                                                                          │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ │
│ fsdp/_state_dict_utils.py:669 in _pre_state_dict_hook                          │
│                                                                                │
│   666 │   │   StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,      │
│   667 │   │   StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,  │
│   668 │   }                                                                    │
│ ❱ 669 │   _pre_state_dict_hook_fn[fsdp_state._state_dict_type](                │
│   670 │   │   fsdp_state,                                                      │
│   671 │   │   module,                                                          │
│   672 │   │   *args,                                                           │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ │
│ fsdp/_state_dict_utils.py:271 in _full_pre_state_dict_hook                     │
│                                                                                │
│   268 │   in ``nn.Module``.                                                    │
│   269 │   """                                                                  │
│   270 │   _common_pre_state_dict_hook(module, fsdp_state)                      │
│ ❱ 271 │   _common_unshard_pre_state_dict_hook(                                 │
│   272 │   │   module,                                                          │
│   273 │   │   fsdp_state,                                                      │
│   274 │   │   offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,     │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ │
│ fsdp/_state_dict_utils.py:143 in _common_unshard_pre_state_dict_hook           │
│                                                                                │
│   140 │   Performs the pre-state_dict tasks shared by all state_dict types tha │
│   141 │   ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_ │
│   142 │   """                                                                  │
│ ❱ 143 │   _enter_unshard_params_ctx(                                           │
│   144 │   │   module,                                                          │
│   145 │   │   fsdp_state,                                                      │
│   146 │   │   writeback=False,                                                 │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ │
│ fsdp/_state_dict_utils.py:109 in _enter_unshard_params_ctx                     │
│                                                                                │
│   106 │   │   offload_to_cpu=offload_to_cpu,                                   │
│   107 │   │   with_grads=with_grads,                                           │
│   108 │   )                                                                    │
│ ❱ 109 │   fsdp_state._unshard_params_ctx[module].__enter__()                   │
│   110                                                                          │
│   111                                                                          │
│   112 @no_type_check                                                           │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/contextlib.py:135 in __enter__   │
│                                                                                │
│   132 │   │   # they are only needed for recreation, which is not possible any │
│   133 │   │   del self.args, self.kwds, self.func                              │
│   134 │   │   try:                                                             │
│ ❱ 135 │   │   │   return next(self.gen)                                        │
│   136 │   │   except StopIteration:                                            │
│   137 │   │   │   raise RuntimeError("generator didn't yield") from None       │
│   138                                                                          │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ │
│ fsdp/_unshard_param_utils.py:171 in _unshard_fsdp_state_params                 │
│                                                                                │
│   168 │   This unshards the parameters for a single FSDP state ``state`` that  │
│   169 │   corresponds to ``module``.                                           │
│   170 │   """                                                                  │
│ ❱ 171 │   _validate_unshard_params_args(                                       │
│   172 │   │   state, writeback, rank0_only, offload_to_cpu, with_grads         │
│   173 │   )                                                                    │
│   174 │   torch.cuda.synchronize()                                             │
│                                                                                │
│ /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torch/distributed/ │
│ fsdp/_unshard_param_utils.py:140 in _validate_unshard_params_args              │
│                                                                                │
│   137 │   if offload_to_cpu and any(                                           │
│   138 │   │   not handle.uses_sharded_strategy for handle in state._handles    │
│   139 │   ):                                                                   │
│ ❱ 140 │   │   raise NotImplementedError(                                       │
│   141 │   │   │   "offload_to_cpu=True and NO_SHARD is not supported yet"      │
│   142 │   │   )                                                                │
│   143 │   if writeback and rank0_only:                                         │
╰────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: offload_to_cpu=True and NO_SHARD is not supported yet

There are related discussions about this problem on the Internet https://github.com/huggingface/transformers/issues/24874 https://github.com/huggingface/transformers/issues/24874 , and it seems to have been fixed in the transformer library. Are you considering updating the version or adopting other fixes?

xiamengzhou commented 9 months ago

It seems to be a lack of support for native FSDP for these settings. Since our repo does not use transformers code, we probably cannot easily use the PR? Plus, Composer has another wrapper on top of FSDP which might add extra complexity.

Longyichen commented 9 months ago

I am currently encountering this problem when running a pruning experiment on a single GPU. I want to solve it by turning off FSDP, but this code seems to be incompatible with settings without FSDP and encounters many problems. But when I turned on the FSDP settings, this problem prevented me from getting Checkpoint. Are you able to share the general configuration of your machine, such as cuda, FSDP and other infrastructure? Will you encounter this problem when running a program on a single GPU?

xiamengzhou commented 9 months ago

Hiii I never experimented with single GPUs. You might need to deactivate the FSDP setup in the yaml scripts to completely turn off FSDP by setting`fsdp_config=null.