Open lukasheinrich opened 4 years ago
Note on ubuntu machines, etc.. follow https://github.com/google/jax#installation where an automatic detection install is
pip install --upgrade https://storage.googleapis.com/jax-releases/`nvidia-smi | sed -En "s/.* CUDA Version: ([0-9]*)\.([0-9]*).*/cuda\1\2/p"`/jaxlib-0.1.52-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
following works for me
# XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda python run.py
import pyhf
import jax
import time
pyhf.set_backend(pyhf.tensor.jax_backend())
model = pyhf.simplemodels.hepdata_like([50.], [100.], [20.])
data = [150.] + model.config.auxdata
model.logpdf = jax.jit(model.logpdf)
model.expected_data = jax.jit(model.expected_data)
pyhf.infer.hypotest(1.0, data, model)
print(jax.devices())
start = time.time()
result = pyhf.infer.hypotest(1.0, data, model)
print(result)
delta = time.time()-start
print(delta)
with following output, on pyhf==0.5.1
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
[GpuDevice(id=0)]
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
WARNING:pyhf.infer.test_statistics:qmu test statistic used for fit configuration with POI bounded at zero.
Use the qmu_tilde test statistic (pyhf.infer.test_statistics.qmu_tilde) instead.
[0.510833]
0.02041482925415039
Updating this example for pyhf
v0.7.0
and using JAX CUDA from
python -m pip install \
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
'jax[cuda]==0.3.25'
then:
# issue_297.py
import pyhf
import jax
import time
pyhf.set_backend("jax")
model = pyhf.simplemodels.uncorrelated_background(
signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
)
data = [51, 48] + model.config.auxdata
test_mu = 1.0
print(f"JAX Devices: {jax.devices()}")
model.logpdf = jax.jit(model.logpdf)
model.expected_data = jax.jit(model.expected_data)
print("First run of hypotest")
start = time.time()
pyhf.infer.hypotest(test_mu, data, model)
delta = time.time() - start
print(f"{delta=}\n")
for idx in range(2):
print(f"{idx+1} run with JIT")
start = time.time()
result = pyhf.infer.hypotest(test_mu, data, model)
delta = time.time() - start
print(f"{result=}")
print(f"{delta=}\n")
gives
$ python issue_297.py
JAX Devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
First run of hypotest
delta=1.7161650657653809
1 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.04859566688537598
2 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.057520389556884766
and remembering https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html and https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
$ JAX_PLATFORMS="cpu" python issue_297.py
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
JAX Devices: [CpuDevice(id=0)]
First run of hypotest
delta=0.6355447769165039
1 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.01805400848388672
2 run with JIT
result=DeviceArray(0.05251497, dtype=float64)
delta=0.019098520278930664
This and recent discussions we had about comparing the results of https://arxiv.org/abs/2301.05676 on GPU reminds me that we need easy ways from the command line to indicate if we want to be able to run on CPU or GPU. We should also revisit Issue #1248 and think about how to make running on the GPU more performant.
tested on colab