Closed mdmould closed 8 months ago
There is a bug in the segmenting of inputs given to the transform/inverse methods of flowjax.bijections.Concatenate.
transform
inverse
flowjax.bijections.Concatenate
E.g., the following will fail, with the specific error depending on the shapes of the stacked bijections:
import jax.numpy as jnp from flowjax.bijections import Affine, Concatenate n = 3 bijections = (Affine(loc = jnp.zeros(1)),) * n bijection = Concatenate(bijections) bijection.transform(jnp.ones(n))
There is a bug in the segmenting of inputs given to the
transform
/inverse
methods offlowjax.bijections.Concatenate
.E.g., the following will fail, with the specific error depending on the shapes of the stacked bijections: