LeCAR-Lab / CoVO-MPC

Official implementation for the paper "CoVO-MPC: Theoretical Analysis of Sampling-based MPC and Optimal Covariance Design" accepted by L4DC 2024. CoVO-MPC is an optimal sampling-based MPC algorithm.
https://lecar-lab.github.io/CoVO-MPC/
Apache License 2.0
115 stars 7 forks source link

🐛 Unable to log DynamicJaxprTrace, use jax.debug.print() instead #2

Closed bzx20 closed 1 year ago

bzx20 commented 1 year ago

❓ Issue

When I wanted to log the total_loss[0](which is an array, shape (320,)), or total_loss_mean(which is a value) like this:

def _log_loss(total_loss):
      total_loss_mean = jnp.mean(total_loss[0])
      logging.info(total_loss_mean)

Log will be

[INFO]: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=3/0)>

🤔 Possible solutions

  1. I read this how-can-i-convert-a-jax-tracer-to-a-numpy-array. It denotes that, if you wish to print a traced value at runtime for debugging purposes, you might consider using jax.debug.print(). So I just use jax.debug.print() to print the values.

  2. jax.debug.print() is pretty good, but the output cannot be saved into the log file (Const strings can be saved into log files, but DynamicJaxprTrace cannot), even I tried this:

    
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)

file_handler = logging.FileHandler('debug.log') file_handler.setLevel(logging.DEBUG)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter)

logger.addHandler(file_handler) jax.debug.print = logger.debug



3. Maybe there exits a better method to replace `jax.debug.print()`.
jc-bao commented 1 year ago

This is caused by jitted function. jax.debug.print() might break the JIT functionality and lead to a significant performance drop. You can refer to jaxrl implementations of logging with tensorboard.