MichaelTMatthews / Craftax_Baselines

MIT License
14 stars 1 forks source link

Problem with policy save file in view_ppo_agent.py #2

Open lbarazza opened 4 months ago

lbarazza commented 4 months ago

When running the PPO baseline on my M1 Mac using the command python ppo.py --save_policy, I encounter ValueError: Unrecognized name format during the policy-saving process within the _save_network function. Despite the error, the policy file is still saved, but with an additional ".0" decimal, which disrupts file type recognition in the view_ppo_agent.py script. This issue prevents the visualization of the policy from running. I don't know if I'm missing something and running it the wrong way, but in case this is a bug, I managed to resolve it by modifying config["TOTAL_TIMESTEPS"] to str(int(config["TOTAL_TIMESTEPS"])) in both line 677 of ppo.py and line 97 of view_ppo_agent.py. Also, I don't seem to be able to load save files from runs with ppo_rnn.py to visualize them with view_ppo_agent.py as I always get a KeyError: 'Dense_7' error. Is the visualization code only supposed to be run for policies from ppo.py? Thank you

MichaelTMatthews commented 4 months ago

Hi lbarazza, thanks for raising this, it's on my radar to fix. The orbax checkpointing I used is now deprecated so I want to refactor the whole saving/loading code while I'm at it. There is no current way to view an RNN policy, but I will push a script to do this soon.