I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.
class Conditioner(nn.Module):
event_shape: Sequence[int]
num_bijector_params: int
hidden_dims: Sequence[int]
@nn.compact
def __call__(self, z: Array, h: Array) -> Array:
h = jnp.concatenate((z.flatten(), h.flatten()), axis=0)
for hidden_dim in self.hidden_dims:
h = nn.Dense(hidden_dim)(h)
h = nn.relu(h)
y = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(h)
y = y.reshape(tuple(self.event_shape) + (self.num_bijector_params,))
return y
class MyModel(nn.Module):
hidden_dims: Sequence[int]
num_flows: int
num_bins: int
event_shape: Sequence[int]
conditioner: Optional[KwArgs] = None
@nn.compact
def __call__(self, x, y: Optional[Array] = None):
# base distribution
output_dim = np.prod(self.event_shape)
base = distrax.Independent(
distrax.Normal(loc=jnp.zeros(output_dim,), scale=jnp.ones(output_dim,)), len(self.event_shape)
)
# bijector
# Number of parameters for the rational-quadratic spline:
# - `num_bins` bin widths
# - `num_bins` bin heights
# - `num_bins + 1` knot slopes
# for a total of `3 * num_bins + 1` parameters.
num_bijector_params = 3 * self.num_bins + 1
layers = []
mask = jnp.arange(0, np.prod(self.event_shape)) % 2
mask = jnp.reshape(mask, self.event_shape)
mask = mask.astype(bool)
def bijector_fn(params: Array):
return distrax.RationalQuadraticSpline(
params, range_min=-3.0, range_max=3.0
)
h = x.flatten()
# shared feature extractor
for hidden_dim in self.hidden_dims:
h = nn.Dense(hidden_dim)(h)
h = nn.relu(h)
for i in range(self.num_flows):
conditioner = Conditioner(
event_shape=self.event_shape,
num_bijector_params=num_bijector_params,
**(self.conditioner or {}),
)
layer = distrax.MaskedCoupling(
mask=mask,
bijector=bijector_fn,
conditioner=functools.partial(conditioner, h=h),
)
layers.append(layer)
mask = ~mask
bijector = distrax.Inverse(distrax.Chain(layers))
transformed = distrax.Transformed(base, bijector)
if y is not None:
return transformed, transformed.log_prob(y)
else:
return transformed
model = MyModel(
hidden_dims = [64, 32],
num_flows = 3,
num_bins = 8,
event_shape = (6,),
conditioner = {'hidden_dims': [64, 32]}
)
variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))
dist = model.apply(variables, jnp.ones((28, 28, 1)))
dist.event_shape
Which raises the following error:
JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
Thankfully, evaluating log probs, i.e., dist.log_prob(jnp.zeros(6,)), runs without any error.
Any idea why this is happening? Am I doing something wrong when constructing the model?
On that note, I've also found that if I initialize the parameters like this:
The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:
But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?
I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.
The error can be reproduced with the following setup (also in this Colab notebook https://colab.research.google.com/drive/1RLRZul_pHnglcT_-YZ7mcuKLU1qd3w5O?usp=sharing).
Which raises the following error:
Thankfully, evaluating log probs, i.e.,
dist.log_prob(jnp.zeros(6,))
, runs without any error.Any idea why this is happening? Am I doing something wrong when constructing the model?
On that note, I've also found that if I initialize the parameters like this:
The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:
But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?
Thanks for the help!