patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
1.91k stars 131 forks source link

docs request for filter_jacs/hessian #741

Open aidancrilly opened 1 month ago

aidancrilly commented 1 month ago

Hi,

Firstly thanks for this great library,

Secondly, apologies I am clearly misunderstanding how the filters work for the jacobian and hessian transformations. Here is a MWE of my issue:

import jax
import jax.numpy as jnp
import equinox

def func(params):
    output = {'a' : params['a']**2, 'b' : params['b']*params['a'], 'c' : params['b']**2}
    return output

jac_filter = equinox.filter_jacrev(func)
jac_unfiltered = jax.jacrev(func)
params = {'a' : 1.0*jnp.ones((1,)), 'b' : 2.0*jnp.ones((1,))}

print(equinox.filter_jit(func)(params))

print(jac_unfiltered(params))
print(jac_filter(params))

The input and output of the function are PyTrees of JAX arrays and the docs say "The inputs and outputs may be arbitrary PyTrees.".

Both the filter_jit and unfiltered jac_rev run without error. However, the filtered jacrev case throws an error:

File "f:\Anaconda3\envs\lagradept\Lib\site-packages\equinox\_ad.py", line 452, in __call__
    out = combine(dynamic_out, static_out)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "f:\Anaconda3\envs\lagradept\Lib\site-packages\equinox\_filters.py", line 200, in combine
    return jtu.tree_map(_combine, *pytrees, is_leaf=_is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "f:\Anaconda3\envs\lagradept\Lib\site-packages\jax\_src\tree_util.py", line 319, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
                             ^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Expected dict, got None.

What is the correct usage here? An example in the docs would really help out!

lockwo commented 1 month ago

This stemmed from the issues of allowing non array inputs and outputs, and just shifted to only allowing non differentiable inputs (see more: https://github.com/patrick-kidger/equinox/pull/734). If you have no static outputs it would error. This is now fixed on main from the previous PR, and I see:

>>> import jax
>>> import jax.numpy as jnp
>>> import equinox
>>> 
>>> def func(params):
...     output = {'a' : params['a']**2, 'b' : params['b']*params['a'], 'c' : params['b']**2}
...     return output
... 
>>> jac_filter = equinox.filter_jacrev(func)
>>> jac_unfiltered = jax.jacrev(func)
>>> params = {'a' : 1.0*jnp.ones((1,)), 'b' : 2.0*jnp.ones((1,))}
>>> 
>>> print(equinox.filter_jit(func)(params))
{'a': Array([1.], dtype=float32), 'b': Array([2.], dtype=float32), 'c': Array([4.], dtype=float32)}
>>> 
>>> print(jac_unfiltered(params))
{'a': {'a': Array([[2.]], dtype=float32), 'b': Array([[0.]], dtype=float32)}, 'b': {'a': Array([[2.]], dtype=float32), 'b': Array([[1.]], dtype=float32)}, 'c': {'a': Array([[0.]], dtype=float32), 'b': Array([[4.]], dtype=float32)}}
>>> print(jac_filter(params))
{'a': {'a': Array([[2.]], dtype=float32), 'b': Array([[0.]], dtype=float32)}, 'b': {'a': Array([[2.]], dtype=float32), 'b': Array([[1.]], dtype=float32)}, 'c': {'a': Array([[0.]], dtype=float32), 'b': Array([[4.]], dtype=float32)}}

when I run it with these changes.

aidancrilly commented 1 month ago

Understood, thank you. Will use main for now and look forward to the next release.