google / dopamine

Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
https://github.com/google/dopamine
Apache License 2.0
10.42k stars 1.36k forks source link

NoisyNets implementation issues #189

Open pseudo-rnd-thoughts opened 2 years ago

pseudo-rnd-thoughts commented 2 years ago

I'm implementing my own RL framework in Jax to better understand RL algorithms and found your code very helpful

Looking at the NoisyNets implementation, on line 316 and 317 (https://github.com/google/dopamine/blob/master/dopamine/jax/networks.py) The same rng_key is used each time noise is generated meaning that no 'new' noise is generated each time an input is passed to the layer. In effect, the layer just applies a linear transform I think

This is a short testing example

import jax
import numpy as np

from dopamine.jax.networks import NoisyNetwork

if __name__ == '__main__':
    rng = jax.random.PRNGKey(1)
    rng, rng_net_def, rng_net_param = jax.random.split(rng, num=3)

    net_def = NoisyNetwork(rng_key=rng_net_def, eval_mode=False)
    net_params = net_def.init(rng_net_param, x=np.zeros(10), features=3)

    state = np.random.random(10)
    print(net_def.apply(net_params, x=state, features=3))
    print(net_def.apply(net_params, x=state, features=3))

If this is an issue, then I implemented the following code for my framework

from typing import Sequence

import jax
import numpy as onp
import jax.numpy as jnp
from flax import linen as nn

class NoisyDense(nn.Module):
    features: int

    use_bias: bool = True

    @staticmethod
    @jax.jit
    def _f(x: jnp.ndarray) -> jnp.ndarray:
        # See (10) and (11) in Fortunato et al. (2018).
        return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5))

    @nn.compact
    def __call__(self, inputs: onp.ndarray, eval_mode: bool = True, rng: jnp.DeviceArray = None) -> jnp.ndarray:
        if eval_mode:  # Turn off noise during evaluation
            w_epsilon = jnp.zeros(shape=(inputs.shape[0], self.features), dtype=onp.float32)
            b_epsilon = jnp.zeros(shape=(self.features,), dtype=onp.float32)
        else:  # Factored gaussian noise in (10) and (11) in Fortunato et al. (2018).
            p_key, q_key = jax.random.split(rng)
            p, q = jax.random.normal(p_key, [inputs.shape[0], 1]), jax.random.normal(q_key, [1, self.features])
            f_p, f_q = self._f(p), self._f(q)
            w_epsilon, b_epsilon = f_p * f_p, jnp.squeeze(f_q)

        def _mu_init(key: jnp.DeviceArray, shape: Sequence[int]):
            # Initialization of mean noise parameters (Section 3.2)
            mean = 1 / jnp.power(inputs.shape[0], 0.5)
            return jax.random.uniform(key, minval=-mean, maxval=mean, shape=shape)

        def _sigma_init(_key: jnp.DeviceArray, shape: Sequence[int], dtype=jnp.float32):
            # Initialization of sigma noise parameters (Section 3.2)
            return jnp.ones(shape, dtype) * (0.1 / onp.sqrt(inputs.shape[0]))

        # See (8) and (9) in Fortunato et al. (2018) for output computation.
        w_mu = self.param('kernel_mu', _mu_init, (inputs.shape[0], self.features))
        w_sigma = self.param('kernel_sigma', _sigma_init, (inputs.shape[0], self.features))
        out = jnp.matmul(inputs, w_mu + jnp.multiply(w_sigma, w_epsilon))

        if self.use_bias:
            b_mu = self.param('bias_mu', _mu_init, (self.features,))
            b_sigma = self.param('bias_sigma', _sigma_init, (self.features,))
            out = out + b_mu + jnp.multiply(b_sigma, b_epsilon)
        return out

Here is some similar testing code

