kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

how to print-debugging inside model #135

Closed jiasenlu closed 3 years ago

jiasenlu commented 3 years ago

I am new to jax and thanks for this great codebase. I wonder is there a way to do print debugging inside the model? Or do you have any preferred way? Thanks!

kingoflolz commented 3 years ago

https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#jax.experimental.host_callback.id_print