Closed edwhu closed 6 months ago
Which version of Orbax and flax are you using? I get the following error when I try to run the train.py:
Traceback (most recent call last):
File "/home/shane/src/tdmpc2-jax/tdmpc2_jax/train.py", line 99, in train
with ocp.CheckpointManager(
TypeError: 'CheckpointManager' object does not support the context manager protocol
Things run properly if I replace the context manager bit with the line below and fix the indentation accordingly as in the commit I just pushed.
mngr = ocp.CheckpointManager(checkpoint_path, options=options, item_names=('agent', 'global_step'))
Also, what do you think about replacing TensorBoard with WandB? I'm hesitant to introduce tensorflow as a dependency if I can help it (required by flax.metrics.tensorboard
)
I'm using the newest version of orbax, 0.5.14 which asks for us to manage everything within the mngr context. https://orbax.readthedocs.io/en/latest/api_refactor.html I initially did it the way you edited, since it looked cleaner and avoided indenting the training loop. But then the code raised a deprecation warning of August 1, 2024 - that code will no longer be supported.
For tensorboard vs wandb - I prefer tensorboard first, since it doesn't cause a vendor lockin. I also personally like using WB too, but they've recently become very restrictive with their storage and usage quotas.
I think it's important to have TB as default since it is always free and self hosted. I think we should support multiple types of logging through an abstract logger class that accepts both TB / WB outputs, but I didn't implement that yet.
Makes sense! Upgraded orbax-checkpoint and everything works now. Merged.
Also updated the readme in the develop branch