google-deepmind / open_spiel

OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning and search/planning in games.
Apache License 2.0
4.27k stars 937 forks source link

Q-learning is a loser? #1180

Closed StepHaze closed 7 months ago

StepHaze commented 9 months ago

Hello, I'm trying to create a strong Mancala bot. I chose Q-learning: `# Let's do independent Q-learning in Mancala, and play it against random.

RL is based on python/examples/independent_tabular_qlearning.py

from open_spiel.python import rl_environment from open_spiel.python import rl_tools from open_spiel.python.algorithms import tabular_qlearner

Create the environment

env = rl_environment.Environment("mancala") num_players = env.num_players num_actions = env.action_spec()["num_actions"]

Create the agents

agents = [ tabular_qlearner.QLearner(player_id=idx, num_actions=num_actions) for idx in range(num_players) ]

Train the Q-learning agents in self-play.

for cur_episode in range(100000): if cur_episode % 1000 == 0: print(f"Episodes: {cur_episode}") time_step = env.reset() while not time_step.last(): player_id = time_step.observations["current_player"] agent_output = agents[player_id].step(time_step) time_step = env.step([agent_output.action])

Episode is over, step all agents with final info state.

for agent in agents: agent.step(time_step) `

And then a game against random agent: `# Evaluate the Q-learning agent against a random agent. from open_spiel.python.algorithms import random_agent eval_agents = [agents[0], random_agent.RandomAgent(1, num_actions, "Entropy Master 2000") ]

time_step = env.reset() while not time_step.last(): print("") print(env.get_state) player_id = time_step.observations["current_player"]

Note the evaluation flag. A Q-learner will set epsilon=0 here.

agent_output = eval_agents[player_id].step(time_step, is_evaluation=True) print(f"Agent {player_id} chooses {env.get_state.action_to_string(agent_output.action)}") time_step = env.step([agent_output.action])

print("") print(env.get_state) print(time_step.rewards)`

What really surprised me is that a TRAINED agent loses a random agent? How come? Could anyone explain me this please?

lanctot commented 9 months ago

Hi @StepHaze,

It's probably due to the size of Mancala? i.e. I would guess something similar would happen with chess. The state space is too large, so most states are probably only visited once (or very small number of times). Tabular methods will only work on small games (e.g. Tic-Tac-Toe has ~4500 states total).

Try self-play DQN instead? It's basically Q-learning with function approximation and a few extra things (experience replay, target network, etc.)

DmitriMedved commented 9 months ago

Hi @StepHaze,

Try self-play DQN instead? It's basically Q-learning with function approximation and a few extra things (experience replay, target network, etc.)

I'm interested in it too. Can you please provide us with an example for DQN ?

lanctot commented 9 months ago

Sure. In OpenSpiel all the agents follow the same API (they are subclasses of the agent base class). So, 99% of the code above can be re-used and you just have to change the instantiation of the agent to DQN instead of tabular_qlearner.QLearning.

There is a full example here.

lanctot commented 9 months ago

Note the example above is a bit old and still uses the TF1-based DQN.

It can be easily swapped for the PyTorch DQN or JAX DQN if you prefer.

StepHaze commented 9 months ago

Hi @lanctot Thank you Marc. I always thought that a trained neural network (weights) should be stored in the external file, and be loaded every time the bot starts. Am I correct?

lanctot commented 9 months ago

Sounds like you are talking about checkpoints. The DQN supports saving and loading the networks themselves, but not the replay buffer. But it's certainly not enabled by default. It's something you have to do manually in your training script (check the save + load methods in dqn.py). Or if you use JAX/PyTorch you can just serialize the agent via pickle and save/load that (which would then include the replay buffer).