Open TC01 opened 2 months ago
It seems that this can be fixed by explicitly converting poi
to a float before checking if it's in the hypotest cache here. I printed out the types and at some point in the test run poi
switched from being a numpy float64 to being a jaxlib type; I can trace further to see exactly why that happens and then maybe submit a PR.
First thing we'll need to do is understand why there aren't test failures
which will probably mean revisiting PR #1274. So writing a failing test would be a good start, so that a PR can make it pass.
Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.
Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.
Nope, which is likely which explains why it wasn't caught. (Adding the backend
fixture in the test will have it run on all the backends).
@matthewfeickert i saw the PR, and I think we need to swap the way we're approaching this. Here's my suggestion instead of type-casting - we need to add in shims across each lib and move some functions into our tensorlib
instead to make them backend-dependent (or use a shim to swap them out as needed, like we do for scipy.optimize
)
See this example:
from functools import lru_cache
import time
import timeit
import jax.numpy as jnp
import jax
import tensorflow as tf
def slow(n):
time.sleep(1)
return n**2
fast = lru_cache(maxsize=None)(slow)
fast_jax = jax.jit(slow)
fast_tflow = tf.function(jit_compile=True)(slow)
value = 5
print('slow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast')
print(timeit.timeit(lambda: [fast(value), fast(value), fast(value), fast(value), fast(value)], number=1))
value = jnp.array(5)
print('slow, jax')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, jax')
print(timeit.timeit(lambda: [fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value)], number=1))
value = tf.convert_to_tensor(5)
print('slow, tensorflow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, tensorflow')
print(timeit.timeit(lambda: [fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value)], number=1))
which outputs
$ python cache.py
slow
5.012567336
fast
1.0029977690000003
slow, jax
5.043927394000001
fast, jax
1.0195144690000006
slow, tensorflow
5.017408181999997
fast, tensorflow
1.0631543910000012
so we can definitely cache those values by JIT-ing for the toms748 scan here and that's probably what we want to do. My suggestion might be that we support pyhf.tensor.jit
with something similiar to the signature of jax.jit
across all backends (yes even numpy, but numpy would be an lru_cache
).
we need to add in shims across each lib and move some functions into our
tensorlib
instead to make them backend-dependent
Okay, sounds good. Let's start up a seperate series of PRs to do this.
Summary
Hello; perhaps this is known but I thought I'd file a bug report just in case. I was testing the
upper_limits
API and discovered that the example given in the documentation doesn't seem to work with the JAX backend. It fails with a complaint about an unhashable array type (see the traceback). If I switch to the numpy backend, as shown in the documentation, it runs fine.I see this on both EL7 in an ATLAS environment (
StatAnalysis,0.3,latest
) and on my own desktop (Fedora 38); in both cases I have the same pyhf version (0.7.6) and I manually installedjax[CPU] == 0.4.26
on top of that.I should add that things work fine with JAX if I use the version of
upper_limits
where I pass in a range of mu values to scan-- so I guess maybe some extra type conversion is needed to go from the JAX array type to a list or something hashable?OS / Environment
Steps to Reproduce
Install pyhf and JAX through pip; then try to run the example in the documentation, but with the JAX backend instead of numpy:
File Upload (optional)
No response
Expected Results
Ideally the example would run without crashing (as it does with the numpy backend).
Actual Results
pyhf Version
Code of Conduct