lean-dojo / ReProver

Retrieval-Augmented Theorem Provers for Lean
https://leandojo.org
MIT License
229 stars 53 forks source link

RuntimeError: CUDA error: unspecified launch failure #68

Closed yifan123 closed 2 months ago

yifan123 commented 3 months ago

Hi, Thanks for your work. I want to reproduce the training process of the Premise Retriever. There are no issues during training, but there is a bug during testing. I followed the instructions in the README for installation, and it seems like there is a GPU memory access out-of-bounds error.

My script:

python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml --trainer.logger.name train_retriever_random --trainer.logger.save_dir logs/train_retriever_random

Env:

lean-dojo 2.0.3 torch 2.3.0 deepspeed 0.14.5 reprover newest

Outputs:

wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
2024-08-09 10:47:44.360 | INFO     | retrieval.model:on_fit_start:150 - Logging to logs/train_retriever_random
SLURM auto-requeueing enabled. Setting signal handlers.
Time to load fused_adam op: 40.49061441421509 seconds
2024-08-09 10:47:44.667 | INFO     | retrieval.model:reindex_corpus:189 - Re-indexing the retrieval corpus
100%|██████████| 2828/2828 [11:37<00:00,  4.05it/s]
Epoch 0: 100%|██████████| 45966/45966 [5:46:06<00:00,  2.21it/s, v_num=hg6u]
2024-08-09 16:45:35.598 | INFO     | retrieval.model:reindex_corpus:189 - Re-indexing the retrieval corpus
100%|██████████| 2828/2828 [14:25<00:00,  3.27it/s]
Traceback (most recent call last):      | 9/67 [00:29<03:08,  0.31it/s]
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
    self.fit_loop.run()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 141, in run
    self.on_advance_end(data_fetcher)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 295, in on_advance_end
    self.val_loop.run()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 410, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/research/ai4math/DeepSpeed/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
    ret_val = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
 File "/mnt/petrelfs/liujie.dispatch/research/ai4math/DeepSpeed/deepspeed/runtime/engine.py", line 1846, in forward
    loss = self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
    out = method(*_args, **_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/research/ai4math/ReProver/retrieval/model.py", line 221, in validation_step
    context_emb = self._encode(batch["context_ids"], batch["context_mask"])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/research/ai4math/ReProver/common.py", line 320, in get_nearest_premises
    scores[j].append(similarities[j, i].item())
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/petrelfs/liujie.dispatch/research/ai4math/ReProver/retrieval/main.py", line 25, in <module>
    main()
  File "/mnt/petrelfs/liujie.dispatch/research/ai4math/ReProver/retrieval/main.py", line 20, in main
    cli = CLI(PremiseRetriever, RetrievalDataModule)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/cli.py", line 394, in __init__
    self._run_subcommand(self.subcommand)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/cli.py", line 701, in _run_subcommand
    fn(**fn_kwargs)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 60, in _call_and_handle_interrupt
    trainer._teardown()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1009, in _teardown
    self.strategy.teardown()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/ddp.py", line 419, in teardown
    super().teardown()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/parallel.py", line 133, in teardown
    super().teardown()
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 531, in teardown
    _optimizers_to_device(self.optimizers, torch.device("cpu"))
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_fabric/utilities/optimizer.py", line 28, in _optimizers_to_device
    _optimizer_to_device(opt, device)
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_fabric/utilities/optimizer.py", line 34, in _optimizer_to_device
    optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 52, in apply_to_collection
    return _apply_to_collection_slow(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 96, in _apply_to_collection_slow
    return function(data, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_fabric/utilities/apply_func.py", line 103, in move_data_to_device
    return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 64, in apply_to_collection
    return function(data, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/liujie.dispatch/miniconda3/envs/ReProver/lib/python3.11/site-packages/lightning_fabric/utilities/apply_func.py", line 97, in batch_to
    data_output = data.to(device, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
yangky11 commented 3 months ago

Unfortunately I wasn't able to reproduce the problem. This kind of problem can be caused by a particular combination of GPU model and PyTorch/HuggingFace/DeepSpeed/CUDA version. I'd suggest re-try with a new conda environment and install the latest version of all dependencies.