instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
709 stars 83 forks source link

[FEATURE] Use numpyro #1098

Open sash-a opened 1 month ago

sash-a commented 1 month ago

Feature

Switch distribution libraries...again!

The problem with tfp is that if we want to run non-shared parameters we need vmap the apply function (over params and observations), but this means that it would return a jax array of tfp distributions and since a tfp distribution is not a jax type this cannot work. But numpyro's distribution objects are jax types and are vmappable! So we can just use them as a drop in replacement, here's an proof of concept:

import jax.numpy as jnp
import jax
import flax.linen as nn
import numpyro

class Network(nn.Module):
    @nn.compact
    def __call__(self, x):
        return numpyro.distributions.Categorical(logits=nn.Dense(5)(x))

n_agents = 4

key = jax.random.PRNGKey(3)
keys = jax.random.split(key, n_agents)

x = jnp.arange(5, dtype=float)
xs = x[jnp.newaxis].repeat(n_agents, axis=0)

net = Network()
params = jax.vmap(net.init)(keys, xs)

dist = jax.jit(jax.vmap(net.apply))(params, xs)
action = dist.sample(key, (n_agents,))  # Array([2, 3, 3, 2], dtype=int32)
dist.log_prob(action)  # Array([-1.9792106 , -1.3051271 , -0.10164165, -2.1678243 ], dtype=float32)
dist.entropy()  # Array([0.52041006, 1.3267    , 0.37860775, 0.91515315], dtype=float32)

Replacing the numpyro.distributions.Categorical with a tfp.Categorical gives the following error: ValueError: Attempt to convert a value (<object object at 0x7fa1a191bfa0>) with an unsupported type (<class 'object'>) to a Tensor. because distributions are objects which are not jax types