sisl / Crux.jl

Julia library for deep reinforcement learning
MIT License
45 stars 10 forks source link

solve failure for POMDPs.jl models #5

Closed WhiffleFish closed 1 year ago

WhiffleFish commented 1 year ago
using POMDPModels
using POMDPs
using Flux
using Crux

mdp = SimpleGridWorld()
as = actions(mdp)
S = state_space(mdp)

A() = DiscreteNetwork(Chain(Dense(Crux.dim(S)..., 64, relu), Dense(64, 64, relu), Dense(64, length(as))), as)
V() = ContinuousNetwork(Chain(Dense(Crux.dim(S)..., 64, relu), Dense(64, 64, relu), Dense(64, 1)))

𝒮_ppo = PPO(π=ActorCritic(A(), V()), S=S, N=10_000, ΔN=1_000)
π_ppo = solve(𝒮_ppo, mdp)

Initially this fails because cpucall does not have a method for StaticArrays.

After adding a method accounting for this (or covering all bases with AbstractArray), another failure occurs here: https://github.com/ancorso/Crux.jl/blob/c32fd8ca94437c991eb0a9bb54686531d8542d23/src/sampler.jl#L93 .

This can be corrected by switching to sp, r = @gen(:sp,:r)(sampler.mdp, sampler.s, args...; kwargs...) from POMDPs.jl.

I'm not entirely certain what the side effects on certain POMDPGym environments may be, but at least this allows GridWorld to run.

ancorso commented 1 year ago

These seem like useful changes! Would you like to submit a PR? Otherwise I will probably make them on the next round of updates I make, but that may be a bit.