Currently the training and sampling phases takes around 3.5 minutes on a A100 node with icelake CPU for RQSpline model with 10 layers, [128,128] conditioners.
See whether swapping make_flow with a scan can solve this compilation problem. This could very well be issue associated with how distrax form the bijector, in that case, there won't be much I can do.
Currently the training and sampling phases takes around 3.5 minutes on a A100 node with icelake CPU for RQSpline model with 10 layers, [128,128] conditioners.
See whether swapping make_flow with a scan can solve this compilation problem. This could very well be issue associated with how distrax form the bijector, in that case, there won't be much I can do.