Open i4never opened 8 months ago
def zero3_consolidated_16bit_state_dict(model_engine):
"""
Get a full non-partitioned state_dict with fp16 weights on cpu.
Important: this function must be called on all ranks and not just rank 0.
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
1. consolidates the weights from different partitions on gpu0
2. works on one layer at a time to require as little gpu0 memory as possible, by
moving the already consolidated weights to cpu
3. takes care to keep the shared params shared when gradually copying the params to cpu
Returns:
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
"""
if not model_engine.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")
state_dict = OrderedDict() if dist.get_rank() == 0 else None
shared_params = {}
def get_layer_state_dict(module, prefix=""):
# gather one layer at a time to be memory-efficient
# must use modifier_rank=0 to release GPU memory after each layer gathered
# see_memory_usage("before GatheredParameters", force=True)
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=None,
enabled=False):
if dist.get_rank() == 0:
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_params:
# shared weights
# print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
state_dict[key] = state_dict[shared_params[param.ds_id]]
else:
state_dict[key] = param.detach().cpu()
shared_params[param.ds_id] = key
# print(f"param {param.ds_id} {param.shape} {key} ")
# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if (buf is not None and name not in module._non_persistent_buffers_set):
state_dict[prefix + name] = buf.detach().cpu()
# see_memory_usage("after GatheredParameters", force=True)
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
get_layer_state_dict(model_engine.module, prefix="")
return state_dict
#...
if model_engine.zero_optimization_stage() == 3:
sd = zero3_consolidated_16bit_state_dict(model_engine)
# ...
New finding after some digging:
If set enabled=False
in deepspeed.zero.GatheredParameters
, backward time will not grow. Is there some unexpected op attached to graph during GatheredParameters? Since backward needs to scan whole graph.
I'm training a model with ema states which means module state dict needs to be gathered after each step. When zero stage 3 is enabled
model_engine.backward
become slower as step grows. Training loopif
model_engine._zero3_consolidated_16bit_state_dict
is disabled, backward time is not growing and everything works fine.ds version is
0.12.6
Tried on 8 a800 & 8 3090