kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
200 stars 23 forks source link

Use scan to reduce NF compilation time #153

Closed kazewong closed 7 months ago

kazewong commented 7 months ago

Currently the NF model uses a basic for-loop for forward and inverse passes

    def forward(
        self, x: Float[Array, " n_dim"]
    ) -> tuple[Float[Array, " n_dim"], Float]:
        log_det = 0.0
        for layer in self.layers:
            x, log_det_i = layer(x)
            log_det += log_det_i
        return x, log_det

This could introduce fairly significant compilation time. Convert this to a scan to reduce the compilation time.