pfnet / pfrl

PFRL: a PyTorch-based deep reinforcement learning library
MIT License
1.2k stars 157 forks source link

Fixed save/load problem on dqn.py #184

Closed jmribeiro closed 1 year ago

jmribeiro commented 1 year ago

Saving and loading the DQN agent would not save/load four needed attributes:

This caused the agent to have different a performance when evaluated without killing the program vs when saving the agent, killing the program, resuming the program and loading the agent.

Fig 1 - Training without checkpoints (i.e. same program ran from start to finish) plot

Fig 2 - Training with checkpoint (i.e., program killed at every t steps and agents loaded from disk) plot

My proposed solution (working, but applied only to the DQN agent) was to add new save_snapshot and load_snapshot methods on the agent's class (without overwriting the original save and load methods, avoiding saving the replay buffer every time):

    def save_snapshot(self, dirname: str) -> None:
        self.save(dirname)
        torch.save(
            self.t, os.path.join(dirname, "t.pt")
        )
        torch.save(
            self.optim_t, os.path.join(dirname, "optim_t.pt")
        )
        torch.save(
            self._cumulative_steps, os.path.join(dirname, "_cumulative_steps.pt")
        )
        self.replay_buffer.save(
            os.path.join(dirname, "replay_buffer.pkl")
        )

    def load_snapshot(self, dirname: str) -> None:
        self.load(dirname)
        self.t = torch.load(
            os.path.join(dirname, "t.pt")
        )
        self.optim_t = torch.load(
            os.path.join(dirname, "optim_t.pt")
        )
        self._cumulative_steps = torch.load(
            os.path.join(dirname, "_cumulative_steps.pt")
        )
        self.replay_buffer.load(
            os.path.join(dirname, "replay_buffer.pkl")
        )

This change is working as intended, training is resumed properly after reloading the agent from disk:

Fig 3 - Training with checkpoint (New patch) (i.e., program killed at every t steps and agents loaded from disk) image

muupan commented 1 year ago

/test

pfn-ci-bot commented 1 year ago

Successfully created a job for commit 8fc26f4:

jmribeiro commented 1 year ago

/test

@muupan There seem to be a problem with tests on test_acer.py (unrelated to the changes)

image