datamllab / rlcard

Reinforcement Learning / AI Bots in Card (Poker) Games - Blackjack, Leduc, Texas, DouDizhu, Mahjong, UNO.
http://www.rlcard.org
MIT License
2.87k stars 619 forks source link

Add storing and restoring RL agent checkpoints #280

Closed kaiks closed 1 year ago

kaiks commented 1 year ago

This PR introduces checkpoints for RL agents (DQN and NSFP).

Checkpoints are data describing complete agent states (weights and parameters) during training.

To save an agent, you can either set the checkpoint path and save it automatically during training every n training steps. For example, specifying save_path and save_every parameters during agent instantiation:

agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=500
)

will save the training progress every 500 steps to a single file.

A complete example of loading an agent looks as follows:

import torch
from rlcard.agents import DQNAgent
dict = torch.load('experiments/dqn_checkpoint_v2/checkpoint.pt')
agent = DQNAgent.from_checkpoint(checkpoint = dict)

Training can then be resumed or the agent attributes might be inspected or debugged.

Use cases for checkpoints that I can see:

You can also, for instance, save the highest reward agent on every evaluation step when using the run_rl.py script by manually saving the agent state:

if current_reward > best_reward:
    agent.save_checkpoint(path, filename='best_checkpoint.pt'):

As a next step, it would be possible to streamline loading the models, for instance by extracting and modifying the load_model function found in evaluate.py:

def load_agent_from_checkpoint(checkpoint_dict):
    if checkpoint_dict["agent_type"] == "NFSPAgent":
        return NFSPAgent.from_checkpoint(checkpoint_dict)
    elif checkpoint_dict["agent_type"] == "DQNAgent":
        return DQNAgent.from_checkpoint(checkpoint_dict)
    else:
        raise Exception("Unknown agent type")
end

def load_model(model_path, env=None, position=None, device=None):
    if os.path.isfile(model_path):  # Torch model
        import torch
        agent = torch.load(model_path, map_location=device)
        if "agent_type" in agent:
            agent = load_agent_checkpoint(agent)
        agent.set_device(device)
    elif os.path.isdir(model_path):  # CFR model
        from rlcard.agents import CFRAgent
        agent = CFRAgent(env, model_path)
        agent.load()
    elif model_path == 'random':  # Random model
        from rlcard.agents import RandomAgent
        agent = RandomAgent(num_actions=env.num_actions)
    else:  # A model in the model zoo
        from rlcard import models
        agent = models.load(model_path).agents[position]

But this requires more thought and would blow the PR up a bit.

In principle there's nothing stopping us from adding similar serialization for every other agent type and generalizing the loading of saved agent checkpoints.

daochenzha commented 1 year ago

@kaiks Thank you for the contribution! I have carefully reviewed the PR. It looks great. I just added some minor comments. Your proposal of modifying evaluate.py also makes lots of sense. Please consider submitting another PR as well. Have a great day

kaiks commented 1 year ago

Hi @daochenzha, thank you for your feedback! I'm glad to hear the change looks useful. I'm not seeing the comments you mentioned. Maybe you forgot to publish the review?

I'll try submit a follow up PR with the discussed changes later - closer to the end of this or next week.

daochenzha commented 1 year ago

@kaiks Yes, I forgot to publish it. You should be able to see it now.

kaiks commented 1 year ago

@daochenzha thanks for the review. I addressed your feedback

daochenzha commented 1 year ago

@kaiks LGTM, thank you1