Closed exnx closed 2 weeks ago
Thanks for sharing this!
More specifically, I was wondering if there's a way to load the deepspeed checkpoint without need to initializing deepspeed? ie without deepspeed.init_distributed()
.
Sometimes I get errors using deepspeed, like mpi issues, and for just loading ckpts it would be nice not need to do all that.
In particular, I get this error when running the deepspeed.init_distributed()
.
[2024-08-09 10:26:24,175] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-08-09 10:26:24,175] [INFO] [comm.py:652:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[GPUCB52:3951616] OPAL ERROR: Not initialized in file pmix3x_client.c at line 112
--------------------------------------------------------------------------
The application appears to have been direct launched using "srun",
but OMPI was not built with SLURM's PMI support and therefore cannot
execute. There are several options for building PMI support under
SLURM, depending upon the SLURM version you are using:
version 16.05 or later: you can use SLURM's PMIx support. This
requires that you configure and build SLURM --with-pmix.
Versions earlier than 16.05: you must use either SLURM's PMI-1 or
PMI-2 support. SLURM builds PMI-1 by default, or you can manually
install PMI-2. You must then build Open MPI using --with-pmi pointing
to the SLURM PMI library location.
Please configure as appropriate and try again.
--------------------------------------------------------------------------
*** An error occurred in MPI_Init_thread
*** on a NULL communicator
*** MPI_ERRORS_ARE_FATAL (processes in this communicator will now abort,
*** and potentially your MPI job)
[GPUCB52:3951616] Local abort before MPI_INIT completed completed successfully, but am not able to aggregate error messages, and not able to guarantee that all other processes were killed!
Also, what if you use a zero stage 1, or not at all, ie zero stage 0?
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# remove all the deepspeed magic data
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir)
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
MODEL_PATH = "your_model_without_DS.pth"
# Save the model's state_dict
torch.save(model.state_dict(), MODEL_PATH)
# Load the model's state_dict
model.load_state_dict(torch.load(MODEL_PATH))
My understanding is that it get_fp32_state_dict_from_zero_checkpoint
only works on zero 2 or 3 optimizers? (not 0 or 1)?
zero 0 or zero 1, the model parameters are not partitioned. It means the full FP32 state dictionary is readily accessible without additional reconstruction steps.
@ycool, thanks for providing a workaround.
Closing this issue.
Hello, is there a way to load the deepspeed checkpoint without deepspeed? I'd like to somehow just load it with standard Pytorch. Or do we somehow need to convert it first? Deepspeed has some overhead that I'd like to get around for doing zero-shot evals downstream.