google / dopamine

Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
https://github.com/google/dopamine
Apache License 2.0
10.52k stars 1.37k forks source link

"ValueError: Expected dict, got FrozenDict" when using ReDO #223

Open Zarzard opened 21 hours ago

Zarzard commented 21 hours ago

Hello there,

I was running this commandpython3 -um dopamine.labs.redo.train --base_dir /tmp --gin_files dopamine/labs/redo/configs/dqn_dense.gin trying to reproduce the results in the ReDO paper. I added the following lines to the dqn_dense.gin file and there're no other changes to the code:

atari_lib.create_atari_environment.game_name = 'DemonAttack'
RecycledDQNAgent.reset_mode = 'neurons'
NeuronRecycler.dead_neurons_threshold = 0.1
NeuronRecycler.reset_period = 1

However, I got the following error:

......
File "/home/joe/dopamine/dopamine/labs/redo/weight_recyclers.py", line 557, in recycle_dead_neurons
    new_mu = reset_momentum_fn(opt_state[0][1], incoming_mask)
ValueError: Expected dict, got FrozenDict({
    params: {
        Conv_0: {
            bias: None,
            kernel: Traced<ShapedArray(bool[8,8,4,32])>with<DynamicJaxprTrace(level=1/0)>,
        },
        Conv_1: {
            bias: None,
            kernel: Traced<ShapedArray(bool[4,4,32,64])>with<DynamicJaxprTrace(level=1/0)>,
        },
        Conv_2: {
            bias: None,
            kernel: Traced<ShapedArray(bool[3,3,64,64])>with<DynamicJaxprTrace(level=1/0)>,
        },
        Dense_0: {
            bias: None,
            kernel: Traced<ShapedArray(bool[7744,512])>with<DynamicJaxprTrace(level=1/0)>,
        },
        final_layer: {
            bias: None,
            kernel: Traced<ShapedArray(float32[512,6])>with<DynamicJaxprTrace(level=1/0)>,
        },
    },
}).

When RecycledDQNAgent.reset_mode in dqn_dense.gin is set to None (i.e., not using ReDO), there is no error.

Zarzard commented 19 hours ago

After modifying the following lines in weight_recyclers.py:

new_mu = reset_momentum_fn(opt_state[0][1], incoming_mask)
new_mu = reset_momentum_fn(new_mu, outgoing_mask)
new_nu = reset_momentum_fn(opt_state[0][2], incoming_mask)
new_nu = reset_momentum_fn(new_nu, outgoing_mask)

to:

new_mu = reset_momentum_fn(opt_state[0][1], flax.core.frozen_dict.unfreeze(incoming_mask))
new_mu = reset_momentum_fn(new_mu, flax.core.frozen_dict.unfreeze(outgoing_mask))
new_nu = reset_momentum_fn(opt_state[0][2], flax.core.frozen_dict.unfreeze(incoming_mask))
new_nu = reset_momentum_fn(new_nu, flax.core.frozen_dict.unfreeze(outgoing_mask))

and modifying online_params, grad in the apply_updates_jitted function in recycled_dqn_agents.py to Dict objects, the code can finally run. However, I haven't run the full 10M training steps yet, so I don't know if my modification to fix the ValueError issue affects the performance.