coax-dev / coax

Modular framework for Reinforcement Learning in python
https://coax.readthedocs.io
MIT License
165 stars 17 forks source link

PPOClip grad update seems to cause inf update #5

Open glmcdona opened 2 years ago

glmcdona commented 2 years ago

Describe the bug Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

Expected behavior

...
adv = np.random.rand(32)
grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
print("grads", grads)
print(ppo_clip._pi.params)
metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
print(ppo_clip._pi.params)

Results in:

grads FlatMapping({
  'linear': FlatMapping({
              'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
              'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                ...,
                                [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
            }),
FlatMapping({
  'linear': FlatMapping({
              'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                ...,
                                [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
              'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
            }),
})
FlatMapping({
  'linear': FlatMapping({
              'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
              'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                ...,
                                [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
            }),

Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

import os
from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
from luxai2021.game.game import Game
from luxai2021.game.actions import *
from luxai2021.game.constants import LuxMatchConfigs_Default

from luxai2021.env.agent import Agent, AgentWithTeamModel
import numpy as np

from agent import TeamAgent

# set some env vars
os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet

import gym
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam

# the name of this script
name = 'ppo'

configs = LuxMatchConfigs_Default

player = TeamAgent(mode="train")
opponent = Agent()

env = LuxEnvironment(configs=configs,
                                learning_agent=player,
                                opponent_agent=opponent)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

def func_pi(S, is_training):
    n_actions = 4
    out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
    return out

def func_v(S, is_training):
    h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
    return h

'''
def func_pi(S, is_training):
    #print(env.action_space.shape)
    n_filters = 5
    n_actions = 4
    n_layers = 3

    h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
    for layer in range(n_layers):
        h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))

    print('h', type(h), h.shape)
    h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
    h_head_actions = hk.Linear(n_actions)(h_head)
    print('h_head_actions', type(h_head_actions), h_head_actions.shape)
    #print(h_head_actions)

    out = {'logits': h_head_actions}

    return out

def func_v(S, is_training):
    n_filters = 5
    n_layers = 3

    h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
    for layer in range(n_layers):
        h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))

    h = hk.Flatten()(h)
    h = jax.nn.relu(hk.Linear(64)(h))
    h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))

    return h
'''

# function approximators
pi = coax.Policy(func_pi, env)
v = coax.V(func_v, env)

# target networks
pi_behavior = pi.copy()
v_targ = v.copy()

# policy regularizer (avoid premature exploitation)
entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)

# updaters
simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))

# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)

# run episodes
max_episode_steps = 400
while env.T < 3000000:
    s = env.reset()

    for t in range(max_episode_steps):
        print(t)
        a, logp = pi_behavior(s, return_logp=True)
        s_next, r, done, info = env.step(a)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done, logp)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= buffer.capacity:
            num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
            for i in range(num_batches):
                transition_batch = buffer.sample(32)
                grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)

                adv = np.random.rand(32)
                grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                print("grads", grads)
                print(ppo_clip._pi.params)
                metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                print(ppo_clip._pi.params)
                exit()
                env.record_metrics(metrics_pi)
                env.record_metrics(metrics_v)

            buffer.clear()

            # sync target networks
            pi_behavior.soft_update(pi, tau=0.1)
            v_targ.soft_update(v, tau=0.1)

        if done:
            break

        s = s_next

    # generate an animated GIF to see what's going on
    if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
        T = env.T - env.T % 10000  # round to 10000s
        coax.utils.generate_gif(
            env=env, policy=pi, resize_to=(320, 420),
            filepath=f"./data/gifs/{name}/T{T:08d}.gif")
royerk commented 2 years ago

Hello,

To add to @glmcdona, I'm getting the exact same issue but with a Box action space (if that makes any difference). After the update with the first minibatch the networks are filled with nans.

I will try to replicate with a classic gym env (by the way the pendulum-v0 from the examples is deprecated I think).

glmcdona commented 2 years ago

This error only occurs with the optax adam optimizer. Workaround is to use sgd optimizer. Error does not reproduce with TestPPOClip->test_update_discrete() or the example pong PPO with adam optimizer. Maybe close this issue unless a reliable repro can be created?

KristianHolsheimer commented 2 years ago

Hi Geoff! Thanks for telling me about this one.

It's very surprising that replacing optax.adam by optax.sgd seems to help. Perhaps the adam accumulators are contaminated by one a non-finite gradient somewhere?

Would it be possible to share a Colab notebook?