tldr; When upgrading from Jax v0.3.5 to v0.3.16, Distrax raises an AttributeError when accessing the bijectors module.
Context:
Within GPJax one use of Distrax is for parameter bijections within an optimisation routine. We use a similar parameter structuring to Haiku with a dictionary of dictionaries to represent individual parameters and the parameters' value is then transformed using jax.tree_util.tree_map.
Issue:
After upgrading Jax/Jaxlib and Distrax to their latest versions, running this parameter transformation routine raises the error
gpjax/parameters.py:185: in transform
return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/jax/_src/tree_util.py:205: in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/jax/_src/tree_util.py:205: in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
gpjax/parameters.py:185: in <lambda>
return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/distrax-0.1.2-py3.9.egg/distrax/_src/bijectors/lambda_bijector.py:112: in inverse
return self._inverse(y)
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/distrax-0.1.2-py3.9.egg/distrax/_src/utils/transformations.py:119: in wrapped
out = _interpret_inverse(jaxpr, consts, *args)
It appears that the issue is in relation to Distrax being unable to interpret Jax's unitvar type. However, this type is present in Jax's core module see here.
I've had a look at the underlying Distrax code and I can't see how this can easily be resovled. Do you know how this issue can be resolved - we'd really like to increment the Jax/Jaxlib version in our dependencies but can't until this is resolved.
If it's helpful, then here is the corresponding jaxpr for the issue.
jaxpr = { lambda ; a:f64[1]. let
b:f64[1] = exp a
c:f64[1] = add 1.0 b
d:f64[1] = log c
in (d,) }, consts = [], args = (DeviceArray([1.], dtype=float64),), read = <function _interpret_inverse.<locals>.read at 0x7f7a285c4af0>, write = <function _interpret_inverse.<locals>.write at 0x7f7a285c4820>
def _interpret_inverse(jaxpr, consts, *args):
"""Interprets and executes the inverse of `jaxpr`."""
env = {}
def read(var):
return var.val if isinstance(var, jax.core.Literal) else env[var]
def write(var, val):
env[var] = val
> write(jax.core.unitvar, jax.core.unit)
E AttributeError: module 'jax.core' has no attribute 'unitvar'
../../miniconda3/envs/gpjax_bump/lib/python3.9/site-packages/distrax-0.1.2-py3.9.egg/distrax/_src/utils/transformations.py:250: AttributeError
tldr; When upgrading from Jax v0.3.5 to v0.3.16, Distrax raises an AttributeError when accessing the bijectors module.
Context: Within GPJax one use of Distrax is for parameter bijections within an optimisation routine. We use a similar parameter structuring to Haiku with a dictionary of dictionaries to represent individual parameters and the parameters' value is then transformed using
jax.tree_util.tree_map
.Issue: After upgrading Jax/Jaxlib and Distrax to their latest versions, running this parameter transformation routine raises the error
It appears that the issue is in relation to Distrax being unable to interpret Jax's
unitvar
type. However, this type is present in Jax's core module see here.Minimal working example:
I've had a look at the underlying Distrax code and I can't see how this can easily be resovled. Do you know how this issue can be resolved - we'd really like to increment the Jax/Jaxlib version in our dependencies but can't until this is resolved.
If it's helpful, then here is the corresponding jaxpr for the issue.