coax-dev / coax

Modular framework for Reinforcement Learning in python
https://coax.readthedocs.io
MIT License
168 stars 17 forks source link

Quantile Q-Learning Implementation #4

Closed frederikschubert closed 3 years ago

frederikschubert commented 3 years ago

This PR adds a QuantileQ class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as a QuantileQLearning class. The QuantileQ function could be merged into the Q class which would simplify the user-facing API. However, some more work needs to be done to incorporate the QuantileQLearning class into the QLearning class. I just wanted to validate that this is the correct approach to take to implement the IQN.

There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.

Closes https://github.com/coax-dev/coax/issues/3

frederikschubert commented 3 years ago

So the initial implementation is functional and I added a CartPole notebook to showcase the API, but there are some refactorings left to do.

I also think that there need to be some stubs added and I still need to write the documentation. @KristianHolsheimer What do you think about my approach and proposed API?

frederikschubert commented 3 years ago

I fixed the tests and refactored the QuantileQ class. I think I will leave the class as is and add the CategoricalQ in another PR if I find the time. I will now try to add some stubs, examples for IQN and QR-DQN and ping you if that is complete.

KristianHolsheimer commented 3 years ago

Thanks so much for putting this together!

Generalize QuantileQ into a DistributionalQ and create QuantileQ and CategoricalQ subclasses. This will enable the implementation of C51.

Have you had a look at StochasticQ? I believe that that one implements C51.

What do you think about my approach and proposed API?

I have a question about the design. Would it be possible to implement the quantile inputs internally as in:

def forward(S, is_training):
    if is_training:
        quantiles = jnp.linspace(0, 1, num_quantiles)
    else:
        quantiles = jax.random.uniform(hk.next_rng_key(), (num_quantiles,))
    ...

Or would that make it harder to define the loss function etc?

frederikschubert commented 3 years ago

Oh, I did not see StochasticQ! I will change the QuantileQ function to be in line with it. I can then add a return_quantiles argument to its forward pass, so that it looks like this:

def __call__(self, s, a=None, return_quantiles=False):
    quantiles = self.quantile_func(s=s, a=a, is_training=False)
    ...
    (x, quantiles) if return_quantiles else x

Would this be ok? Or do you want to circumvent the function types 3 and 4 alltogether? That would be possible, but then the forward pass function has to create and return the quantiles.

KristianHolsheimer commented 3 years ago

Yeah sorry, I didn't document StochasticQ well enough. There's a simple example on frozen lake here, but that's about it.

If we can do without type 3 and 4, I would definitely prefer that. But I think you're in a better position to judge which option would be cleaner.

One alternative option might be to store the quantiles in the function state, e.g.

import jax
import jax.numpy as jnp
import haiku as hk

NUM_QUANTILES = 51

@hk.transform_with_state
def f(S, is_training):
    if is_training:
        quantiles = jax.random.uniform(hk.next_rng_key(), (NUM_QUANTILES,))
    else:
        quantiles = jnp.linspace(0, 1, NUM_QUANTILES)

    # Expose `quantiles` without changing the return signature.
    hk.set_state('quantiles', quantiles)  

    # The main neural net.
    value = hk.Sequential([
        # Something that involves `quantiles` ...
        hk.Linear(13),
    ])

    return value(S)

params, fn_state = f.init(42, S=jnp.ones((3, 7)), is_training=True)

print(fn_state)
# FlatMap({
#   '~': FlatMap({
#          'quantiles': DeviceArray([0.289, 0.827, 0.229, ...], dtype=float32),
#        }),
# })
frederikschubert commented 3 years ago

Thanks for the pointer to the example! I like the stochastic value function visualization and will try to add something similar for the quantile version.

I did not know about the function state API but I think I prefer the pure function approach of returning the q values and the quantiles. I will remove the additional function types.

frederikschubert commented 3 years ago

@KristianHolsheimer I refactored everything to use the StochasticQ class by introducing an EmpiricalQuantileDist distribution that is parametrized by the samples from the quantile function and also carries the quantile fractions for the loss computation. The StochasticQ class uses quantile regression if no value_range is given.

What do you think about this approach?

frederikschubert commented 3 years ago

Thank you for looking into my contribution!

I will still have to add a stub for the IQN and add some unit tests for the new functionality in StochasticQ and QLearning. Also I need to have a look at the documentation after the API is now finalized. But then this should be good to go.

Do you want me to squash the commits in the end so that the history in the main branch is clean?

KristianHolsheimer commented 3 years ago

Thank you for looking into my contribution!

No worries and thank you for the well-written PR!

I will still have to add a stub for the IQN and add some unit tests for the new functionality in StochasticQ and QLearning. Also I need to have a look at the documentation after the API is now finalized. But then this should be good to go.

Okay, shall we include that in this PR as well? Also, You probably noticed this yourself, but just in case... I noticed that the doc/examples/cartpole/iqn.rst script wasn't included in doc/examples/cartpole/index.rst.

Do you want me to squash the commits in the end so that the history in the main branch is clean?

Yes, will do.

frederikschubert commented 3 years ago

I added the stub, some documentation and a test that StochasticQ works with both quantile function types. Do you want me to add more tests or should we leave it as is? Otherwise I think this PR is ready to merge.

frederikschubert commented 3 years ago

No worries! I just fixed the quantile huber loss docs, because I saw that they were not included in the documentation and there was an error. So now it should be really good to go. :)