danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

feat: matmul bijection #151

Closed nstarman closed 4 months ago

nstarman commented 5 months ago

Good for whitening covariate data.

danielward27 commented 5 months ago

Whilst the flexibility/generality of what you suggest could be convenient, it would come with some issues. If a user were to attempt to optimize the bijection, after parameter updates the inverse and forward transformations would no longer correspond, and further there is nothing constraining the matrix to be positive definite.

For a whitening transformation I would currently suggest using triangular affine along with the cholesky decomposition of the covariance matrix:

from flowjax.bijections import TriangularAffine
import jax.numpy as jnp
import jax.random as jr
import jax
from flowjax.distributions import MultivariateNormal

loc, covariance_matrix = jnp.ones(2), jnp.array([[2, 0.5], [0.5, 3]])
mvn = MultivariateNormal(loc=loc, covariance=covariance_matrix)
samps = mvn.sample(jr.PRNGKey(2), (10000, ))

# We could equally just take mvn.bijection, but to be explicit
cov_cholesky = jnp.linalg.cholesky(covariance_matrix)
tri_aff = TriangularAffine(loc=loc, arr=cov_cholesky)
whitened_samps = jax.vmap(tri_aff.inverse)(samps)

This also has the benefit of fast logdet calculation, and similar complexity in the forward and inverse direction. If this doesn't suit your needs, let me know some more details and hopefully we can figure out a solution.