jonathan-laurent / AlphaZero.jl

A generic, simple and fast implementation of Deepmind's AlphaZero algorithm.
https://jonathan-laurent.github.io/AlphaZero.jl/stable/
MIT License
1.24k stars 140 forks source link

AZ much worse than generic solution for simple game #193

Open 70Gage70 opened 1 year ago

70Gage70 commented 1 year ago

I'm trying to train AZ on single-player 21. You have a shuffled deck of cards and at each step you either "take" a card (and add its value to your total, such that Ace = 1, 2 = 2 ... face cards = 10) or "stop" and receive your current total. The obvious strategy would be to take if the expected value of a draw would leave your total <=21, and stop otherwise. This gives an average reward of roughly 14. I defined the game and used the exact training parameters from gridworld.jl and this is the result:

benchmark_reward

I don't understand why (i) the rewards are much less than 14 and (ii) why AZ is worse than the network.

Game

using AlphaZero 
using CommonRLInterface 
const RL = CommonRLInterface
import Random as RNG

# using StaticArrays
# using Crayons

const NONFACE_CARDS = [i for j = 1:4 for i = 1:10]
const FACE_CARDS = [10 for j = 1:4 for i = 1:3]
const STANDARD_DECK = map(UInt8, vcat(NONFACE_CARDS, FACE_CARDS))

### MANDATORY INTERFACE

# state = "what the player should look at"
mutable struct Env21 <: AbstractEnv
    deck::Vector{UInt8} 
    state::UInt8 # points
    reward::UInt8
    terminated::Bool
end

function RL.reset!(env::Env21)
    env.deck = RNG.shuffle(STANDARD_DECK)
    env.state = 0
    env.reward = 0
    env.terminated = false

    return nothing
end

function Env21()
    deck = RNG.shuffle(STANDARD_DECK)
    state = 0
    reward = 0
    terminated = false

    return Env21(deck, state, reward, terminated)
end

RL.actions(env::Env21) = [:take, :stop]
RL.observe(env::Env21) = env.state
RL.terminated(env::Env21) = env.terminated

function RL.act!(env::Env21, action)
    if action == :take
        draw = popfirst!(env.deck)
        env.state += draw

        if env.state >= 22 
            env.reward = 0
            env.state = 0 ######################### okay?
            env.terminated = true
        end
    elseif action == :stop
        env.reward = env.state
        env.terminated = true
    else
        error("Invalid action $action")
    end

    return env.reward
end

### TESTING

# env = Env21()
# reset!(env)
# rsum = 0.0
# while !terminated(env)
#     global rsum += act!(env, rand(actions(env))) 
# end
# @show rsum

### MULTIPLAYER INTERFACE

RL.players(env::Env21) = [1]
RL.player(env::Env21) = 1 

### Optional Interface

RL.observations(env::Env21) = map(UInt8, collect(0:21))
RL.clone(env::Env21) = Env21(copy(env.deck), copy(env.state), copy(env.reward), copy(env.terminated))
RL.state(env::Env21) = env.state
RL.setstate!(env::Env21, new_state) = (env.state = new_state)
RL.valid_action_mask(env::Env21) = BitVector([1, 1])

### AlphaZero Interface

function GI.render(env::Env21)
  println(env.deck)
  println(env.state)
  println(env.reward)
  println(env.terminated)

  return nothing
end

function GI.vectorize_state(env::Env21, state)
  v = zeros(Float32, 22)
  v[state + 1] = 1

  return v
end

const action_names = ["take", "stop"]

function GI.action_string(env::Env21, a)
  idx = findfirst(==(a), RL.actions(env))
  return isnothing(idx) ? "?" : action_names[idx]
end

function GI.parse_action(env::Env21, s)
  idx = findfirst(==(s), action_names)
  return isnothing(idx) ? nothing : RL.actions(env)[idx]
end

function GI.read_state(env::Env21)
  return env.state
end

GI.heuristic_value(::Env21) = 0.

GameSpec() = CommonRLInterfaceWrapper.Spec(Env21())

Canonical strategy

import Random as RNG

const NONFACE_CARDS = [i for j = 1:4 for i = 1:10]
const FACE_CARDS = [10 for j = 1:4 for i = 1:3]
const STANDARD_DECK = map(UInt8, vcat(NONFACE_CARDS, FACE_CARDS))

function mc_run()
    deck = RNG.shuffle(STANDARD_DECK)
    score = 0
    while true
        expected_score = score + sum(STANDARD_DECK)/length(deck) 

        if expected_score >= 22
            return score
        else
            score = score + popfirst!(deck)
            if score >= 22
                return 0
            end
        end
    end

end

function mc(n_trials)
    score = 0 
    for i = 1:n_trials
        score = score + mc_run()
    end
    return score/n_trials
end

mc(10000)
jonathan-laurent commented 1 year ago

I don't have time to look too deeply but here are a few remarks:

My advice:

70Gage70 commented 1 year ago

Thanks for the tips, appreciate it. I'll certainly take a look at the MCTS