Open tejank10 opened 6 years ago
These look so much better than the RL code I've seen in Python-based frameworks. One thing I noticed, though, is that functions like zero_grad!
and update_target!
are implemented with nested loops over individual parameter elements, so they would have to be rewritten for GPU. Instead, maybe use a single loop over the parameter arrays and an in-place broadcast inside the loop? Or maybe the outer loop can also be replaced with mapleaves!
?
Fixed it. Can we add those two functions (or just zero_grad!
maybe) to Flux?
How about turning this into a Julia package? Makes the dependencies easier. It would be nice to start training a standard model in a couple of lines, and even better if I can quickly demo a trained model.
Would be great to make sure they are all GPU compatible as well.
Cool, I'll start working on it.
What's going on with lines like https://github.com/tejank10/Flux-baselines/blob/master/dqn/duel-dqn.jl#L46? Is it a matter of Flux not being able to deal with the complicated broadcast expression if you write it out with dots? (If so we should make sure that's fixed in Flux-on-v0.7).
Works with dots, fixed it
I am trying to adapt this code to a problem of mine but I get an error (on julia 1.0). Just wondering if anyone could give a hint how to deal with this. The code line is
logπ = log.(mean(π .* action, 1) + 1f-10)
and π comes directly from the Flux model (I use the AC example)
π = softmax(policy(base_out))
MethodError: no method matching mean(::TrackedArray{…,Array{Float64,1}}, ::Int64)
Closest candidates are:
mean(!Matched::Union{Function, Type}, ::Any) at /Users/osx/buildbot/slave/package_osx64/build/usr/share/julia/stdlib/v1.0/Statistics/src/Statistics.jl:58
mean(::TrackedArray; dims) at /Users/jonnorberg/.julia/packages/Flux/jbpWo/src/tracker/array.jl:226
mean(::AbstractArray; dims) at /Users/osx/buildbot/slave/package_osx64/build/usr/share/julia/stdlib/v1.0/Statistics/src/Statistics.jl:128
Not sure why the mean function doesn't recognize the ::TrackedArray or how I could adjust it to work.
thanks for any hints
If you're on Julia 1.0, mean
takes dims
as a kwarg.
major thanks. I also noted I need to change +
to .+
at two places to not get error message
Hey @MikeInnes, if you are back could you please review the code? New models which I have added are Dueling DQN, Advantage Actor-Critic, and DDPG. Also, all the previous work done on DQN is added to dqn directory.