patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
1.94k stars 132 forks source link

Unstable PPO? #596

Open stergiosba opened 8 months ago

stergiosba commented 8 months ago

Hello Patrick,

I am doing an implementation of the PPO algorithm for a custom environment and first wanted to test things out with a standard example and I choose CartPole-v1 implemented with gymnax.

I compared the Flax implementation from gymnax-blines that is stable/tested and solves a lot of environments with an equinox based solution. Essentially I got rid of the flax.TrainState as well as the flax NN model and replaced them with the usage of Equinox.Module based model.

The code runs with no errors but silently something is going on and I don't know what. To elaborate further, even though the agent is learning the following "problems" are still present:

  1. The Flax implementation converges much faster to the optimal with exactly the same configuration parameters (learning rate, clipping ratio e.t.c.) and same agent model (layers/initialization).
  2. The Flax implementation stays at the optimal while the performance of the equinox agent is unstable in the sense that it reaches close to the maximum but it exhibits oscillations. For instance the maximum in CarPole-v1 is 500 and the equinox agent can get from 499 to 300 to 150 and so on, in consecutive training epochs even after a training session whereas for the same training length (num of epochs) the Flax implementation is locked at 500 forever after a point.

I checked how randomization is progressing and I get exactly the same keys as the flax implementation since I start with the same seed and then split the keys at exactly the same places.

Could it be something with the way gradients are updating the parameters of the model? My input is batched so I know that I have to used jax.vmap according to the docs. The problem is that If I use jax.vmap and then define the optimization state as:

opt_state = optimizer.init(eqx.filter(equinox_model, eqx.is_array))

I then get None in the gradients because the output of jax.vmap is of type: type.FunctionType and that is not detectable by eqx.is_array. This pushed me to use eqx.filter_vmap an got the gradients with eqx.filter_value_and_grad but am I messing things up somewhere?

FYI, the model is defined as:

class Agent(eqx.Module):
    critic: List
    actor: List

    def __init__(self, env, env_params, key):
        obs_shape = env.observation_space(env_params).shape
        keys = jrandom.split(key, 6)

        # critic similar to actor but with out_dim=1... 
        # keys[0:2] are used in the critic layers.

        self.actor = [
            eqx.filter_vmap(
                eqx.nn.Linear(
                    jnp.array(obs_shape).prod(),
                    64,
                    key=keys[3],
                ),
                in_axes=0,
            ),
            jnn.relu,
            eqx.filter_vmap(
                eqx.nn.Linear(
                    64, 
                    64, 
                    key=keys[4]
                ), 
                in_axes=0),
            jnn.relu,
            eqx.filter_vmap(
                eqx.nn.Linear(
                    64, 
                    env.num_actions, 
                    key=keys[5]
                ),
                in_axes=0,
            ),
        ]

    def get_action_and_value(self, x):
        x_a = x
        x_v = x
        for layer in self.actor:
            x_a = layer(x_a)

        for layer in self.critic:
            x_v = layer(x_v)

        pi = tfp.distributions.Categorical(logits=x_a)

        return x_v, pi

Any suggestions would be greatly appreciated. Thanks!

lockwo commented 8 months ago

Hard to say exactly what the issue is (if you have the MVC that would be helpful) from this description. I usually do the neural network defined for a single input, then vmap outside of that (instead of internally), but that's a smaller point. You could also check out: https://github.com/patrick-kidger/rl-test/blob/master/src/algorithms.py#L61.

Not super helpful in the near term, but I am working on https://github.com/sotetsuk/pgx/issues/1059 which will result in a stable equinox PPO version (but the PR for that issue won't be done in the near term).

stergiosba commented 8 months ago

Hey @lockwo. Thanks for the comment.

I have seen the PPO algo that Patrick wrote in the past as well as other implementations like CleanRL.

I reported this as I thought there was an internal issue with Equinox. Of course, I realize that most probably my revisions are what broke the algorithm and not Equinox but wanted a second opinion.

