ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
48 stars 7 forks source link

Orbax checkpointing, Tensorboard Logging #4

Closed edwhu closed 6 months ago

edwhu commented 6 months ago
ShaneFlandermeyer commented 6 months ago

Which version of Orbax and flax are you using? I get the following error when I try to run the train.py:

Traceback (most recent call last):
  File "/home/shane/src/tdmpc2-jax/tdmpc2_jax/train.py", line 99, in train
    with ocp.CheckpointManager(
TypeError: 'CheckpointManager' object does not support the context manager protocol

Things run properly if I replace the context manager bit with the line below and fix the indentation accordingly as in the commit I just pushed.

mngr = ocp.CheckpointManager(checkpoint_path, options=options, item_names=('agent', 'global_step'))

Also, what do you think about replacing TensorBoard with WandB? I'm hesitant to introduce tensorflow as a dependency if I can help it (required by flax.metrics.tensorboard)

edwhu commented 6 months ago

I'm using the newest version of orbax, 0.5.14 which asks for us to manage everything within the mngr context. https://orbax.readthedocs.io/en/latest/api_refactor.html I initially did it the way you edited, since it looked cleaner and avoided indenting the training loop. But then the code raised a deprecation warning of August 1, 2024 - that code will no longer be supported.

edwhu commented 6 months ago

For tensorboard vs wandb - I prefer tensorboard first, since it doesn't cause a vendor lockin. I also personally like using WB too, but they've recently become very restrictive with their storage and usage quotas.

I think it's important to have TB as default since it is always free and self hosted. I think we should support multiple types of logging through an abstract logger class that accepts both TB / WB outputs, but I didn't implement that yet.

ShaneFlandermeyer commented 6 months ago

Makes sense! Upgraded orbax-checkpoint and everything works now. Merged.

ShaneFlandermeyer commented 6 months ago

Also updated the readme in the develop branch