Open aidancrilly opened 1 month ago
This stemmed from the issues of allowing non array inputs and outputs, and just shifted to only allowing non differentiable inputs (see more: https://github.com/patrick-kidger/equinox/pull/734). If you have no static outputs it would error. This is now fixed on main from the previous PR, and I see:
>>> import jax
>>> import jax.numpy as jnp
>>> import equinox
>>>
>>> def func(params):
... output = {'a' : params['a']**2, 'b' : params['b']*params['a'], 'c' : params['b']**2}
... return output
...
>>> jac_filter = equinox.filter_jacrev(func)
>>> jac_unfiltered = jax.jacrev(func)
>>> params = {'a' : 1.0*jnp.ones((1,)), 'b' : 2.0*jnp.ones((1,))}
>>>
>>> print(equinox.filter_jit(func)(params))
{'a': Array([1.], dtype=float32), 'b': Array([2.], dtype=float32), 'c': Array([4.], dtype=float32)}
>>>
>>> print(jac_unfiltered(params))
{'a': {'a': Array([[2.]], dtype=float32), 'b': Array([[0.]], dtype=float32)}, 'b': {'a': Array([[2.]], dtype=float32), 'b': Array([[1.]], dtype=float32)}, 'c': {'a': Array([[0.]], dtype=float32), 'b': Array([[4.]], dtype=float32)}}
>>> print(jac_filter(params))
{'a': {'a': Array([[2.]], dtype=float32), 'b': Array([[0.]], dtype=float32)}, 'b': {'a': Array([[2.]], dtype=float32), 'b': Array([[1.]], dtype=float32)}, 'c': {'a': Array([[0.]], dtype=float32), 'b': Array([[4.]], dtype=float32)}}
when I run it with these changes.
Understood, thank you. Will use main for now and look forward to the next release.
Hi,
Firstly thanks for this great library,
Secondly, apologies I am clearly misunderstanding how the filters work for the jacobian and hessian transformations. Here is a MWE of my issue:
The input and output of the function are PyTrees of JAX arrays and the docs say "The inputs and outputs may be arbitrary PyTrees.".
Both the filter_jit and unfiltered jac_rev run without error. However, the filtered jacrev case throws an error:
What is the correct usage here? An example in the docs would really help out!