datamllab / rlcard

Reinforcement Learning / AI Bots in Card (Poker) Games - Blackjack, Leduc, Texas, DouDizhu, Mahjong, UNO.
http://www.rlcard.org
MIT License
2.78k stars 615 forks source link

Fix RuntimeError when running example from docs\toy-examples.md #284

Closed Clarivy closed 1 year ago

Clarivy commented 1 year ago

Description:

When running the following example from docs\toy-examples.md,

python3 examples/run_rl.py --env blackjack --algorithm dqn --log_dir experiments/blackjack_dqn_result/

I encountered a RuntimeError:

Traceback (most recent call last):
  File "xx\rlcard\examples\run_rl.py", line 181, in <module>
    train(args)
  File "xx\rlcard\examples\run_rl.py", line 86, in train
    agent.feed(ts)
  File "xx\rlcard\agents\dqn_agent.py", line 139, in feed
    self.train()
  File "xx\rlcard\agents\dqn_agent.py", line 235, in train
    self.save_checkpoint(self.save_path)
  File "xx\rlcard\agents\dqn_agent.py", line 322, in save_checkpoint
    torch.save(self.checkpoint_attributes(), path + '/' + filename)
  File "xx\torch\serialization.py", line 440, in save
    with _open_zipfile_writer(f) as opened_zipfile:
  File "xx\torch\serialization.py", line 315, in _open_zipfile_writer
    return container(name_or_buffer)
  File "xx\torch\serialization.py", line 288, in __init__
    super().__init__(torch._C.PyTorchFileWriter(str(name)))
RuntimeError: File experiments/blackjack_dqn_result//checkpoint_dqn.pt cannot be opened.

It seems that the issue is related to line 322 in agents\dqn_agent.py, where the save_checkpoint function attempts to save the checkpoint, but the path concatenated with '/', which is not always compatible with some OS. A possible solution is to use os.path.join instead of string concatenation.

Steps to reproduce:

  1. Clone the repository and navigate to the project directory.
  2. Run the following command:
    python3 examples/run_rl.py --env blackjack --algorithm dqn --log_dir experiments/blackjack_dqn_result/

Expected result:

The example runs successfully without any RuntimeError.

Actual result:

A RuntimeError occurs when the script tries to save the checkpoint file.

Possible solution:

Following this answer https://stackoverflow.com/a/2953843/17576960 .

Replace the string concatenation on line 322 in agents\dqn_agent.py with os.path.join, like so:

torch.save(self.checkpoint_attributes(), os.path.join(path, filename))

This should resolve the issue and make the path compatible with different operating systems.

daochenzha commented 1 year ago

@Clarivy Thank you for the fix!