JuliaPOMDP / DiscreteValueIteration.jl

Value iteration solver for MDPs
Other
20 stars 12 forks source link

default max_iterations too low #45

Open zsunberg opened 3 years ago

zsunberg commented 3 years ago

Right now the max iterations are very low (100) and the algorithm may not converge. This can lead to confusing results.

zsunberg commented 3 years ago

Here is a simple example where the number of iterations is too low:

using POMDPs
using QuickPOMDPs: QuickPOMDP
using POMDPSimulators: RolloutSimulator
using QMDP

###################
# Problem 2: Cancer
###################

cancer = QuickPOMDP(
    states = [:healthy, :in_situ, :invasive, :death],
    actions = [:wait, :test, :treat],
    observations = [true, false],

    transition = function (s, a)
        if s == :healthy
            return SparseCat([:healthy, :in_situ], [0.98, 0.02])
        elseif s == :in_situ
            if a == :treat
                return SparseCat([:healthy, :in_situ], [0.6, 0.4])
            else
                return SparseCat([:in_situ, :invasive], [0.9, 0.1])
            end
        elseif s == :invasive
            if a == :treat
                return SparseCat([:healthy, :death, :invasive], [0.2, 0.2, 0.6])
            else
                return SparseCat([:invasive, :death], [0.4, 0.6])
            end
        else
            return Deterministic(:death)
        end
    end,

    observation = function (a, sp)
        if a == :test
            if sp == :healthy
                return SparseCat([true, false], [0.05, 0.95])
            elseif sp == :in_situ
                return SparseCat([true, false], [0.8, 0.2])
            else
                return Deterministic(true)
            end
        elseif a == :treat
            if sp in (:in_situ, :invasive)
                return Deterministic(true)
            end
        end
        return Deterministic(false)
    end,

    reward = function (s, a)
        if s == :death
            return 0.0
        elseif a == :wait
            return 1.0
        elseif a == :test
            return 0.8
        elseif a == :treat
            return 0.1
        end
    end,

    discount = 0.99,
    initialstate = Deterministic(:healthy),
    isterminal = s->s==:death,
)

qmdp_100 = solve(QMDPSolver(verbose=true, max_iterations=100), cancer)
println("\n\n***************************************\n\n")
qmdp_1000 = solve(QMDPSolver(verbose=true, max_iterations=1000), cancer)