danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

Concatenate fails when stacking more than two bijections #130

Closed mdmould closed 8 months ago

mdmould commented 8 months ago

There is a bug in the segmenting of inputs given to the transform/inverse methods of flowjax.bijections.Concatenate.

E.g., the following will fail, with the specific error depending on the shapes of the stacked bijections:

import jax.numpy as jnp
from flowjax.bijections import Affine, Concatenate

n = 3
bijections = (Affine(loc = jnp.zeros(1)),) * n
bijection = Concatenate(bijections)

bijection.transform(jnp.ones(n))