Closed kazewong closed 7 months ago
Currently the NF model uses a basic for-loop for forward and inverse passes
def forward( self, x: Float[Array, " n_dim"] ) -> tuple[Float[Array, " n_dim"], Float]: log_det = 0.0 for layer in self.layers: x, log_det_i = layer(x) log_det += log_det_i return x, log_det
This could introduce fairly significant compilation time. Convert this to a scan to reduce the compilation time.
Currently the NF model uses a basic for-loop for forward and inverse passes
This could introduce fairly significant compilation time. Convert this to a scan to reduce the compilation time.