if __name__ == '__main__':
    rng = jax.random.PRNGKey(1)
    rng, rng_net_def, rng_net_param = jax.random.split(rng, num=3)

    net_def = NoisyDense(features=2)
    net_params = net_def.init(rng_net_param, np.zeros(10))

    state = np.random.random(10)
    print(net_def.apply(net_params, inputs=state))
    print(net_def.apply(net_params, inputs=state, eval_mode=False, rng=rng_net_def))
    print(net_def.apply(net_params, inputs=state, eval_mode=False, rng=rng))

I would have submitted this as a pull request but noticed that you are not accepting merges

young-geng commented 2 years ago

I also realized that this might be an issue. If we want to resample noise we should use either explicitly pass in a new rng every time or use self.make_rng to ensure that RNGs are split correctly.

pseudo-rnd-thoughts commented 2 years ago

Flax linen module variables are not able to be updated so the only way to have "new" random noise is to pass in the PRNG as a parameter like I have done in my example code

agarwl commented 2 years ago

Edit: I understood the original comment incorrectly -- it was pointing out the correlated noise in Line 316 & 317 -- It's unclear how much impact it has on performance but will fix it - thanks for pointing it out! Also, this should fix it:

rng_p, rng_q = jax.random.split(self.rng_key, num=2)
p = NoisyNetwork.sample_noise(rng_p, [x.shape[0], 1])
q = NoisyNetwork.sample_noise(rng_q, [1, features])

I am not sure if this is a bug -- as @young-geng mentioned, if we want to resample noise, then we need to pass an explicit rng every time as done in the FullRainbowNetwork here. That said, this does seem like a documentation issue about how we expect NoisyNets to work. @psc-g for further visibility

Here's a simplified example to verify that explicitly passing rng works:

class DummyNetwork(nn.Module):
  """Dummy network for testing NoisyNets."""

  @nn.compact
  def __call__(self, x, eval_mode=False, key=None):
    if key is None:
      key = jax.random.PRNGKey(int(time.time() * 1e6))
    return NoisyNetwork(rng_key=key, eval_mode=eval_mode)(x, features=2)

def create_noisy_net_and_eval(num_runs=5):
  network_def = DummyNetwork()
  x = jnp.ones(5)
  rng = jax.random.PRNGKey(0)
  rng1, rng = jax.random.split(rng, 2)
  params = network_def.init(rng1, x=x)
  for i in range(num_runs):
    rng1, rng = jax.random.split(rng)
    print(f'rng{i}', network_def.apply(params, x=x, key=rng1))
>> create_noisy_net_and_eval()
rng0 [ 0.49825954 -0.5264382 ]
rng1 [ 0.3296632  -0.56998575]
rng2 [ 0.5706229  -0.42372862]
rng3 [ 0.5419281  -0.47531918]
rng4 [ 0.52439386 -0.46529555]
psc-g commented 2 years ago

hi, thanks for raising this! i agree with what rishabh pointed out. i believe once the rngs used for p and q are uncorrelated, i believe it is working as expected (e.g. a new rng is not passed in every time)

pseudo-rnd-thoughts commented 2 years ago

@agarwl Thanks, I hadn't spotted the FullRainbowNetwork implementation passed a new rng key to the noisy network each time so you are correct. With the modification that you propose then the noisy network works are expected

But as the eval_mode and rng_key are attributes of the network then it is potentially misleading as these are actually attributes that need to be passed to the call function every time. And in reverse, the features, use_bias and kernel_init should not be modified after init. This is the reason that I shifted these variables from the init to call and vice versa in my implementation

@psc-g I may be wrong but I think a new rng should be passed every time (when eval_mode = False) as if new noise is not added each time then all that is happening is a linear transformation is being applied to the inputs. In my eyes, defeating the point of the noisy network heuristic to both increase stability/resilience of the network and increase the "observations" seen by the network.

young-geng commented 2 years ago

Now I see that it passes in a new RNG key every time so I believe I was wrong about the noise not being resampled and the implementation should be correct. Sorry for the confusion.