Closed mdmould closed 7 months ago
Calls to transform, inverse, etc. for Concatenate bijection break jax tracing (e.g., via jit), because of the int array split_idxs passed to array_split.
split_idxs
array_split
E.g., this fails:
import jax import jax.numpy as jnp from flowjax.bijections import Affine, Concatenate bijections = (Affine(loc = jnp.zeros(1)),) * 2 bijection = Concatenate(bijections) jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))
whereas Stack works (because it splits input arrays equally):
import jax import jax.numpy as jnp from flowjax.bijections import Affine, Stack bijections = (Affine(),) * 2 bijection = Stack(bijections) jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))
A sufficient fix seems to be to convert split_idxs to a sequence.
Calls to transform, inverse, etc. for Concatenate bijection break jax tracing (e.g., via jit), because of the int array
split_idxs
passed toarray_split
.E.g., this fails:
whereas Stack works (because it splits input arrays equally):
A sufficient fix seems to be to convert
split_idxs
to a sequence.