ikostrikov / rlpd

MIT License
201 stars 22 forks source link

Flax FrozenDict: dict.copy() takes no keyword arguments #6

Open BurgerAndreas opened 9 months ago

BurgerAndreas commented 9 months ago

Reproduce error

Running

XLA_PYTHON_CLIENT_PREALLOCATE=false python train_finetuning_pixels.py --env_name=cheetah-run-v0 \
                --start_training 5000 \
                --max_steps 300000 \
                --config=configs/rlpd_pixels_config.py \
                --project_name=rlpd_vd4rl

I am getting: TypeError: dict.copy() takes no keyword arguments.

Possible fix

In file rlpd/rlpd/agents/drq/drq_learner.py:

import flax.core.frozen_dict as frozen_dict
actor_params = frozen_dict.FrozenDict(actor_def.init(actor_key, observations)["params"]) # line 121
critic_params = frozen_dict.FrozenDict(critic_def.init(critic_key, observations, actions)["params"]) # line 145
jren03 commented 5 months ago

I believe this might be related to Flax's migration from frozen_dict to regular Python dictionaries as the return type, according to the issue here. Note the migration note here for Flax 0.7.1 onwards. Not sure what exact lines are erroring for you, but another possible workaround could be using the flax.core.frozen_dict utility functions, described here.

Hope this is helpful!