google-deepmind / distrax

Apache License 2.0
528 stars 32 forks source link

Distrax transformations raise 'jax.core' has no attribute 'unitvar' error #191

Closed thomaspinder closed 2 years ago

thomaspinder commented 2 years ago

tldr; When upgrading from Jax v0.3.5 to v0.3.16, Distrax raises an AttributeError when accessing the bijectors module.

Context: Within GPJax one use of Distrax is for parameter bijections within an optimisation routine. We use a similar parameter structuring to Haiku with a dictionary of dictionaries to represent individual parameters and the parameters' value is then transformed using jax.tree_util.tree_map.

Issue: After upgrading Jax/Jaxlib and Distrax to their latest versions, running this parameter transformation routine raises the error

gpjax/parameters.py:185: in transform
    return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/jax/_src/tree_util.py:205: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/jax/_src/tree_util.py:205: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
gpjax/parameters.py:185: in <lambda>
    return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/distrax-0.1.2-py3.9.egg/distrax/_src/bijectors/lambda_bijector.py:112: in inverse
    return self._inverse(y)
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/distrax-0.1.2-py3.9.egg/distrax/_src/utils/transformations.py:119: in wrapped
    out = _interpret_inverse(jaxpr, consts, *args)

It appears that the issue is in relation to Distrax being unable to interpret Jax's unitvar type. However, this type is present in Jax's core module see here.

Minimal working example:

import jax.numpy as jnp 
import jax 
import distrax as dx 

Softplus = dx.Lambda(lambda x: jnp.log(1.0 + jnp.exp(x)))

params = {'kernel': {'lengthscale': jnp.array([1.]), 'variance': jnp.array([0.5])},
          'likelihood': {'variance': jnp.array([1.])}}

unconstrainers = {'kernel': {'lengthscale': Softplus.inverse, 'variance': Softplus.inverse},
          'likelihood': {'variance': Softplus.inverse}}

# Below line errors
unconstrained_params = jax.tree_util.tree_map(lambda param, trans: trans(param), params, unconstrainers)

I've had a look at the underlying Distrax code and I can't see how this can easily be resovled. Do you know how this issue can be resolved - we'd really like to increment the Jax/Jaxlib version in our dependencies but can't until this is resolved.

If it's helpful, then here is the corresponding jaxpr for the issue.

jaxpr = { lambda ; a:f64[1]. let
    b:f64[1] = exp a
    c:f64[1] = add 1.0 b
    d:f64[1] = log c
  in (d,) }, consts = [], args = (DeviceArray([1.], dtype=float64),), read = <function _interpret_inverse.<locals>.read at 0x7f7a285c4af0>, write = <function _interpret_inverse.<locals>.write at 0x7f7a285c4820>

    def _interpret_inverse(jaxpr, consts, *args):
      """Interprets and executes the inverse of `jaxpr`."""
      env = {}

      def read(var):
        return var.val if isinstance(var, jax.core.Literal) else env[var]
      def write(var, val):
        env[var] = val

>     write(jax.core.unitvar, jax.core.unit)
E     AttributeError: module 'jax.core' has no attribute 'unitvar'

../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/distrax-0.1.2-py3.9.egg/distrax/_src/utils/transformations.py:250: AttributeError
thomaspinder commented 2 years ago

This is resolved by explicitly defining the inverse transformation i.e.,

Softplus = dx.Lambda(
    forward=lambda x: jnp.log(1 + jnp.exp(x)),
    inverse=lambda x: jnp.log(jnp.exp(x) - 1.0),
)