Open pfackeldey opened 6 months ago
I've managed to rephrase my compute task from a array of PyTrees
to PyTree of arrays
paradigm. Then I could use jax.lax.scan
similar to the example you provided on the documentation with scan-over-layers
! :)
Awesome, I'm glad you got that figured out! Indeed that'd be the best answer.
For your "PS" question, then probably eqx.tree_equal
. It will return a JIT'd (i.e. traced) boolean value.
Thank you very much for your reply!
I'm looking already into eqx.tree_equal
but I am not yet fully understanding how to use it in a JIT'd function:
Essentially I want to check if two trees are unequal, and I am not getting where I want currently:
import equinox as eqx
import jax
import jax.numpy as jnp
t1 = {"a": jnp.array([0.0]), "b": jnp.array([0.0])}
t2 = {"a": jnp.array([1.0]), "b": jnp.array([1.0])}
# What I want:
# a function `f` that returns True if two trees are unequal
#
# def f(t1, t2):
# return ?
# f(t1, t2) # -> True
# f(t1, t1) # -> False
# jax.jit(f)(t1,t2) # -> True
# jax.jit(f)(t1,t1) # -> False
# This won't work because of TracerBoolConversionError
def f1(t1, t2):
return not eqx.tree_equal(t1, t2)
f1(t1, t2) # -> True
f1(t1, t1) # -> False
jax.jit(f1)(t1,t2) # -> TracerBoolConversionError
jax.jit(f1)(t1,t1) # -> TracerBoolConversionError
# These (f2, f3) don't return the correct comparison
def f2(t1, t2):
return eqx.tree_equal(t1, t2) is False
f2(t1, t2) # -> False
f2(t1, t1) # -> False
jax.jit(f2)(t1,t2) # -> False
jax.jit(f2)(t1,t1) # -> False
def f3(t1, t2):
return eqx.tree_equal(t1, t2) is not True
f2(t1, t2) # -> True
f2(t1, t1) # -> True
jax.jit(f2)(t1,t2) # -> True
jax.jit(f2)(t1,t1) # -> True
Am I missing the obvious solution here? :D
At the end of the day I'd like to throw an error if these two trees are not exactly the same.
Update:
I think I just found the solution:
def f(t1, t2):
return eqx.error_if(t1, eqx.tree_equal(t1, t2) != True, "Error, t1 and t2 are unequal")
f(t1, t2) # -> Error
f(t1, t1) # -> pass
jax.jit(f)(t1,t2) # -> Error
jax.jit(f)(t1,t1) # -> pass
Yup, that's it! When you're inside a jax.jit then all computations are deferred, and Python is just used to build up the computation graph. So the erroring also has to be deferred.
FWIW you can stop a logical inverse with jnp.invert
. (not
is special-cased by Python so it's not overridable by JAX.)
Great! If that's fine for you I will close the issue as I consider all my questions as solved :)
Dear @patrick-kidger,
I am currently developing a fitting library (https://github.com/pfackeldey/evermore) for binned likelihood fits. These type of fits are typically performed within high energy physics analyses.
A common pattern is to evaluate a vast amount (O(100-1000)) systematic uncertainties that affect bins of a histogram. You can see these systematic uncertainties as functions that may vary each bin of such a histogram.
Here is some pseudo-code to explain this in more detail:
This principle works great for a small amount of systematic uncertainties. However, for many systematic uncertainties compile time becomes a non-negligible factor.
My question is if you have tipps on how to reduce the compile time for this pattern. Is it possible to use
jax.lax.scan
oreqx.filter_vmap
to apply a set of functions to onejax.Array
?(I think the equivalent in Deep Learning models would be a scan over layers that are of different kinds.)
Best, Peter
PS: In addition, I'd like to ask what's the best way to check if two instances of
eqx.Module
(s) withjax.Array
(s) as leaves are identical (in a jitted function)?