jonathan-laurent / AlphaZero.jl

A generic, simple and fast implementation of Deepmind's AlphaZero algorithm.
https://jonathan-laurent.github.io/AlphaZero.jl/stable/
MIT License
1.24k stars 139 forks source link

StackOverflowError on cyclic state graph game #47

Open blacksph3re opened 3 years ago

blacksph3re commented 3 years ago

I am currently implementing a board game called tak. In this game, it is possible to move stones around, so it is possible to move a stone back and forth. Theoretically it is possible to reach a terminal state, at least a draw, from every state when choosing the right actions. Practically, the MCTS decides to loop infinitely, resulting in:

Initializing a new AlphaZero environment

  Initial report

    Number of network parameters: 159,457
    Number of regularized network parameters: 156,736
    Memory footprint per MCTS node: 24056 bytes

  Running benchmark: AlphaZero against MCTS (1000 rollouts)

StackOverflowError:StackOverflowError:
Stacktrace:
  [1] check_win(board::Array{Union{Nothing, Tuple{Main.tak.TakEnv.Stone, Main.tak.TakEnv.Player}}, 3}, active_player::Main.tak.TakEnv.Player)
    @ Main.tak.TakEnv ~/Programming/tak/src/TakEnv.jl:622
  [2] play!(g::Main.tak.TakInterface.TakGame, action_idx::Int64)
    @ Main.tak.TakInterface ~/Programming/tak/src/TakInterface.jl:81
  [3] run_simulation!(env::AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.MCTS.RolloutOracle{Main.tak.TakInterface.TakSpec}}, game::Main.tak.TakInterface.TakGame; η::Vector{Float64}, root::Bool)
    @ AlphaZero.MCTS ~/.julia/packages/AlphaZero/eAGva/src/mcts.jl:214
  [4] run_simulation!(env::AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.MCTS.RolloutOracle{Main.tak.TakInterface.TakSpec}}, game::Main.tak.TakInterface.TakGame; η::Vector{Float64}, root::Bool) (repeats 11808 times)
    @ AlphaZero.MCTS ~/.julia/packages/AlphaZero/eAGva/src/mcts.jl:218
  [5] explore!(env::AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.MCTS.RolloutOracle{Main.tak.TakInterface.TakSpec}}, game::Main.tak.TakInterface.TakGame, nsims::Int64)
    @ AlphaZero.MCTS ~/.julia/packages/AlphaZero/eAGva/src/mcts.jl:243
  [6] think(p::MctsPlayer{AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.MCTS.RolloutOracle{Main.tak.TakInterface.TakSpec}}}, game::Main.tak.TakInterface.TakGame)
    @ AlphaZero ~/.julia/packages/AlphaZero/eAGva/src/play.jl:198
  [7] think
    @ ~/.julia/packages/AlphaZero/eAGva/src/play.jl:259 [inlined]
  [8] play_game(gspec::Main.tak.TakInterface.TakSpec, player::TwoPlayers{MctsPlayer{AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.Batchifier.BatchedOracle{AlphaZero.Batchifier.var"#6#7"}}}, MctsPlayer{AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.MCTS.RolloutOracle{Main.tak.TakInterface.TakSpec}}}}; flip_probability::Float64)
    @ AlphaZero ~/.julia/packages/AlphaZero/eAGva/src/play.jl:308
  [9] (::AlphaZero.var"#simulate_game#70"{TwoPlayers{MctsPlayer{AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.Batchifier.BatchedOracle{AlphaZero.Batchifier.var"#6#7"}}}, MctsPlayer{AlphaZero.MCTS.Env{Tuple{BitVector, Main.tak.TakEnv.Player}, AlphaZero.MCTS.RolloutOracle{Main.tak.TakInterface.TakSpec}}}}, AlphaZero.Benchmark.var"#5#9"{ProgressMeter.Progress}, Simulator{AlphaZero.Benchmark.var"#4#8"{Env{Main.tak.TakInterface.TakSpec, SimpleNet, Tuple{BitVector, Main.tak.TakEnv.Player}}, AlphaZero.Benchmark.Duel}, AlphaZero.Benchmark.var"#net#6"{Env{Main.tak.TakInterface.TakSpec, SimpleNet, Tuple{BitVector, Main.tak.TakEnv.Player}}, AlphaZero.Benchmark.Duel}, typeof(record_trace)}, Main.tak.TakInterface.TakSpec, SimParams})(sim_id::Int64)
    @ AlphaZero ~/.julia/packages/AlphaZero/eAGva/src/simulations.jl:232
 [10] macro expansion
    @ ~/.julia/packages/AlphaZero/eAGva/src/util.jl:187 [inlined]
 [11] (::AlphaZero.Util.var"#9#10"{AlphaZero.var"#68#69"{AlphaZero.Benchmark.var"#5#9"{ProgressMeter.Progress}, Simulator{AlphaZero.Benchmark.var"#4#8"{Env{Main.tak.TakInterface.TakSpec, SimpleNet, Tuple{BitVector, Main.tak.TakEnv.Player}}, AlphaZero.Benchmark.Duel}, AlphaZero.Benchmark.var"#net#6"{Env{Main.tak.TakInterface.TakSpec, SimpleNet, Tuple{BitVector, Main.tak.TakEnv.Player}}, AlphaZero.Benchmark.Duel}, typeof(record_trace)}, Main.tak.TakInterface.TakSpec, SimParams, AlphaZero.var"#48#49"{Channel{Any}}, AlphaZero.var"#make#65"{Channel{Any}}}, UnitRange{Int64}, typeof(vcat), ReentrantLock})()
    @ AlphaZero.Util ~/.julia/packages/ThreadPools/P1NVV/src/macros.jl:259

