Open AkshitaB opened 1 year ago
did you solve it?
I'm trying too, but it's not working. What I really want to know is the intermediate value of each layer of the model. So, I printed the value, but the dynamic jitxpr array was output and only the shape was printed. jax.debug.print also doesn't work. Even if I disable it in infer.py, it is the same.
logging.info('Running inference on %d batches.', checkpoint_period)
with jax.disable_jit():
infer_result = infer_fn(model_ds.enumerate(), rng=chunk_rng)
Is there a way to check the value in t5x? thank you
What is the best way to debug and inspect values when running train.py? I set the flags as follows, but in the training loop, I still encounter traced values.