scikit-hep / pyhf

pure-Python HistFactory implementation with tensors and autodiff
https://pyhf.readthedocs.io/
Apache License 2.0
274 stars 81 forks source link

toms748_scan doesn't work with JAX backend #2466

Open TC01 opened 2 months ago

TC01 commented 2 months ago

Summary

Hello; perhaps this is known but I thought I'd file a bug report just in case. I was testing the upper_limits API and discovered that the example given in the documentation doesn't seem to work with the JAX backend. It fails with a complaint about an unhashable array type (see the traceback). If I switch to the numpy backend, as shown in the documentation, it runs fine.

I see this on both EL7 in an ATLAS environment (StatAnalysis,0.3,latest) and on my own desktop (Fedora 38); in both cases I have the same pyhf version (0.7.6) and I manually installed jax[CPU] == 0.4.26 on top of that.

I should add that things work fine with JAX if I use the version of upper_limits where I pass in a range of mu values to scan-- so I guess maybe some extra type conversion is needed to go from the JAX array type to a list or something hashable?

OS / Environment

# Linux
$ cat /etc/os-release
NAME="Fedora Linux"
VERSION="38 (Thirty Eight)"
ID=fedora
VERSION_ID=38
VERSION_CODENAME=""
PLATFORM_ID="platform:f38"
PRETTY_NAME="Fedora Linux 38 (Thirty Eight)"
ANSI_COLOR="0;38;2;60;110;180"
LOGO=fedora-logo-icon
CPE_NAME="cpe:/o:fedoraproject:fedora:38"
DEFAULT_HOSTNAME="fedora"
HOME_URL="https://fedoraproject.org/"
DOCUMENTATION_URL="https://docs.fedoraproject.org/en-US/fedora/f38/system-administrators-guide/"
SUPPORT_URL="https://ask.fedoraproject.org/"
BUG_REPORT_URL="https://bugzilla.redhat.com/"
REDHAT_BUGZILLA_PRODUCT="Fedora"
REDHAT_BUGZILLA_PRODUCT_VERSION=38
REDHAT_SUPPORT_PRODUCT="Fedora"
REDHAT_SUPPORT_PRODUCT_VERSION=38
SUPPORT_END=2024-05-14

Steps to Reproduce

Install pyhf and JAX through pip; then try to run the example in the documentation, but with the JAX backend instead of numpy:

import numpy as np
import pyhf
pyhf.set_backend("JAX")
model = pyhf.simplemodels.uncorrelated_background(
    signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
)
observations = [51, 48]
data = pyhf.tensorlib.astensor(observations + model.config.auxdata)
obs_limit, exp_limits = pyhf.infer.intervals.upper_limits.toms748_scan(
    data, model, 0., 5., rtol=0.01
)

File Upload (optional)

No response

Expected Results

Ideally the example would run without crashing (as it does with the numpy backend).

Actual Results

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 130, in toms748_scan
    toms748(f, bounds_low, bounds_up, args=(level, 0), k=2, xtol=atol, rtol=rtol)
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1374, in toms748
    result = solver.solve(f, a, b, args=args, k=k, xtol=xtol, rtol=rtol,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1229, in solve
    fc = self._callf(c)
         ^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1083, in _callf
    fx = self.f(x, *self.args)
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 95, in f
    f_cached(poi)[0] - level
    ^^^^^^^^^^^^^
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 80, in f_cached
    if poi not in cache:
       ^^^^^^^^^^^^^^^^
TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'

pyhf Version

$ pyhf --version
pyhf, version 0.7.6

Code of Conduct

TC01 commented 2 months ago

It seems that this can be fixed by explicitly converting poi to a float before checking if it's in the hypotest cache here. I printed out the types and at some point in the test run poi switched from being a numpy float64 to being a jaxlib type; I can trace further to see exactly why that happens and then maybe submit a PR.

matthewfeickert commented 2 months ago

First thing we'll need to do is understand why there aren't test failures

https://github.com/scikit-hep/pyhf/blob/64ab2646b836ae69404929b7e2b7ba04fb87d492/tests/test_infer.py#L26-L57

which will probably mean revisiting PR #1274. So writing a failing test would be a good start, so that a PR can make it pass.

TC01 commented 2 months ago

Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.

kratsg commented 2 months ago

Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.

Nope, which is likely which explains why it wasn't caught. (Adding the backend fixture in the test will have it run on all the backends).

kratsg commented 2 months ago

@matthewfeickert i saw the PR, and I think we need to swap the way we're approaching this. Here's my suggestion instead of type-casting - we need to add in shims across each lib and move some functions into our tensorlib instead to make them backend-dependent (or use a shim to swap them out as needed, like we do for scipy.optimize)

See this example:

from functools import lru_cache
import time
import timeit

import jax.numpy as jnp
import jax
import tensorflow as tf

def slow(n):
    time.sleep(1)
    return n**2

fast = lru_cache(maxsize=None)(slow)

fast_jax = jax.jit(slow)
fast_tflow = tf.function(jit_compile=True)(slow)

value = 5
print('slow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast')
print(timeit.timeit(lambda: [fast(value), fast(value), fast(value), fast(value), fast(value)], number=1))

value = jnp.array(5)
print('slow, jax')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, jax')
print(timeit.timeit(lambda: [fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value)], number=1))

value = tf.convert_to_tensor(5)
print('slow, tensorflow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, tensorflow')
print(timeit.timeit(lambda: [fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value)], number=1))

which outputs

$ python cache.py
slow
5.012567336
fast
1.0029977690000003
slow, jax
5.043927394000001
fast, jax
1.0195144690000006
slow, tensorflow
5.017408181999997
fast, tensorflow
1.0631543910000012

so we can definitely cache those values by JIT-ing for the toms748 scan here and that's probably what we want to do. My suggestion might be that we support pyhf.tensor.jit with something similiar to the signature of jax.jit across all backends (yes even numpy, but numpy would be an lru_cache).

matthewfeickert commented 2 months ago

we need to add in shims across each lib and move some functions into our tensorlib instead to make them backend-dependent

Okay, sounds good. Let's start up a seperate series of PRs to do this.