JuliaReinforcementLearning / ReinforcementLearning.jl

A reinforcement learning package for Julia
https://juliareinforcementlearning.org
Other
584 stars 110 forks source link

Refactor of DQN Algorithms #557

Closed harwiltz closed 1 year ago

harwiltz commented 2 years ago

I've started refactoring the DQN implementations, but I'm fairly new to Julia so I'd appreciate your feedback about whether this is a good idea or not.

In essence, it looks to me like there is lots of repetition among the various DQN implementations and I think there should be more abstraction. My first goal is to abstract the bootstrapping function, so structs would have a bootstrap_func member. This will be particularly useful in the distributional DQN algorithms where C51/Rainbow and QR-DQN differ pretty much only by the loss function and bootstrapping. My research actually even would be simplified by having this abstraction for expected value RL as well.

My biggest concern with this is how the "subtyping" would work, as multi-dispatch is still a little mysterious to me. So far, I made an abstract type AbstractDQNLearner that all of the DQN implementations inherit from, and I made a method in common.jl like so

function bootstrap_func(::AbstractDQNLearner, r, \gamma, t, q', n) = r .+ \gamma^n .* (1 .- t) .* q'

Now suppose there was a constructor (say, the DQNLearner constructor) that I argue can be used to initialize both C51 and QR-DQN by just changing the bootstrap_func and loss_func parameters. Will we ever need to distinguish these two algorithms by multi-dispatch? For instance, is there any downside to having

module CategoricalDQN
exports CategoricalDQNLearner

function bootstrap_func ... end
function loss_func ... end

function CategoricalDQNLearner(# basically same arguments as DQNLearner)
    DQNLearner(..., loss_func = loss_func, bootstrap_func = bootstrap_func, ...)
end
end

module QRDQN
exports QRDQNLearner

function bootstrap_func ... end
function loss_func ... end # we would just move quantile_huber_loss here

function QRDQNLearner(# basically same arguments as DQNLearner)
    DQNLearner(..., loss_func = loss_func, bootstrap_func = bootstrap_func, ...)
end

Then we wouldn't be able to have separate update!(::QRDQNLearner, ...) methods like we have now, for example. Personally, this module design makes more sense to me, because it seems like the most natural way to reuse code.

findmyway commented 2 years ago

In essence, it looks to me like there is lots of repetition among the various DQN implementations and I think there should be more abstraction.

Yeah, it was intentional at first to keep the implementation of each algorithm in one single file. But then we really had a lot of duplicate code among all the DQN variants, so some common parts were moved into common.jl. I agree there's still much duplicate code there and we can do more.

I can get your idea to add a botstrap_func to simplify current implementations. I think in Julia people would prefer to implement it like this:


struct DQNLearner{B, L}
    bootstrap::B
    loss::L
    # ... other fields
end

const CategoricalDQNLearner = DQNLearner{BootstrapForC51}
const QRDQNLearner = DQNLearner{BootstrapForQR}

Though this looks more concise, I have several concerns:

  1. Different variants have really different parameters, so I'm afraid the definition of DQNLearner would be quite general.
  2. The definition of bootstrap step is not very clear to me. I'm not sure whether each DQN variant can be generalized with a such step of similar parameters.

Maybe the first step to refactor is to try to extract the bootstrap part from existing implementations into common.jl. Then the rest would be relatively easy, We put each implementation into a submodule and generalize them further. What do you think?

harwiltz commented 2 years ago

Firstly, I very much agree with your first step, so I'll get on that!

Different variants have really different parameters, so I'm afraid the definition of DQNLearner would be quite general.

Yeah I guess that's true. It would be really nice if it was possible to have inheritance of some sort, but as I understand it, this is not possible in Julia. I gotta get out of the OOP mindset a little bit.

The definition of bootstrap step is not very clear to me. I'm not sure whether each DQN variant can be generalized with a such step of similar parameters.

My idea of it is basically that each DQN variant (that I know of) has to learn some type of Q-value representation, and learns via temporal differences (so... perhaps this is mainly a TD learning abstraction more than DQN). The bootstrap step simply computes the target for whatever the Q function loss is (ie, it computes one of the arguments to loss_func).

As a more concrete example, consider the Statistics and Samples in Distributional RL paper. It suggests that all distributional RL algorithms should perform updates of the form

  1. Transform the set of statistics you're modeling into a probability distribution
  2. Sample from that distribution
  3. Transform the samples with some sort of Bellman operator
  4. Extract the new statistics from the resulting distribution

So step 3 would be the bootstrap_func. It should be the case that every distributional algorithm has this form. This also generalizes to non-distributional algorithms which can be seen as a special case:

  1. Transform the set of statistics (the expected value) to a probability distribution (point mass at the expected value)
  2. Sample from that distribution (trivial)
  3. Transform (standard Bellman update)
  4. Extract new statistics (also trivial)

I think ultimately we can abstract each of these steps (maybe in a similar way to how hooks are implemented?) and save lots of code repetition.

findmyway commented 2 years ago

Thanks for the link and your explanation! Now I have a much better understanding. I think such abstraction will not only save lots of code repetition but also do help to understand existing algorithms and even implement new algorithms. Let me know if you need any help in understanding existing implementations when refactoring the DQN variants in this package.