google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.14k stars 647 forks source link

[Question] Best way to implement stochastic reinforcement learning actors? #179

Closed mmcenta closed 4 years ago

mmcenta commented 4 years ago

Hello!

First of all, thanks for this project - it is a lifesaver! So I wanted to get familiar with JAX so I decided to implement a few deep reinforcement learning algorithms as a side project. I initially approached the problem by subclassing flax.nn.Module as follows:

import flax.nn as nn
import jax.numpy as jnp
import jax.random as random

class MLPCategoricalActor(nn.Module):
    def apply(self, obs, act, action_space=None, rng=None,
              hidden_sizes=(64, 64), activation_fn=nn.tanh, output_fn=None):
        assert action_space is not None, "Action space must be specified."
        if rng is None:
            rng = nn.make_rng()
        act_dim = action_space.n
        logits = _MLP(obs, sizes=list(hidden_sizes) + [act_dim], activation_fn=activation_fn)
        pi = random.categorical(rng, logits)
        logp_all = nn.log_softmax(logits)
        logp = jnp.multiply(one_hot(act, act_dim), logp_all).sum(axis=1)
        logp_pi = jnp.multiply(one_hot(pi, act_dim), logp_all).sum(axis=1)
        return pi, logp, logp_pi

I very quickly ran into the problem of having a duplicate parameter 'rng'. I dug into the flax code and discovered that 1) dropout was not implemented as a module as I believed and 2) the ModuleFrame has an rng param that I can't access (apparently). I came up with three solutions:

  1. Get rng from the nn.stochastic context, but that would require wrapping the entire training function with it which seems a little weird to me.

  2. Use the same solution as in the VAE example and pass rng each time as a positional argument.

  3. Mix both solutions and try to get rng from a kwarg and if that fails fallback to the context. This may lead to a problem if someone sets the kwarg with a call to partial...

I wanted to ask you how you would go about this? My main concerns are code reusability and reproducibility.

levskaya commented 4 years ago

Hi!

Basically something like (1) is the canonical way we do it! We use a with nn.stochastic(prng_key): context around the model evaluation inside the training-step function, with a top-level prng-key fed into the training-step function from an outside split in the training loop. The important thing is that the training-step function is jitted in its entirety - the one thing not to do is to jit across a nn.stochastic context. Provided it's used inside a jit and the prng-keys are fed like any other function argument, there should be no trouble using it. Using the stochastic context can save a lot of prng-key plumbing boilerplate in models.

I've typed up a quick demo of the canonical way in a colab at https://colab.research.google.com/drive/1eDXEVd8NPXgaSwn7jEMxsZTDVHNUtHYK

Let me know if that helps or if anything remains unclear!

mmcenta commented 4 years ago

Thank you for the amazing answer, I will get working on my project right away! Everything is clear, you even mentioned the part about passing the PRNG key itself while initializing a Module following 3, which something I had problems with.

I'm closing the issue, thanks for the help!