The normal version looks like this (this is dqn/SpaceInvaders/2/ckpt.199)
Conv_0 (8, 8, 4, 32) (32,) Conv_1 (4, 4, 32, 64) (64,) Conv_2 (3, 3, 64, 64) (64,) Dense_0 (7744, 512) (512,) Dense_1 (512, 6) (6,)
This appears to be a Quantile network params as they have 1200 outputs as well
The rest of the dqn/SpaceInvaders network params are completely normal and work fine
I have no idea how or why this would have happened and as a fix, everyone can just use the other 4 agents
Not sure what you want to do with this info, just wanted to report it as a bug
Replication code
import os
import pickle
import flax
import gym
from dopamine.jax.networks import NatureDQNNetwork
from flax.training import checkpoints
import jax.numpy as jnp
os.system('gsutil -m cp -r gs://download-dopamine-rl/jax/dqn/SpaceInvaders .')
env = gym.make('SpaceInvaders-v4')
network_def = NatureDQNNetwork(num_actions=env.action_space.n)
with open(f'SpaceInvaders/1/ckpt.199', 'rb') as file:
param_data = pickle.load(file)
network_params = flax.core.FrozenDict({
'params': checkpoints.convert_pre_linen(
param_data['online_params']).unfreeze()
})
for layer, data in network_params['params'].items():
print(layer, data['kernel'].shape, data['bias'].shape)
obs = jnp.ones((84, 84, 4))
network_def.apply(network_params, obs)
Im using the very helpful set of pre-trained JAX agent however in training the Dqn one I found an issue with one of the SpaceInvader network params
For dqn/SpaceInvaders/1/ckpt.199 output layer has 1200 outputs rather than the expected 6
Conv_0 (8, 8, 4, 32) (32,) Conv_1 (4, 4, 32, 64) (64,) Conv_2 (3, 3, 64, 64) (64,) Dense_0 (7744, 512) (512,) Dense_1 (512, 1200) (1200,)
The normal version looks like this (this is dqn/SpaceInvaders/2/ckpt.199)
Conv_0 (8, 8, 4, 32) (32,) Conv_1 (4, 4, 32, 64) (64,) Conv_2 (3, 3, 64, 64) (64,) Dense_0 (7744, 512) (512,) Dense_1 (512, 6) (6,)
This appears to be a Quantile network params as they have 1200 outputs as well
The rest of the dqn/SpaceInvaders network params are completely normal and work fine I have no idea how or why this would have happened and as a fix, everyone can just use the other 4 agents
Not sure what you want to do with this info, just wanted to report it as a bug
Replication code