JuliaPOMDP / DiscreteValueIteration.jl

Value iteration solver for MDPs
Other
20 stars 12 forks source link

Solve does not work with MDPs using rewards of form R(s, a) and R(s, a, s') #58

Closed kevinbradner closed 5 months ago

kevinbradner commented 5 months ago

I am trying to make some modifications to the MDP used in the Julia Academy tutorial on MDPs. I get an error when using the vanilla DiscreteValueIteration solve function after trying to redefine my MDP's reward function R to have the signature function R(s, a) instead of function R(s, a=missing).

I started from the original notebook and made some minimal changes to recreate the error. My version of the notebook is here.

@req reward(::P,::S,::A,::S) is a line from the requirements section of the vanilla solve function in this repo. Based on that line, it looks like I need to define a reward that takes a (S, A, S') triple, but this does not seem to be the case when using the function.

The details are in my notebook linked above, but when I run the following lines:

solver = ValueIterationSolver(max_iterations=30);
policy = solve(solver, mdp)

I get this error:

MethodError: no method matching R(::Main.var"workspace#17".State)

Closest candidates are:

R(::Any, !Matched::Any)

@ Main.var"workspace#17" ~/Decision-Making-Under-Uncertainty/notebooks/1-MDPs backup 1.jl#==#f7814a66-23c8-4782-ba06-755397af87db:1

    show_requirements(::POMDPLinter.RequirementSet)@requirements_interface.jl:213
    macro expansion@requirements_interface.jl:100[inlined]
    var"#solve#3"(::Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, ::typeof(POMDPs.solve), ::DiscreteValueIteration.ValueIterationSolver, ::QuickPOMDPs.QuickMDP{Main.var"workspace#17".GridWorld, Main.var"workspace#17".State, Main.var"workspace#17".Action, @NamedTuple{stateindex::Dict{Main.var"workspace#17".State, Int64}, isterminal::typeof(Main.var"workspace#17".termination), render::typeof(Main.var"workspace#17".plot_grid_world), transition::typeof(Main.var"workspace#17".T), reward::typeof(Main.var"workspace#17".R), states::Vector{Main.var"workspace#17".State}, actions::Vector{Main.var"workspace#17".Action}, discount::Float64, initialstate::Vector{Main.var"workspace#17".State}, actionindex::Dict{Main.var"workspace#17".Action, Int64}}})@vanilla.jl:71
    DiscreteValueIteration@vanilla.jl:64[inlined]
    top-level scope@[Local: 1](http://localhost:1234/edit?id=6a20fab0-f778-11ee-2d09-af01287073bb#)[inlined]

The error suggests that the issue has something to do with the requirements macro. The requirements list as well as the rest of the code in vanilla.jl makes it look like rewards with larger parameter lists should work here. I'm pretty new to Julia, so I may just be misunderstanding something here. If anyone can tell me whether I am missing something, it would be greatly appreciated.

WhiffleFish commented 5 months ago

It seems like R(s) is being used in the transition function here which causes the failure. When you remove the default missing argument for the action in reward you're also removing the single argument reward method.

kevinbradner commented 5 months ago

Thanks, that certainly shows one area where my signature would cause an issue. Since R seems to be called with more arguments in other places, do you know the expected semantics for R(s)?

With that said, I'm still confused about the stack trace described in my earlier message. It's not necessarily an issue for this repo, but it would be helpful for my future Julia work if anyone can help explain that.

WhiffleFish commented 5 months ago

Sure thing, we set up some default fallbacks for reward, but it seems that we don't include reward(m, s) as a fallback.

In terms of the stack trace, I can see how it would be confusing that the reward(::P,::S,::A,::S) seems to be a required method. However, because of these fallbacks whenever you define reward(m,s,a), then reward(m,s,a,sp) is also automatically defined. So, you don't need to manually define the extra method.

Hope this helps

kevinbradner commented 5 months ago

Ok sure, thanks again for the information. I took the time to edit my notebook and test it a moment ago, and the DiscreteValueIteration code runs. I'll go ahead and close the issue.