Open stas00 opened 3 years ago
I will also post here my experiment to do staggered loading (which appears to make no difference peak CPU memory wise)
I have used 2 approaches.
flock
-based approach operates correctlybarrier
-based one seems to hang in barrier
for some reason.# engine.py
def load_checkpoint(self,
load_dir,
tag=None,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Returns:
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``client_state``: State dictionary used for loading required training states in the client code.
"""
def load_wrapper(self,
load_dir,
tag,
load_module_strict,
load_optimizer_states,
load_lr_scheduler_states):
if tag is None:
latest_path = os.path.join(load_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
logger.warning(f"Unable to find latest file at {latest_path}, if trying to load latest " \
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.")
return None, None
load_path, client_states = self._load_checkpoint(load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
if self.zero_optimization() and load_path is not None:
pass
self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
return load_path, client_states
import gc
import time
if self.checkpoint_staggered_load(): # added a new config elsewhere, can just use `if 1` for now
# version based on flock
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
print(f"flock! {dist.get_rank()}: loading")
load_path, client_states = load_wrapper(self, load_dir, tag, load_module_strict, load_optimizer_states, load_lr_scheduler_states)
#time.sleep(1)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
# # version based on barrier (hangs in barrier - as probably there is another barrier elsewhere)
# for i in range(dist.get_world_size()):
# if i == dist.get_rank():
# print(f"barrier! {dist.get_rank()}: loading")
# load_path, client_states = load_wrapper(self, load_dir, tag, load_module_strict, load_optimizer_states, load_lr_scheduler_states)
# else:
# print(f"barrier! {dist.get_rank()}: waiting for my slot")
# dist.barrier()
gc.collect()
return load_path, client_states
else:
load_path, client_states = load_wrapper(self, load_dir, tag, load_module_strict, load_optimizer_states, load_lr_scheduler_states)
gc.collect()
return load_path, client_states
# return load_wrapper(self, load_dir, tag, load_module_strict, load_optimizer_states, load_lr_scheduler_states)
hi, @stas00 any follow up modifications to solve the redundant loading problem?
I am not sure why you're asking me. I reported the issue so I'm in the same boat as you are.
And there has been no follow up since I posted this many moons ago.
cc: @tjruwase
Unfortunately, this is one of the many feature requests that we have lacked bandwidth to address. @TobiasLee, since this quite old, I wonder if you still experience this issue and could share the current manifestation. This problem originally came from an attempt (called elastic checkpointing) to support different save-DP and load-DP. That feature has since been disabled and should hopefully be replaced by something better called universal checkpointing. Also, we have tried to avoid this CPU memory explosion with this change. So, can you please share some details of your current experience? Thanks!
Tunji, it's pretty safe to assume that this problem impacts anybody with available CPU memory < total GPU memory if they load the gpu memory to the brim.
e.g. this is the case with free colab for something really small scale (their 1 gpu is much larger than available cpu memory once you got everything loaded) or even some HPCs have this issue. e.g. one has 512GB CPU RAM vs 640GB GPU RAM on each node.
@stas00 Hello Have you solved this problem? If there is a solution, I would appreciate it if you could share it.
This is for the Deepspeed team to solve. As far as I know it hasn't been resolved.
cc: @tjruwase
As it was originally reported in https://github.com/huggingface/transformers/issues/12680 a user can easily train and save checkpoints, but don't have enough RAM to subsequently load that same checkpoint into memory.
One discussed approach is to use staggered loading so multiple processed won't try to use CPU memory at once, but the main issue appears in the fact that currently each process loads zero checkpoints for all ranks in deepspeed, even though most of the time each process uses just their own rank's checkpoint.
In zero2 this is needed when
load_from_fp32_weights
isTrue
inload_state_dict
to recover original fp32 weights which are spread out through multiple processes. In zero3 this option is ignored, but if I'm not mistaken it always recovers fp32 weights here.The other intention for having more than just this rank's state is to be able to dynamically move from say DP=4 to DP=2 and so each process will then need 2 zero checkpoints.
But I think there should be also a support for a case where a user is doing the straightforward save/load, is ok to recover from fp16 weights and it shouldn't take much more additional memory than it took to save the checkpoint.
Not sure what is the best way to approach this, so would be happy to hear your ideas.
Based on profiling the main CPU memory-hungry call is:
it does return all of it when it's done. So we are talking only about a high tmp memory need.
Thank you.
@tjruwase