google-research / vdm

Apache License 2.0
298 stars 25 forks source link

What is the purpose of `substeps` hyperparameter? #2

Closed baofff closed 2 years ago

baofff commented 2 years ago

Thanks for the great work. I find an interesting usage of jax.lax.scan in your code. Applying it to p_train_step will induce a successive running of p_train_step for substeps times, and it seems that it won't affect the training result. What is the benefit of it compared to the normal training (i.e., without using jax.lax.scan and p_train_step)?

baofff commented 2 years ago

I've found the benefit. This hyperparameter makes jax compile multiple updates together, making the training faster.