lean-dojo / ReProver

Retrieval-Augmented Theorem Provers for Lean
https://leandojo.org
MIT License
208 stars 44 forks source link

DeepSpeed checkpoints for the retreiver? #15

Closed bollu closed 1 year ago

bollu commented 1 year ago

Thanks a lot for the superb code!

I wish to rerun the predict script. To do this, I tried download the huggingface model available for the retriever and then running

$ python3 -m retrieval.main predict --config retrieval/confs/cli_random.yaml --ckpt_path ./data/model-retriever 

where the directory model-retriever contains the huggingface model:

t-sibhat@hostname:~/projects/premise-selection/premise-selection/src/Projects/premise-selection/ReProver$ ls -R data/model-retriever/
data/model-retriever/:
latest

data/model-retriever/latest:
README.md  config.json  generation_config.json  pytorch_model.bin  special_tokens_map.json  tokenizer_config.json

The above command fails with the backtrace:

$ python3 -m retrieval.main predict --config retrieval/confs/cli_random.yaml --ckpt_path ./data/model-retriever 
2023-07-05 22:30:53.243 | INFO     | common:__init__:198 - Building the corpus from data/leandojo_benchmark/corpus.jsonl
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 3407
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
[2023-07-05 22:31:00,708] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2023-07-05 22:31:00,899] [WARNING] [deepspeed.py:638:_auto_select_batch_size] Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. To ensure DeepSpeed logging remains correct, please manually pass the plugin with the batch size, `Trainer(strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=batch_size))`.
Enabling DeepSpeed BF16.
2023-07-05 22:31:01.811 | INFO     | retrieval.datamodule:_load_data:49 - Loading data from data/leandojo_benchmark/random/train.json
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 92962/92962 [00:13<00:00, 6929.21it/s]
2023-07-05 22:31:24.555 | INFO     | retrieval.datamodule:_load_data:99 - Loaded 204018 examples.
2023-07-05 22:31:24.571 | INFO     | retrieval.datamodule:_load_data:49 - Loading data from data/leandojo_benchmark/random/val.json
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1191.80it/s]
2023-07-05 22:31:26.304 | INFO     | retrieval.datamodule:_load_data:99 - Loaded 4253 examples.
2023-07-05 22:31:26.304 | INFO     | retrieval.datamodule:_load_data:49 - Loading data from data/leandojo_benchmark/random/test.json
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 7514.59it/s]
2023-07-05 22:31:26.625 | INFO     | retrieval.datamodule:_load_data:99 - Loaded 4516 examples.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Restoring states from the checkpoint path at ./data/model-retriever/
[2023-07-05 22:31:29,833] [WARNING] [engine.py:2594:load_checkpoint] Unable to find latest file at ./data/model-retriever/latest, if trying to load latest checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/t-sibhat/projects/premise-selection/premise-selection/src/Projects/premise-selection/ReProver/retrieval/main.py", line 24, in <module>
    main()
  File "/home/t-sibhat/projects/premise-selection/premise-selection/src/Projects/premise-selection/ReProver/retrieval/main.py", line 19, in main
    cli = CLI(PremiseRetriever, RetrievalDataModule)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 845, in predict
    return call._call_and_handle_interrupt(
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 41, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 91, in launch
    return function(*args, **kwargs)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 887, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 395, in _restore_modules_and_callbacks
    self.resume_start(checkpoint_path)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 82, in resume_start
    loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
  File "/home/t-sibhat/.local/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 792, in load_checkpoint
    raise MisconfigurationException(
lightning_fabric.utilities.exceptions.MisconfigurationException: DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`.

I gather that I am using the huggingface models incorrectly, and I should somehow make them compatible with deepspeed? Could you help me run the retrieval script with the released models?

yangky11 commented 1 year ago

Hi,

Those are checkpoints of Hugging Face Transformers, which are supposed to be used in your own pipeline.

For running our pipeline, you could use our PyTorch Lightning checkpoints here

bollu commented 1 year ago

Super, thanks a lot!