I want to logging the following PPO training script writing in jax in real-time. But currently, I can only logging all related metrics and plotting with plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)) when everything is done. The training code is as follows:
train_jit = jax.jit(make_train(config))
out = jax.block_until_ready(train_jit(rng))
Help me to modify the following training code so that it can logging status in real time. Please not the train function need to be jit. The following is the original training function code:
I want to logging the following PPO training script writing in jax in real-time. But currently, I can only logging all related metrics and plotting with
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
when everything is done. The training code is as follows:Help me to modify the following training code so that it can logging status in real time. Please not the train function need to be jit. The following is the original training function code:
In the implementation, the training process is called by
which make it hard to logging metrics accordingly. Modify the code so that I can log the training status in real-time.