I am still working on the issue and will update with new info as I go.

patrick-kidger commented 8 months ago

Hard to say exactly what the issue is (if you have the MVC that would be helpful) from this description. I usually do the neural network defined for a single input, then vmap outside of that (instead of internally), but that's a smaller point.

I would note that this design pattern also means you don't need to worry about getting gradients as you then discussed with _"I then get None in the gradients because the output of jax.vmap is of type: type.FunctionType and that is not detectable by eqx.isarray." (But if you do want to use this pattern, then using filter_vmap as you have is indeed the appropriate solution.)

One possible difference between the implementations may be that Equinox and Flax use different initialisations for the same seed (e.g. sampling from a random vs a uniform distribution).

stergiosba commented 8 months ago

One possible difference between the implementations may be that Equinox and Flax use different initialisations for the same seed (e.g. sampling from a random vs a uniform distribution).

I have ensured that both models start with the same initialization and all PPO parameters are of course the same.

What I did:

What is baffling is that even though the first batch of data is the same, the gradients begin to diverge ever so slightly. After the first batch I cannot keep the same input as the two models choose different actions and thus get different observations etc.

The result is again that the Equinox model oscillates around the optimal and then makes huge dips in performance while the Flax model converges.

Also tried x64 in JAX config, still the same.

Any other thoughts would be greatly appreciated. Still looking into this.

stergiosba commented 8 months ago

Tried with the model for single observation and vmap outside. I get the same performance as before (I mean exactly the same to the last decimal point). The model is the following:

class Agent(eqx.Module):
    critic: List
    actor: List

    def __init__(self, env, env_params, key):
        obs_shape = env.observation_space(env_params).shape
        keys = jrandom.split(key, 6)

        self.critic = eqx.nn.Sequential(
            [
                eqx.nn.Linear(
                    jnp.array(obs_shape).prod(),
                    64,
                    key=keys[0],
                ),
                eqx.nn.Lambda(jnn.relu),
                eqx.nn.Linear(64, 64, key=keys[1]),
                eqx.nn.Lambda(jnn.relu),
                eqx.nn.Linear(64, 1, key=keys[2]),
            ]
        )

        self.actor = eqx.nn.Sequential(
            [
                eqx.nn.Linear(
                    jnp.array(obs_shape).prod(),
                    64,
                    key=keys[3],
                ),
                eqx.nn.Lambda(jnn.relu),
                eqx.nn.Linear(64, 64, key=keys[4]),
                eqx.nn.Lambda(jnn.relu),
                eqx.nn.Linear(64, env.num_actions, key=keys[5]),
            ]
        )

    @eqx.filter_jit
    def __call__(self, x):
        return self.critic(x), self.actor(x)

If you meant something else and not using these two sequential models please explain further. Thanks again.

lockwo commented 8 months ago

That is what I meant. It was mostly a design pattern note, I wouldn’t expect any numerical differences

stergiosba commented 8 months ago

That is what I meant. It was mostly a design pattern note, I wouldn’t expect any numerical differences

Yeah I was just paranoid and checked everything.

I managed to make it work and it is stable when I do the following hacky thing:

I make a namedtuple trainstate as follows:

eqxTrainState = namedtuple("eqxTrainState", ["params", "static", "tx", "opt_state"])

Then basically, I carry this trainstate around everywhere instead of the equinox model. I reconstruct the model when needed from the params and static and then do a forward pass. I will test now to see if I can do this without the partition.

I attach the following image so you can get an idea of what I have been seeing so far (green) and what the performance is now (orange) for Equinox vs Flax (blue).

eqx_vs_flax

As it can be seen, the "correct" equinox and the flax implementations are nearly identical at the beginning. At some points there are differences but that is to be expected? This remains a bit of a mystery. Also the dips are sharper for the equinox agent but much better than the green tragedy :)

edit: Works without the partition as well