Closed WhiffleFish closed 1 year ago
using POMDPModels using POMDPs using Flux using Crux mdp = SimpleGridWorld() as = actions(mdp) S = state_space(mdp) A() = DiscreteNetwork(Chain(Dense(Crux.dim(S)..., 64, relu), Dense(64, 64, relu), Dense(64, length(as))), as) V() = ContinuousNetwork(Chain(Dense(Crux.dim(S)..., 64, relu), Dense(64, 64, relu), Dense(64, 1))) 𝒮_ppo = PPO(π=ActorCritic(A(), V()), S=S, N=10_000, ΔN=1_000) π_ppo = solve(𝒮_ppo, mdp)
Initially this fails because cpucall does not have a method for StaticArrays.
cpucall
After adding a method accounting for this (or covering all bases with AbstractArray), another failure occurs here: https://github.com/ancorso/Crux.jl/blob/c32fd8ca94437c991eb0a9bb54686531d8542d23/src/sampler.jl#L93 .
AbstractArray
This can be corrected by switching to sp, r = @gen(:sp,:r)(sampler.mdp, sampler.s, args...; kwargs...) from POMDPs.jl.
sp, r = @gen(:sp,:r)(sampler.mdp, sampler.s, args...; kwargs...)
I'm not entirely certain what the side effects on certain POMDPGym environments may be, but at least this allows GridWorld to run.
These seem like useful changes! Would you like to submit a PR? Otherwise I will probably make them on the next round of updates I make, but that may be a bit.
Initially this fails because
cpucall
does not have a method for StaticArrays.After adding a method accounting for this (or covering all bases with
AbstractArray
), another failure occurs here: https://github.com/ancorso/Crux.jl/blob/c32fd8ca94437c991eb0a9bb54686531d8542d23/src/sampler.jl#L93 .This can be corrected by switching to
sp, r = @gen(:sp,:r)(sampler.mdp, sampler.s, args...; kwargs...)
from POMDPs.jl.I'm not entirely certain what the side effects on certain POMDPGym environments may be, but at least this allows GridWorld to run.