Closed harwiltz closed 1 year ago
In essence, it looks to me like there is lots of repetition among the various DQN implementations and I think there should be more abstraction.
Yeah, it was intentional at first to keep the implementation of each algorithm in one single file. But then we really had a lot of duplicate code among all the DQN variants, so some common parts were moved into common.jl
. I agree there's still much duplicate code there and we can do more.
I can get your idea to add a botstrap_func
to simplify current implementations. I think in Julia people would prefer to implement it like this:
struct DQNLearner{B, L}
bootstrap::B
loss::L
# ... other fields
end
const CategoricalDQNLearner = DQNLearner{BootstrapForC51}
const QRDQNLearner = DQNLearner{BootstrapForQR}
Though this looks more concise, I have several concerns:
DQNLearner
would be quite general.bootstrap
step is not very clear to me. I'm not sure whether each DQN variant can be generalized with a such step of similar parameters.Maybe the first step to refactor is to try to extract the bootstrap
part from existing implementations into common.jl
. Then the rest would be relatively easy, We put each implementation into a submodule and generalize them further. What do you think?
Firstly, I very much agree with your first step, so I'll get on that!
Different variants have really different parameters, so I'm afraid the definition of DQNLearner would be quite general.
Yeah I guess that's true. It would be really nice if it was possible to have inheritance of some sort, but as I understand it, this is not possible in Julia. I gotta get out of the OOP mindset a little bit.
The definition of bootstrap step is not very clear to me. I'm not sure whether each DQN variant can be generalized with a such step of similar parameters.
My idea of it is basically that each DQN variant (that I know of) has to learn some type of Q-value representation, and learns via temporal differences (so... perhaps this is mainly a TD learning abstraction more than DQN). The bootstrap step simply computes the target for whatever the Q function loss is (ie, it computes one of the arguments to loss_func
).
As a more concrete example, consider the Statistics and Samples in Distributional RL paper. It suggests that all distributional RL algorithms should perform updates of the form
So step 3 would be the bootstrap_func
. It should be the case that every distributional algorithm has this form. This also generalizes to non-distributional algorithms which can be seen as a special case:
I think ultimately we can abstract each of these steps (maybe in a similar way to how hooks are implemented?) and save lots of code repetition.
Thanks for the link and your explanation! Now I have a much better understanding. I think such abstraction will not only save lots of code repetition but also do help to understand existing algorithms and even implement new algorithms. Let me know if you need any help in understanding existing implementations when refactoring the DQN variants in this package.
I've started refactoring the DQN implementations, but I'm fairly new to Julia so I'd appreciate your feedback about whether this is a good idea or not.
In essence, it looks to me like there is lots of repetition among the various DQN implementations and I think there should be more abstraction. My first goal is to abstract the bootstrapping function, so structs would have a
bootstrap_func
member. This will be particularly useful in the distributional DQN algorithms where C51/Rainbow and QR-DQN differ pretty much only by the loss function and bootstrapping. My research actually even would be simplified by having this abstraction for expected value RL as well.My biggest concern with this is how the "subtyping" would work, as multi-dispatch is still a little mysterious to me. So far, I made an
abstract type AbstractDQNLearner
that all of the DQN implementations inherit from, and I made a method incommon.jl
like soNow suppose there was a constructor (say, the
DQNLearner
constructor) that I argue can be used to initialize both C51 and QR-DQN by just changing thebootstrap_func
andloss_func
parameters. Will we ever need to distinguish these two algorithms by multi-dispatch? For instance, is there any downside to havingThen we wouldn't be able to have separate
update!(::QRDQNLearner, ...)
methods like we have now, for example. Personally, this module design makes more sense to me, because it seems like the most natural way to reuse code.