The part of the stack trace that is above [4] (which is in my implementation) varies, the cause is likely run_simulation which ends up in an infinite recursion. From my understanding, UCT should place a lower weight on states visited a lot of times and thus should, by exclusion, end up performing actions that bring it to new states at some point and thus to a terminal state. As the game depth in a normal game is roughly 100 moves, after 11k recursions this mechanism should have kicked in. If I prohibit movement actions altogether, training works fine. I am not sure how I should approach this problem, does anyone have experience with this?

jonathan-laurent commented 3 years ago

Thanks for your interest in AlphaZero.jl!

I think you nailed it in the title of this issue and MCTS can indeed currently fail to terminate in environments whose state graphs have loops. The reason for this is that in every MCTS simulation and when trying to select a node for expansion, MCTS simulates a step until either the game has ended or the new state hasn't been expanded yet. Note that contrary to many MCTS implementations, we treat the MCTS tree as a dag and share information between nodes that correspond to the same state (concretely, the dag is implemented by a hash table indexed by states).

I can see several fixes here and it should not be too hard for you to implement one of them (and even propose a patch) before I make a decision on what's the best solution to integrate to the main branch:

Do you have any thoughts on this?

MichaelGreisberger commented 3 years ago

I have encountered the same problem with a game I am implementing. I solved the problem by tracking the depth of a state and ending the game after a certain number of moves. Games ended in this way are counted as draws. This approach does not require any changes to AlphaZero.jl, as you can simply implement this in GI.game_terminated(::GameEnv). One drawback could be that the state is used to track depth. Therefore the hash of two equal states might be different. A more general solution would be great, as this problem occurs in many games.

jonathan-laurent commented 3 years ago

I believe that keeping track of time in the state is a good solution indeed and in fact this is exactly what I am doing in my grid world example: https://github.com/jonathan-laurent/AlphaZero.jl/blob/14c62e766460a7977c3985c2c42b7540f72db884/games/grid-world/game.jl#L20

This defeats the purpose of implementing state-sharing in MCTS indeed but most MCTS implementations do not offer state-sharing anyway. Regarding ending the game after a number of moves, it is in general a good trick although it can introduce subtle biases in the value function if the neural network does not see the timer counter (e.g. the network believes something very bad is going to happen but it does not, because the game times out before the bad thing happens).

I think that at some point the best solution is just for me to offer an alternative API-compatible implementation of MCTS that does not implement sharing and let the user choose. Indeed, such an implementation could also be more memory efficient in some cases.

MichaelGreisberger commented 3 years ago

I guess it would be possible to enable state sharing nevertheless, by overriding the default hash function to exclude the time filed. In other languages like Java this can be done quite easily but I don't know about julia. I guess a tree-based data structure without state sharing would be a great addition, as it also enables simple cache eviction. Thank you for mentioning that hiding the time form the network introduces a bias. I have not thought about this before and will look into adding this to the networks input.

blacksph3re commented 3 years ago

Thanks at first for your insightful thoughts :)

Enabling state sharing and keeping the time variable in state as Michael suggested could yield weird results. When declaring a state as a draw, this is an inaccurate result and it will propagate a long way back through the mcts tree. Now state-sharing this would in my understanding persist this inaccuracy even more. I don't know how that plays out in practice, maybe the mathematical inaccuracy is made up by increased memory efficiency. The cleanest implementation for that would imho be to just limit depth on the MCTS search in the run_simulation function as Jonathan suggested instead of hacking the hash function.

I think I might try out multiple options but I will start with the timer. That seems to yield the lowest implementation overhead and from my rationale brings the lowest algorithmic inaccuracy.

The option with the loop detection is actually a bit problematic. In my case it is valid but I don't know if it's okay to assume in general that running a loop should not happen. Especially when playing outside of self-play this could be tricky, as an enemy could exploit this "blind spot". I think a threshold (e.g. visiting the same state 10 times) would be better for these cases. Also in non deterministic games it might actually make sense to run a loop as a specific state transition in that loop might bring you to a very good state with a lower probability, so you just try multiple times.

And implementing the depth counter on the MCTS nodes sounds a bit tricky to me. I am wondering if that wouldn't introduce a bug. I could imagine a scenario where the first thing the mcts does is jump into the loop. Now the only known child to that state is the loop action, and selecting it would keep you in the loop. I think it should be stored along (s,a) pairs - as not to count node visits but edge visits and disable an edge? But I am not sure on that. The argument against disabling loops completely should also apply here.

Also exhibiting the timer counter to the nn doesn't seem out of this world, I could imagine allocating a couple of inputs for "30, 15, 10, 5 moves left". That way the value inaccuracy could be made up for.

Thanks a lot for your input. Should I keep the issue open for the time being?

jonathan-laurent commented 3 years ago

I agree that having a timer on the environment side is probably the right move here. Let's keep this issue opened until the current behavior is better documented at the very least.

Honestly, I think I should also offer a version of MCTS without sharing. The advantages of the current solution are:

However, the price to pay for this is heavy: