Closed frederikschubert closed 3 years ago
So the initial implementation is functional and I added a CartPole notebook to showcase the API, but there are some refactorings left to do.
quantile_func
that receives the states and actions and a is_training
flag to produce the quantiles. This would allow an easy implementation of QR-DQN as well as the FQF. It also will make the usage of risk sensitive policies using CVaR distortions possible.QuantileQ
into a DistributionalQ
and create QuantileQ
and CategoricalQ
subclasses. This will enable the implementation of C51
.I also think that there need to be some stubs added and I still need to write the documentation. @KristianHolsheimer What do you think about my approach and proposed API?
I fixed the tests and refactored the QuantileQ
class. I think I will leave the class as is and add the CategoricalQ
in another PR if I find the time. I will now try to add some stubs, examples for IQN and QR-DQN and ping you if that is complete.
Thanks so much for putting this together!
Generalize QuantileQ into a DistributionalQ and create QuantileQ and CategoricalQ subclasses. This will enable the implementation of C51.
Have you had a look at StochasticQ
? I believe that that one implements C51.
What do you think about my approach and proposed API?
I have a question about the design. Would it be possible to implement the quantile inputs internally as in:
def forward(S, is_training):
if is_training:
quantiles = jnp.linspace(0, 1, num_quantiles)
else:
quantiles = jax.random.uniform(hk.next_rng_key(), (num_quantiles,))
...
Or would that make it harder to define the loss function etc?
Oh, I did not see StochasticQ
! I will change the QuantileQ
function to be in line with it. I can then add a return_quantiles
argument to its forward pass, so that it looks like this:
def __call__(self, s, a=None, return_quantiles=False):
quantiles = self.quantile_func(s=s, a=a, is_training=False)
...
(x, quantiles) if return_quantiles else x
Would this be ok? Or do you want to circumvent the function types 3 and 4 alltogether? That would be possible, but then the forward pass function has to create and return the quantiles.
Yeah sorry, I didn't document StochasticQ
well enough. There's a simple example on frozen lake here, but that's about it.
If we can do without type 3 and 4, I would definitely prefer that. But I think you're in a better position to judge which option would be cleaner.
One alternative option might be to store the quantiles in the function state, e.g.
import jax
import jax.numpy as jnp
import haiku as hk
NUM_QUANTILES = 51
@hk.transform_with_state
def f(S, is_training):
if is_training:
quantiles = jax.random.uniform(hk.next_rng_key(), (NUM_QUANTILES,))
else:
quantiles = jnp.linspace(0, 1, NUM_QUANTILES)
# Expose `quantiles` without changing the return signature.
hk.set_state('quantiles', quantiles)
# The main neural net.
value = hk.Sequential([
# Something that involves `quantiles` ...
hk.Linear(13),
])
return value(S)
params, fn_state = f.init(42, S=jnp.ones((3, 7)), is_training=True)
print(fn_state)
# FlatMap({
# '~': FlatMap({
# 'quantiles': DeviceArray([0.289, 0.827, 0.229, ...], dtype=float32),
# }),
# })
Thanks for the pointer to the example! I like the stochastic value function visualization and will try to add something similar for the quantile version.
I did not know about the function state API but I think I prefer the pure function approach of returning the q values and the quantiles. I will remove the additional function types.
@KristianHolsheimer I refactored everything to use the StochasticQ
class by introducing an EmpiricalQuantileDist
distribution that is parametrized by the samples from the quantile function and also carries the quantile fractions for the loss computation. The StochasticQ
class uses quantile regression if no value_range
is given.
What do you think about this approach?
Thank you for looking into my contribution!
I will still have to add a stub for the IQN and add some unit tests for the new functionality in StochasticQ
and QLearning
. Also I need to have a look at the documentation after the API is now finalized. But then this should be good to go.
Do you want me to squash the commits in the end so that the history in the main branch is clean?
Thank you for looking into my contribution!
No worries and thank you for the well-written PR!
I will still have to add a stub for the IQN and add some unit tests for the new functionality in StochasticQ and QLearning. Also I need to have a look at the documentation after the API is now finalized. But then this should be good to go.
Okay, shall we include that in this PR as well? Also, You probably noticed this yourself, but just in case... I noticed that the doc/examples/cartpole/iqn.rst
script wasn't included in doc/examples/cartpole/index.rst
.
Do you want me to squash the commits in the end so that the history in the main branch is clean?
Yes, will do.
I added the stub, some documentation and a test that StochasticQ
works with both quantile function types. Do you want me to add more tests or should we leave it as is? Otherwise I think this PR is ready to merge.
No worries! I just fixed the quantile huber loss docs, because I saw that they were not included in the documentation and there was an error. So now it should be really good to go. :)
This PR adds a
QuantileQ
class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as aQuantileQLearning
class. TheQuantileQ
function could be merged into theQ
class which would simplify the user-facing API. However, some more work needs to be done to incorporate theQuantileQLearning
class into theQLearning
class. I just wanted to validate that this is the correct approach to take to implement the IQN.There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.
Closes https://github.com/coax-dev/coax/issues/3