Thanks for the great work.
I find an interesting usage of jax.lax.scan in your code. Applying it to p_train_step will induce a successive running of p_train_step for substeps times, and it seems that it won't affect the training result. What is the benefit of it compared to the normal training (i.e., without using jax.lax.scan and p_train_step)?
Thanks for the great work. I find an interesting usage of jax.lax.scan in your code. Applying it to p_train_step will induce a successive running of p_train_step for
substeps
times, and it seems that it won't affect the training result. What is the benefit of it compared to the normal training (i.e., without using jax.lax.scan and p_train_step)?