danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

Concatenate transform incompatible with jax tracing #132

Closed mdmould closed 7 months ago

mdmould commented 8 months ago

Calls to transform, inverse, etc. for Concatenate bijection break jax tracing (e.g., via jit), because of the int array split_idxs passed to array_split.

E.g., this fails:

import jax
import jax.numpy as jnp
from flowjax.bijections import Affine, Concatenate

bijections = (Affine(loc = jnp.zeros(1)),) * 2
bijection = Concatenate(bijections)

jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))

whereas Stack works (because it splits input arrays equally):

import jax
import jax.numpy as jnp
from flowjax.bijections import Affine, Stack

bijections = (Affine(),) * 2
bijection = Stack(bijections)

jax.jit(jax.vmap(bijection.transform))(jnp.ones((1, 2)))

A sufficient fix seems to be to convert split_idxs to a sequence.