Official implementation for the paper "CoVO-MPC: Theoretical Analysis of Sampling-based MPC and Optimal Covariance Design" accepted by L4DC 2024. CoVO-MPC is an optimal sampling-based MPC algorithm.
jax.debug.print() is pretty good, but the output cannot be saved into the log file (Const strings can be saved into log files, but DynamicJaxprTrace cannot), even I tried this:
This is caused by jitted function. jax.debug.print() might break the JIT functionality and lead to a significant performance drop. You can refer to jaxrl implementations of logging with tensorboard.
❓ Issue
When I wanted to log the total_loss[0](which is an array, shape (320,)), or total_loss_mean(which is a value) like this:
Log will be
🤔 Possible solutions
I read this how-can-i-convert-a-jax-tracer-to-a-numpy-array. It denotes that, if you wish to print a traced value at runtime for debugging purposes, you might consider using jax.debug.print(). So I just use
jax.debug.print()
to print the values.jax.debug.print()
is pretty good, but the output cannot be saved into the log file (Const strings can be saved into log files, but DynamicJaxprTrace cannot), even I tried this:file_handler = logging.FileHandler('debug.log') file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter)
logger.addHandler(file_handler) jax.debug.print = logger.debug