google / dopamine

Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
https://github.com/google/dopamine
Apache License 2.0
10.42k stars 1.36k forks source link

JAX Dqn SpaceInvader network params are wrong #181

Closed pseudo-rnd-thoughts closed 2 years ago

pseudo-rnd-thoughts commented 2 years ago

Im using the very helpful set of pre-trained JAX agent however in training the Dqn one I found an issue with one of the SpaceInvader network params

For dqn/SpaceInvaders/1/ckpt.199 output layer has 1200 outputs rather than the expected 6

Conv_0 (8, 8, 4, 32) (32,) Conv_1 (4, 4, 32, 64) (64,) Conv_2 (3, 3, 64, 64) (64,) Dense_0 (7744, 512) (512,) Dense_1 (512, 1200) (1200,)

The normal version looks like this (this is dqn/SpaceInvaders/2/ckpt.199) Conv_0 (8, 8, 4, 32) (32,) Conv_1 (4, 4, 32, 64) (64,) Conv_2 (3, 3, 64, 64) (64,) Dense_0 (7744, 512) (512,) Dense_1 (512, 6) (6,)

This appears to be a Quantile network params as they have 1200 outputs as well

The rest of the dqn/SpaceInvaders network params are completely normal and work fine I have no idea how or why this would have happened and as a fix, everyone can just use the other 4 agents

Not sure what you want to do with this info, just wanted to report it as a bug

Replication code

import os
import pickle

import flax
import gym
from dopamine.jax.networks import NatureDQNNetwork
from flax.training import checkpoints
import jax.numpy as jnp

os.system('gsutil -m cp -r gs://download-dopamine-rl/jax/dqn/SpaceInvaders .')

env = gym.make('SpaceInvaders-v4')
network_def = NatureDQNNetwork(num_actions=env.action_space.n)
with open(f'SpaceInvaders/1/ckpt.199', 'rb') as file:
    param_data = pickle.load(file)
    network_params = flax.core.FrozenDict({
        'params': checkpoints.convert_pre_linen(
            param_data['online_params']).unfreeze()
    })
    for layer, data in network_params['params'].items():
        print(layer, data['kernel'].shape, data['bias'].shape)

obs = jnp.ones((84, 84, 4))
network_def.apply(network_params, obs)
psc-g commented 2 years ago

hi, thanks for reporting this! i'll look into this and get back to you. it's possible i copied over the incorrect checkpoint.