patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.06k stars 135 forks source link

How to reduce compile time of loops over "PyTree Callables"? #685

Open pfackeldey opened 6 months ago

pfackeldey commented 6 months ago

Dear @patrick-kidger,

I am currently developing a fitting library (https://github.com/pfackeldey/evermore) for binned likelihood fits. These type of fits are typically performed within high energy physics analyses.

A common pattern is to evaluate a vast amount (O(100-1000)) systematic uncertainties that affect bins of a histogram. You can see these systematic uncertainties as functions that may vary each bin of such a histogram.

Here is some pseudo-code to explain this in more detail:

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, PyTree

class systematic1(eqx.Module):
  def __call__(self, hist: Array) -> Array:
    # implement the variation that `systematic1` does to `hist` here
    ...

class systematic2(eqx.Module):
  def __call__(self, hist: Array) -> Array:
    # implement the variation that `systematic2` does to `hist` here
    ...

histogram = jnp.array([10, 20, 30])

@eqx.filter_jit
def vary_histogram(hist: Array, *systematics: eqx.Module) -> Array:
  for systematic in systematics:
    histogram = systematic(histogram)
  return histogram

# apply variations
varied_histogram = vary_histogram(histogram, systematic1(), systematic2())

This principle works great for a small amount of systematic uncertainties. However, for many systematic uncertainties compile time becomes a non-negligible factor.

My question is if you have tipps on how to reduce the compile time for this pattern. Is it possible to use jax.lax.scan or eqx.filter_vmap to apply a set of functions to one jax.Array?

(I think the equivalent in Deep Learning models would be a scan over layers that are of different kinds.)

Best, Peter

PS: In addition, I'd like to ask what's the best way to check if two instances of eqx.Module(s) with jax.Array(s) as leaves are identical (in a jitted function)?

pfackeldey commented 6 months ago

I've managed to rephrase my compute task from a array of PyTrees to PyTree of arrays paradigm. Then I could use jax.lax.scan similar to the example you provided on the documentation with scan-over-layers! :)

patrick-kidger commented 6 months ago

Awesome, I'm glad you got that figured out! Indeed that'd be the best answer.

For your "PS" question, then probably eqx.tree_equal. It will return a JIT'd (i.e. traced) boolean value.

pfackeldey commented 6 months ago

Thank you very much for your reply! I'm looking already into eqx.tree_equal but I am not yet fully understanding how to use it in a JIT'd function: Essentially I want to check if two trees are unequal, and I am not getting where I want currently:

import equinox as eqx
import jax
import jax.numpy as jnp

t1 = {"a": jnp.array([0.0]), "b": jnp.array([0.0])}
t2 = {"a": jnp.array([1.0]), "b": jnp.array([1.0])}

# What I want:
# a function `f` that returns True if two trees are unequal
# 
# def f(t1, t2):
#   return ?
# f(t1, t2) # -> True
# f(t1, t1) # -> False
# jax.jit(f)(t1,t2) # -> True
# jax.jit(f)(t1,t1) # -> False

# This won't work because of TracerBoolConversionError
def f1(t1, t2):
  return not eqx.tree_equal(t1, t2)

f1(t1, t2) # -> True
f1(t1, t1) # -> False
jax.jit(f1)(t1,t2) # -> TracerBoolConversionError
jax.jit(f1)(t1,t1) # -> TracerBoolConversionError

# These (f2, f3) don't return the correct comparison
def f2(t1, t2):
   return eqx.tree_equal(t1, t2) is False

f2(t1, t2) # -> False
f2(t1, t1) # -> False
jax.jit(f2)(t1,t2) # -> False
jax.jit(f2)(t1,t1) # -> False

def f3(t1, t2):
   return eqx.tree_equal(t1, t2) is not True

f2(t1, t2) # -> True
f2(t1, t1) # -> True
jax.jit(f2)(t1,t2) # -> True
jax.jit(f2)(t1,t1) # -> True

Am I missing the obvious solution here? :D

At the end of the day I'd like to throw an error if these two trees are not exactly the same.

pfackeldey commented 6 months ago

Update:

I think I just found the solution:

def f(t1, t2):
  return eqx.error_if(t1, eqx.tree_equal(t1, t2) != True, "Error, t1 and t2 are unequal")

f(t1, t2) # -> Error
f(t1, t1) # -> pass
jax.jit(f)(t1,t2) # -> Error
jax.jit(f)(t1,t1) # -> pass
patrick-kidger commented 6 months ago

Yup, that's it! When you're inside a jax.jit then all computations are deferred, and Python is just used to build up the computation graph. So the erroring also has to be deferred.

FWIW you can stop a logical inverse with jnp.invert. (not is special-cased by Python so it's not overridable by JAX.)

pfackeldey commented 6 months ago

Great! If that's fine for you I will close the issue as I consider all my questions as solved :)