google-research / t5x

Apache License 2.0
2.57k stars 296 forks source link

Debugging with pdb and disabling JIT #959

Open AkshitaB opened 1 year ago

AkshitaB commented 1 year ago

What is the best way to debug and inspect values when running train.py? I set the flags as follows, but in the training loop, I still encounter traced values.

JAX_DISABLE_JIT=1 python3 -m t5x.train \
  ...
  --gin.partitioning.PjitPartitioner.use_cpu_pjit=False \
  --gin.train.use_gda=True \
...
switiz commented 1 year ago

did you solve it?

I'm trying too, but it's not working. What I really want to know is the intermediate value of each layer of the model. So, I printed the value, but the dynamic jitxpr array was output and only the shape was printed. jax.debug.print also doesn't work. Even if I disable it in infer.py, it is the same.

infer.py

      logging.info('Running inference on %d batches.', checkpoint_period)
      with jax.disable_jit():
        infer_result = infer_fn(model_ds.enumerate(), rng=chunk_rng)

Is there a way to check the value in t5x? thank you