lean-dojo / ReProver

Retrieval-Augmented Theorem Provers for Lean
https://leandojo.org
MIT License
229 stars 53 forks source link

Deepspeed checkpoint error when using the released Huggingface checkpoints #71

Closed Sara-Rajaee closed 1 month ago

Sara-Rajaee commented 2 months ago

Hi, I want to run the experiments in "Retrieving Premises for All Proof States" using your HF checkpoints (without training the retriever). However, when I use the provided command, I get a deepspeed checkpoint error. I would appreciate it if you could help addressing this.

This is the command I use: python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path kaiyuy/leandojo-lean4-retriever-tacgen-byt5-small --trainer.logger.name predict_retriever_random --trainer.logger.save_dir logs/predict_retriever_random

This is the error I get:

Restoring states from the checkpoint path at kaiyuy/leandojo-lean4-retriever-tacgen-byt5-small [2024-08-20 01:49:18,235] [WARNING] [engine.py:2796:load_checkpoint] Unable to find latest file at kaiyuy/leandojo-lean4-retriever-tacgen-byt5-small/latest, if trying to load latest checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint. rank0: Traceback (most recent call last): rank0: File "/local/mnt/workspace/srajaee/wiai_trial/benchmarks/math_reasoning/ReProver/retrieval/main.py", line 25, in

rank0: File "/local/mnt/workspace/srajaee/wiai_trial/benchmarks/math_reasoning/ReProver/retrieval/main.py", line 20, in main rank0: cli = CLI(PremiseRetriever, RetrievalDataModule) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 353, in init

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 642, in _run_subcommand

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 845, in predict rank0: return call._call_and_handle_interrupt( rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 41, in _call_and_handle_interrupt rank0: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 91, in launch rank0: return function(args, **kwargs) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 887, in _predict_impl rank0: results = self._run(model, ckpt_path=ckpt_path) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in _run

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 395, in _restore_modules_and_callbacks

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 82, in resume_start rank0: loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/strategies/deepspeed.py", line 792, in load_checkpoint rank0: raise MisconfigurationException( rank0: lightning_fabric.utilities.exceptions.MisconfigurationException: DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint or a single checkpoint file with Trainer(strategy=DeepSpeedStrategy(load_full_weights=True)). rank0: Traceback (most recent call last): rank0: File "/local/mnt/workspace/srajaee/wiai_trial/benchmarks/math_reasoning/ReProver/retrieval/main.py", line 25, in

rank0: File "/local/mnt/workspace/srajaee/wiai_trial/benchmarks/math_reasoning/ReProver/retrieval/main.py", line 20, in main rank0: cli = CLI(PremiseRetriever, RetrievalDataModule) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 353, in init

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 642, in _run_subcommand

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 845, in predict rank0: return call._call_and_handle_interrupt( rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 41, in _call_and_handle_interrupt rank0: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 91, in launch rank0: return function(args, **kwargs) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 887, in _predict_impl rank0: results = self._run(model, ckpt_path=ckpt_path) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in _run

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 395, in _restore_modules_and_callbacks

rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 82, in resume_start rank0: loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) rank0: File "/opt/mamba/envs/wiai/lib/python3.10/site-packages/pytorch_lightning/strategies/deepspeed.py", line 792, in load_checkpoint rank0: raise MisconfigurationException( rank0: lightning_fabric.utilities.exceptions.MisconfigurationException: DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint or a single checkpoint file with Trainer(strategy=DeepSpeedStrategy(load_full_weights=True)).

yangky11 commented 2 months ago

Here it expects a DeepSpeed checkpoint instead of a Hugging Face checkpoint. One option would be training the model by yourself (python retrieval/main.py fit ...). Another workaround is to convert the Hugging Face checkpoint to DeepSpeed by running the training script for 1 step with 0 learning rate (using something like retrieval/confs/cli_dummy.yaml).

Sara-Rajaee commented 1 month ago

Thanks for your answer, I'll close the issue