scikit-hep / pyhf

pure-Python HistFactory implementation with tensors and autodiff
https://pyhf.readthedocs.io/
Apache License 2.0
283 stars 83 forks source link

example script to run pyhf on JAX w/ GPU #792

Open lukasheinrich opened 4 years ago

lukasheinrich commented 4 years ago

tested on colab

!pip install pyhf[jax]  > /dev/null
import jax
import pyhf
import json
import time

pyhf.set_backend(pyhf.tensor.jax_backend())
w = pyhf.Workspace(json.load(open('./pyhf_likelihood.json')))
m = w.model()
d = w.data(m)
m.logpdf = jax.jit(m.logpdf)
m.expected_data = jax.jit(m.expected_data)

#trigger jit
pyhf.infer.hypotest(1.0,d,m)
jax.devices()
start = time.time()
result = pyhf.infer.hypotest(1.0,d,m)
print(result)
delta = time.time()-start
print(delta)
kratsg commented 4 years ago

Note on ubuntu machines, etc.. follow https://github.com/google/jax#installation where an automatic detection install is

pip install --upgrade https://storage.googleapis.com/jax-releases/`nvidia-smi | sed -En "s/.* CUDA Version: ([0-9]*)\.([0-9]*).*/cuda\1\2/p"`/jaxlib-0.1.52-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
kratsg commented 4 years ago

following works for me

# XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda python run.py
import pyhf
import jax
import time

pyhf.set_backend(pyhf.tensor.jax_backend())
model = pyhf.simplemodels.hepdata_like([50.], [100.], [20.])
data = [150.] + model.config.auxdata
model.logpdf = jax.jit(model.logpdf)
model.expected_data = jax.jit(model.expected_data)
pyhf.infer.hypotest(1.0, data, model)
print(jax.devices())

start = time.time()
result = pyhf.infer.hypotest(1.0, data, model)
print(result)
delta = time.time()-start
print(delta)

with following output, on pyhf==0.5.1

WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
[GpuDevice(id=0)]
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
[0.510833]
0.02041482925415039
matthewfeickert commented 1 year ago

Updating this example for pyhf v0.7.0 and using JAX CUDA from

python -m pip install \
    --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
    'jax[cuda]==0.3.25'

then:

# issue_297.py
import pyhf
import jax
import time

pyhf.set_backend("jax")
model = pyhf.simplemodels.uncorrelated_background(
    signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
)
data = [51, 48] + model.config.auxdata
test_mu = 1.0

print(f"JAX Devices: {jax.devices()}")

model.logpdf = jax.jit(model.logpdf)
model.expected_data = jax.jit(model.expected_data)
print("First run of hypotest")
start = time.time()
pyhf.infer.hypotest(test_mu, data, model)
delta = time.time() - start
print(f"{delta=}\n")

for idx in range(2):
    print(f"{idx+1} run with JIT")
    start = time.time()
    result = pyhf.infer.hypotest(test_mu, data, model)
    delta = time.time() - start

    print(f"{result=}")
    print(f"{delta=}\n")

gives

$ python issue_297.py
JAX Devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
First run of hypotest
delta=1.7161650657653809

1 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.04859566688537598

2 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.057520389556884766

and remembering https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html and https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices

$ JAX_PLATFORMS="cpu" python issue_297.py
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
JAX Devices: [CpuDevice(id=0)]
First run of hypotest
delta=0.6355447769165039

1 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.01805400848388672

2 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.019098520278930664

This and recent discussions we had about comparing the results of https://arxiv.org/abs/2301.05676 on GPU reminds me that we need easy ways from the command line to indicate if we want to be able to run on CPU or GPU. We should also revisit Issue #1248 and think about how to make running on the GPU more performant.