matthewfeickert / distributed-inference-with-pyhf-and-funcX

Example code for vCHEP 2021 paper "Distributed statistical inference with pyhf enabled through funcX"
MIT License
0 stars 2 forks source link

Run pyhf on an Expanse GPU Node #5

Open BenGalewsky opened 3 years ago

BenGalewsky commented 3 years ago

Thanks to an XSEDE start-up grant, we have an allocation on SDSC's Expanse GPU Supercomputer

This issue will record notes on getting the service to run there

BenGalewsky commented 3 years ago

Launching an Interactive Node

 srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=96G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash

Setting up Environment

module load python/3.8.5
module load py-pip

Install Jax

python -m pip install --upgrade jax jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Configure CUDA Environment

export  XLA_FLAGS="--xla_gpu_cuda_data_dir=/cm/shared/apps/spack/gpu/opt/spack/linux-centos8-skylake_avx512/gcc-8.3.1/nvhpc-20.9-tpwyy4iik6rsikls5ikkvzrttcnc7ytd/Linux_x86_64/20.9/cuda/11.0"
matthewfeickert commented 3 years ago

I'm having some trouble getting the modules setup for Python correctly as the interactive nodes make it seem like you should have the environment setup with modules correctly given the export=ALL in

srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=20G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash

but if I try to do

[feickert@login01 ~]$ module load gpu

Inactive Modules:
  1) gcc

[feickert@login01 ~]$ module load cuda
[feickert@login01 ~]$ module load python/3.8.5
Lmod has detected the following error:  These module(s) or extension(s) exist but cannot be loaded as requested: "python/3.8.5"
   Try: "module spider python/3.8.5" to see how to load the module(s).

[feickert@login01 ~]$ module spider python/3.8.5

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  python: python/3.8.5
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

    You will need to load all module(s) on any one of the lines below before the "python/3.8.5" module is available to load.

      cpu/0.15.4  gcc/10.2.0
      cpu/0.15.4  intel/19.1.1.217

    Help:
      The Python programming language.

so it seems like a compiler that can work with the GPU environment is needed. :/

matthewfeickert commented 3 years ago

This is basically my logs of continued trials at getting an environment with CUDA and Python 3.8 working (still failing at this point).

$ srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=20G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash
srun: job 2463381 queued and waiting for resources
srun: job 2463381 has been allocated resources
$ module purge
$ module restore
Resetting modules to system default. Reseting $MODULEPATH back to system default. All extra directories will be removed from $MODULEPATH.
$ module list

Currently Loaded Modules:
  1) shared   2) slurm/expanse/20.02.3   3) gpu/0.15.4   4) DefaultModules

$ module load cuda
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Thu_Jun_11_22:26:38_PDT_2020
Cuda compilation tools, release 11.0, V11.0.194
Build cuda_11.0_bu.TC445_37.28540450_0
$ nvidia-smi --list-gpus
GPU 0: Tesla V100-SXM2-32GB (UUID: GPU-e592d2b3-0e03-99f1-8092-2ff65e7f8660)

At this point CUDA is available and working, but Python now is not

$ module load python/3.8.5
Lmod has detected the following error:  These module(s) or extension(s) exist but cannot be loaded as requested: "python/3.8.5"
   Try: "module spider python/3.8.5" to see how to load the module(s).

$ module spider python/3.8.5

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  python: python/3.8.5
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

    You will need to load all module(s) on any one of the lines below before the "python/3.8.5" module is available to load.

      cpu/0.15.4  gcc/10.2.0
      cpu/0.15.4  intel/19.1.1.217

    Help:
      The Python programming language.

However, loading CPU modules deactivates CUDA (makes sense)

$ module load cpu/0.15.4

Inactive Modules:
  1) cuda

$ module load gcc/10.2.0
$ module load python/3.8.5
$ module load py-pip
$ python -m pip install --upgrade pip setuptools wheel
$ python -m pip install virtualenv
$ virtualenv pyhf-funcx-tests
$ . pyhf-funcx-tests/bin/activate
(pyhf-funcx-tests) $ python -m pip install --upgrade pip setuptools wheel
(pyhf-funcx-tests) $ python -m pip install pyhf[jax]
(pyhf-funcx-tests) $ python -m pip install --upgrade jaxlib==0.1.66+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
(pyhf-funcx-tests) $ python -m pip list
Package     Version
----------- --------------
absl-py     0.12.0
attrs       21.2.0
click       7.1.2
flatbuffers 2.0
jax         0.2.13
jaxlib      0.1.66+cuda110
jsonpatch   1.32
jsonpointer 2.1
jsonschema  3.2.0
numpy       1.20.3
opt-einsum  3.3.0
pip         20.2
pyhf        0.6.1
pyrsistent  0.17.3
PyYAML      5.4.1
scipy       1.6.3
setuptools  56.2.0
six         1.16.0
tqdm        4.60.0
wheel       0.36.2

As expected, JAX is unable to find a working version of CUDA for the environment and trying to search for a viable option to be able to use for XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda hasn't turned up anything yet.

(pyhf-funcx-tests) $ curl -sL https://raw.githubusercontent.com/matthewfeickert/nvidia-gpu-ml-library-test/main/jax_detect_GPU.py -o jax_detect_GPU.py
$ python jax_detect_GPU.py
2021-05-12 00:44:14.778668: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen2/gcc-10.2.0/py-pip-20.2-isdbspkz2tvx3sddoinfki6fl6cy4cdc/lib:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen2/gcc-10.2.0/python-3.8.5-naaw62bifnds2nzcxoazwqbb5bok4u4r/lib:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/gcc-10.2.0-n7su7jf54rc7l2ozegds5xksy6qhrjin/lib64:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/gcc-10.2.0-n7su7jf54rc7l2ozegds5xksy6qhrjin/lib:/cm/shared/apps/slurm/current/lib64/slurm:/cm/shared/apps/slurm/current/lib64
2021-05-12 00:44:15.052180: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen2/gcc-10.2.0/py-pip-20.2-isdbspkz2tvx3sddoinfki6fl6cy4cdc/lib:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen2/gcc-10.2.0/python-3.8.5-naaw62bifnds2nzcxoazwqbb5bok4u4r/lib:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/gcc-10.2.0-n7su7jf54rc7l2ozegds5xksy6qhrjin/lib64:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/gcc-10.2.0-n7su7jf54rc7l2ozegds5xksy6qhrjin/lib:/cm/shared/apps/slurm/current/lib64/slurm:/cm/shared/apps/slurm/current/lib64
2021-05-12 00:44:15.059053: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen2/gcc-10.2.0/py-pip-20.2-isdbspkz2tvx3sddoinfki6fl6cy4cdc/lib:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen2/gcc-10.2.0/python-3.8.5-naaw62bifnds2nzcxoazwqbb5bok4u4r/lib:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/gcc-10.2.0-n7su7jf54rc7l2ozegds5xksy6qhrjin/lib64:/cm/shared/apps/spack/cpu/opt/spack/linux-centos8-zen/gcc-8.3.1/gcc-10.2.0-n7su7jf54rc7l2ozegds5xksy6qhrjin/lib:/cm/shared/apps/slurm/current/lib64/slurm:/cm/shared/apps/slurm/current/lib64
XLA backend type: gpu

Number of GPUs found on system: 1

Active GPU index: 0
Active GPU name: Tesla V100-SXM2-32GB
BenGalewsky commented 3 years ago

It does look like you can have CUDA and Python 3.6

$ module load gpu
$ python --version
Python 3.6.8
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Thu_Jun_11_22:26:38_PDT_2020
Cuda compilation tools, release 11.0, V11.0.194
Build cuda_11.0_bu.TC445_37.28540450_0

I know that pyhf abandoned python 3.6 recently... does this do us any good?

matthewfeickert commented 3 years ago

I know that pyhf abandoned python 3.6 recently... does this do us any good?

Sadly, not really. We should be using pyhf v0.6.1 for these tests. :/

matthewfeickert commented 3 years ago

@BenGalewsky Thanks to support from Mahidhar on my XSEDE ticket ("150421 : XUP: Possible to have CUDA and Python 3.7+ on EXPANSE GPU nodes?")

We also have an anaconda installed version of python which gives you version 3.8.5. The following should work:

module reset module load gpu module load anaconda3

the following works


$ srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=20G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash
$ module purge
$ module restore
$ module load cuda
$ module load anaconda3
$ python --version --version
Python 3.8.5 (default, Sep  4 2020, 07:30:14)
[GCC 7.3.0]
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Thu_Jun_11_22:26:38_PDT_2020
Cuda compilation tools, release 11.0, V11.0.194
Build cuda_11.0_bu.TC445_37.28540450_0
$ conda create -n pyhf-funcx
Collecting package metadata (current_repodata.json): done
Solving environment: done

==> WARNING: A newer version of conda exists. <==
  current version: 4.9.2
  latest version: 4.10.1

Please update conda by running

    $ conda update -n base -c defaults conda

## Package Plan ##

  environment location: /home/feickert/.conda/envs/pyhf-funcx

Proceed ([y]/n)? y

Preparing transaction: done
Verifying transaction: done
Executing transaction: done
#
# To activate this environment, use
#
#     $ conda activate pyhf-funcx
#
# To deactivate an active environment, use
#
#     $ conda deactivate

$ conda activate pyhf-funcx

CommandNotFoundError: Your shell has not been properly configured to use 'conda activate'.
To initialize your shell, run

    $ conda init <SHELL_NAME>

Currently supported shells are:
  - bash
  - fish
  - tcsh
  - xonsh
  - zsh
  - powershell

See 'conda init --help' for more information and options.

IMPORTANT: You may need to close and restart your shell after running 'conda init'.
$ conda init bash

We trust you have received the usual lecture from the local System
Administrator. It usually boils down to these three things:

    #1) Respect the privacy of others.
    #2) Think before you type.
    #3) With great power comes great responsibility.

PIN+Yubi:

Here we run into a problem as conda init bash needs root access, so instead if you just exit and log back in then things work

$ srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=20G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash
$ module purge
$ module restore
$ module load cuda
$ module load anaconda3
$ conda activate pyhf-funcx
(pyhf-funcx) $ python -m pip install --upgrade pip setuptools wheel
(pyhf-funcx) $ python -m pip install pyhf[jax]
(pyhf-funcx) $ python -m pip install --upgrade jaxlib==0.1.66+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
curl -sL https://raw.githubusercontent.com/matthewfeickert/nvidia-gpu-ml-library-test/main/jax_detect_GPU.py -o jax_detect_GPU.py
(pyhf-funcx) $ python jax_detect_GPU.py
XLA backend type: gpu

Number of GPUs found on system: 1

Active GPU index: 0
Active GPU name: Tesla V100-SXM2-32GB
Also actual training works too.: :+1: ``` (pyhf-funcx) $ curl -sL https://raw.githubusercontent.com/matthewfeickert/nvidia-gpu-ml-library-test/main/jax_example_datasets.py -o jax_example_datasets.py (pyhf-funcx) $ curl -sL https://raw.githubusercontent.com/matthewfeickert/nvidia-gpu-ml-library-test/main/jax_MNIST.py -o jax_MNIST.py (pyhf-funcx) $ python jax_MNIST.py downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ Starting training... Epoch 0 in 1.78 sec Training set accuracy 0.8719333410263062 Test set accuracy 0.8805000185966492 Epoch 1 in 0.24 sec Training set accuracy 0.8979166746139526 Test set accuracy 0.9032000303268433 Epoch 2 in 0.23 sec Training set accuracy 0.9092333316802979 Test set accuracy 0.9143000245094299 Epoch 3 in 0.23 sec Training set accuracy 0.9170833230018616 Test set accuracy 0.9221000671386719 Epoch 4 in 0.23 sec Training set accuracy 0.9226666688919067 Test set accuracy 0.9279000163078308 Epoch 5 in 0.23 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.929900050163269 Epoch 6 in 0.23 sec Training set accuracy 0.9323333501815796 Test set accuracy 0.932900071144104 Epoch 7 in 0.23 sec Training set accuracy 0.9357166886329651 Test set accuracy 0.9364000558853149 Epoch 8 in 0.23 sec Training set accuracy 0.9387833476066589 Test set accuracy 0.9394000172615051 Epoch 9 in 0.23 sec Training set accuracy 0.9425833225250244 Test set accuracy 0.9419000744819641 ```

yay! :rocket:

I'll create a Conda environment.yml file to make this step reproducible.

BenGalewsky commented 3 years ago

Creating an endpoint...

$ module load py-pip
$ python -m pip install funcx-endpoint

I then created an endpoint with

$ funcx-endpoint configure pyhf

And then updated the ~/.funcx/pyhf/config.py to

from funcx_endpoint.endpoint.utils.config import Config
from funcx_endpoint.executors import HighThroughputExecutor
from parsl.providers import SlurmProvider
from parsl.launchers import SrunLauncher
from parsl.addresses import address_by_hostname

# PLEASE UPDATE user_opts BEFORE USE
user_opts = {
    'expanse': {
        'worker_init': 'source ~/setup_funcx_test_env.sh',
        'scheduler_options': '#SBATCH --gpus=1',
    }
}

config = Config(
    executors=[
        HighThroughputExecutor(
            label='Expanse_GPU',
            address=address_by_hostname(),
            provider=SlurmProvider(
                'gpu-debug',  # Partition / QOS
                account='nsa106',
                nodes_per_block=1,
                init_blocks=1,
                mem_per_node=20,
                # string to prepend to #SBATCH blocks in the submit
                # script to the scheduler eg: '#SBATCH --constraint=knl,quad,cache'
                scheduler_options=user_opts['expanse']['scheduler_options'],

                # Command to be run before starting a worker, such as:
                # 'module load Anaconda; source activate parsl_env'.
                worker_init=user_opts['expanse']['worker_init'],

                launcher=SrunLauncher(),
                walltime='00:10:00',
                # increase the command timeouts
                cmd_timeout=120,
            ),
        ),
    ],
)

and created a ~/setup_funcx_test_env.sh as

echo "Setting up FuncX Endpoint for pyhf"
module purge
module restore
module load cuda
module load anaconda3
conda activate pyhf-funcx

Finally, started the endpoint with

funcx-endpoint start pyhf

and made note of the endpoint ID

matthewfeickert commented 3 years ago

For a simple test of just running

$ python fit_analysis.py -c config/1Lbb.json

with the following diff from main at the moment

diff --git a/fit_analysis.py b/fit_analysis.py
index 4228655..6e01f31 100644
--- a/fit_analysis.py
+++ b/fit_analysis.py
@@ -11,6 +11,8 @@ from pyhf.contrib.utils import download
 def prepare_workspace(data):
     import pyhf

+    pyhf.set_backend("jax")
+
     return pyhf.Workspace(data)

@@ -19,6 +21,8 @@ def infer_hypotest(workspace, metadata, patches):

     import pyhf

+    pyhf.set_backend("jax")
+
     tick = time.time()
     model = workspace.model(
         patches=patches,
@@ -98,6 +102,7 @@ def main(args):

     # execute patch fits across workers and retrieve them when done
     n_patches = len(patchset.patches)
+    n_patches = 5
     tasks = {}
     for patch_idx in range(n_patches):
         patch = patchset.patches[patch_idx]

the endpoint seems to be having problems finding the right CUDA libs

$ python fit_analysis.py -c config/1Lbb.json
prepare: waiting-for-ep
prepare: waiting-for-ep
prepare: waiting-for-ep
prepare: waiting-for-ep
--------------------
<pyhf.workspace.Workspace object at 0x1555407e5360>
inference: waiting-for-ep
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x1555370e8d90>
inference: Internal: libdevice not found at ./libdevice.10.bc
Task C1N2_Wh_hbb_1000_150 complete, there are 1 results now
Task C1N2_Wh_hbb_1000_200 complete, there are 2 results now
Task C1N2_Wh_hbb_1000_250 complete, there are 3 results now
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155539540ee0>
inference: Internal: libdevice not found at ./libdevice.10.bc
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x1555370e8d90>
inference: Internal: libdevice not found at ./libdevice.10.bc
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155539540ee0>
inference: Internal: libdevice not found at ./libdevice.10.bc
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x1555370e8d90>
inference: Internal: libdevice not found at ./libdevice.10.bc
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155539540ee0>
inference: Internal: libdevice not found at ./libdevice.10.bc
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x1555370e8d90>
inference: Internal: libdevice not found at ./libdevice.10.bc
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155539540ee0>
inference: Internal: libdevice not found at ./libdevice.10.bc

Note to self: Recheck in morning if this makes sense.

BenGalewsky commented 3 years ago

Do I need to set an env var for CUDA?

matthewfeickert commented 3 years ago

Do I need to set an env var for CUDA?

I don't think so, as I'm able to do the following

$ ssh EXPANSE 
Welcome to Bright release         9.0

                                                        Based on CentOS Linux 8
                                                                    ID: #000002

--------------------------------------------------------------------------------

                                 WELCOME TO
                  _______  __ ____  ___    _   _______ ______
                 / ____/ |/ // __ \/   |  / | / / ___// ____/
                / __/  |   // /_/ / /| | /  |/ /\__ \/ __/
               / /___ /   |/ ____/ ___ |/ /|  /___/ / /___
              /_____//_/|_/_/   /_/  |_/_/ |_//____/_____/

--------------------------------------------------------------------------------

Use the following commands to adjust your environment:

'module avail'            - show available modules
'module add <module>'     - adds a module to your environment for this session
'module initadd <module>' - configure module to be loaded at every login

-------------------------------------------------------------------------------
Last login: Thu May 13 16:10:03 2021 from 23.249.32.244
$ git clone https://github.com/matthewfeickert/nvidia-gpu-ml-library-test.git
$ cd nvidia-gpu-ml-library-test/
$ srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=20G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash
$ cd ~/workarea/distributed-inference-with-pyhf-and-funcX/
$ . setup_expanse_funcx_test_env.sh 
Setting up FuncX Endpoint for pyhf
Resetting modules to system default. Reseting $MODULEPATH back to system default. All extra directories will be removed from $MODULEPATH.

Currently Loaded Modules:
  1) shared   2) slurm/expanse/20.02.3   3) gpu/0.15.4   4) DefaultModules   5) cuda/11.0.2   6) anaconda3/2020.11

(distributed-inference) $ cd -
/home/feickert/nvidia-gpu-ml-library-test
(distributed-inference) $ python jax_detect_GPU.py 
XLA backend type: gpu

Number of GPUs found on system: 1

Active GPU index: 0
Active GPU name: Tesla V100-SXM2-32GB
(distributed-inference) $ python jax_MNIST.py 
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/

Starting training...
Epoch 0 in 1.78 sec
Training set accuracy 0.8719333410263062
Test set accuracy 0.8805000185966492
Epoch 1 in 0.24 sec
Training set accuracy 0.8979166746139526
Test set accuracy 0.9032000303268433
Epoch 2 in 0.23 sec
Training set accuracy 0.9092333316802979
Test set accuracy 0.9143000245094299
Epoch 3 in 0.23 sec
Training set accuracy 0.9170833230018616
Test set accuracy 0.9221000671386719
Epoch 4 in 0.24 sec
Training set accuracy 0.9226500391960144
Test set accuracy 0.9280000329017639
Epoch 5 in 0.23 sec
Training set accuracy 0.927216649055481
Test set accuracy 0.929900050163269
Epoch 6 in 0.23 sec
Training set accuracy 0.9323166608810425
Test set accuracy 0.932900071144104
Epoch 7 in 0.23 sec
Training set accuracy 0.9357333183288574
Test set accuracy 0.936500072479248
Epoch 8 in 0.23 sec
Training set accuracy 0.9387833476066589
Test set accuracy 0.9394000172615051
Epoch 9 in 0.23 sec
Training set accuracy 0.9425833225250244
Test set accuracy 0.9419000744819641

and as no warnings are printed it is finding CUDA thanks (I think) to

https://github.com/matthewfeickert/distributed-inference-with-pyhf-and-funcX/blob/59516f205373a24ee9c7757833f781cf9c0f2afc/setup_expanse_funcx_test_env.sh#L4

BenGalewsky commented 3 years ago

This now works! I just needed to upgrade to the provided expanse-environment.yaml and all is well

python fit_analysis.py -c config/1Lbb.json 6.09s user 1.11s system 5% cpu 2:02.14 total

matthewfeickert commented 3 years ago

That sounds good @BenGalewsky but it seems like things aren't happy still.

(base) feickert@ThinkPad-X1:~$ ssh EXPANSE 
Welcome to Bright release         9.0

                                                        Based on CentOS Linux 8
                                                                    ID: #000002

--------------------------------------------------------------------------------

                                 WELCOME TO
                  _______  __ ____  ___    _   _______ ______
                 / ____/ |/ // __ \/   |  / | / / ___// ____/
                / __/  |   // /_/ / /| | /  |/ /\__ \/ __/
               / /___ /   |/ ____/ ___ |/ /|  /___/ / /___
              /_____//_/|_/_/   /_/  |_/_/ |_//____/_____/

--------------------------------------------------------------------------------

Use the following commands to adjust your environment:

'module avail'            - show available modules
'module add <module>'     - adds a module to your environment for this session
'module initadd <module>' - configure module to be loaded at every login

-------------------------------------------------------------------------------
Last login: Thu May 13 16:30:12 2021 from 23.249.32.244
[feickert@login01 ~]$ cd workarea/distributed-inference-with-pyhf-and-funcX/
$ git branch
  main
* test/try-using-JAX-backend
[feickert@login01 distributed-inference-with-pyhf-and-funcX]$ srun --partition=gpu-debug --pty --account=nsa106 --ntasks-per-node=10 --nodes=1 --mem=20G --gpus=1 -t 00:30:00 --wait=0 --export=ALL /bin/bash
[feickert@exp-7-59 distributed-inference-with-pyhf-and-funcX]$ . setup_expanse_funcx_test_env.sh 
Setting up FuncX Endpoint for pyhf
Resetting modules to system default. Reseting $MODULEPATH back to system default. All extra directories will be removed from $MODULEPATH.

Currently Loaded Modules:
  1) shared   2) slurm/expanse/20.02.3   3) gpu/0.15.4   4) DefaultModules   5) cuda/11.0.2   6) anaconda3/2020.11

(distributed-inference) [feickert@exp-7-59 distributed-inference-with-pyhf-and-funcX]$ python -m pip list | grep jax
jax                 0.2.13
jaxlib              0.1.66+cuda110
(distributed-inference) [feickert@exp-7-59 distributed-inference-with-pyhf-and-funcX]$ python fit_analysis.py -c config/1Lbb.json
prepare: waiting-for-ep
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155540be6eb0>
prepare: ('There was a problem importing JAX. The jax backend cannot be used.', ModuleNotFoundError("No module named 'jax'"))
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155540be6eb0>
prepare: ('There was a problem importing JAX. The jax backend cannot be used.', ModuleNotFoundError("No module named 'jax'"))
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155540be6eb0>
prepare: ('There was a problem importing JAX. The jax backend cannot be used.', ModuleNotFoundError("No module named 'jax'"))
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155540be6eb0>
prepare: ('There was a problem importing JAX. The jax backend cannot be used.', ModuleNotFoundError("No module named 'jax'"))
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x155540be6eb0>
prepare: ('There was a problem importing JAX. The jax backend cannot be used.', ModuleNotFoundError("No module named 'jax'"))

It seems that JAX isn't getting picked up for some reason now. :?

The only difference between this test branch and main is the use of JAX

$ git diff origin/main 
diff --git a/fit_analysis.py b/fit_analysis.py
index 4228655..3e18a23 100644
--- a/fit_analysis.py
+++ b/fit_analysis.py
@@ -11,6 +11,8 @@ from pyhf.contrib.utils import download
 def prepare_workspace(data):
     import pyhf

+    pyhf.set_backend("jax")
+
     return pyhf.Workspace(data)

@@ -19,6 +21,8 @@ def infer_hypotest(workspace, metadata, patches):

     import pyhf

+    pyhf.set_backend("jax")
+
     tick = time.time()
     model = workspace.model(
         patches=patches,
BenGalewsky commented 3 years ago

Are you running the client on expanse? I'm running it locally on my laptop

BenGalewsky commented 3 years ago

Oh, let me switch to that branch

matthewfeickert commented 3 years ago

Are you running the client on expanse? I'm running it locally on my laptop

Yeah, I was running on Expanse, but I'm getting the same when I run on my laptop

(distributed-inference) $ hostname
ThinkPad-X1
(distributed-inference) $ pip list | grep 'jax\|funcx'
funcx                         0.2.2
funcx-endpoint                0.2.2
jax                           0.2.13
jaxlib                        0.1.66+cuda101
# Note this is 0.1.66+cuda101 and not 0.1.66+cuda110 as I need to adjust it for my laptop GPU
# My GPU isn't doing anything though obviously, so the following is just a "does this environment even work?" check
(distributed-inference) $ curl -sL https://raw.githubusercontent.com/matthewfeickert/nvidia-gpu-ml-library-test/main/jax_detect_GPU.py | python
XLA backend type: gpu

Number of GPUs found on system: 1

Active GPU index: 0
Active GPU name: GeForce GTX 1650 with Max-Q Design
(distributed-inference) $ python fit_analysis.py -c config/1Lbb.json
prepare: waiting-for-ep
WARNING:funcx.sdk.client:We have an exception : <parsl.app.errors.RemoteExceptionWrapper object at 0x7f5d55697fa0>
prepare: ('There was a problem importing JAX. The jax backend cannot be used.', ModuleNotFoundError("No module named 'jax'"))

Strange.

BenGalewsky commented 3 years ago

Think I really got it this time

python fit_analysis.py -c config/1Lbb.json  6.26s user 1.10s system 10% cpu 1:10.67 total

Sure enough Jax-lib was in my env, but not Jax - should the line in requirements.txt be

pyhf[contrib,jax]~=0.6.1
matthewfeickert commented 3 years ago

Think I really got it this time

python fit_analysis.py -c config/1Lbb.json  6.26s user 1.10s system 10% cpu 1:10.67 total

@BenGalewsky Beautiful. :)

(distributed-inference) feickert@ThinkPad-X1:~/Code/GitHub/IRIS-HEP/distributed-inference-with-pyhf-and-funcX$ time python fit_analysis.py -c config/1Lbb.json
prepare: waiting-for-ep
--------------------
<pyhf.workspace.Workspace object at 0x7feaadb09040>
Task C1N2_Wh_hbb_1000_0 complete, there are 1 results now
Task C1N2_Wh_hbb_1000_100 complete, there are 2 results now
Task C1N2_Wh_hbb_1000_150 complete, there are 3 results now
...
real    1m54.501s
user    0m9.340s
sys 0m1.154s

Thanks very much for taking a look and fixing this.

Sure enough Jax-lib was in my env, but not Jax - should the line in requirements.txt be

pyhf[contrib,jax]~=0.6.1

Ah, sorry this was due to poor communication on my part. In PR #6, I split out the JAX dependencies into jax-requirements.txt to try to make it eaiser to specify the core dependencies in requirements.txt vs. the backend dependencies (e.g. jax-requirements.txt or a future torch-requirements.txt).

$ cat jax-requirements.txt 
jax==0.2.13
# Place --find-links _before_ jaxlib to satisfy Conda
--find-links https://storage.googleapis.com/jax-releases/jax_releases.html
jaxlib==0.1.66+cuda110

They both get picked up in the expanse-environment.yml

name: distributed-inference
channels:
  - defaults
  - conda-forge
dependencies:
  - python>=3.7,<3.9
  - pip
  - pip:
    - setuptools
    - wheel
    - -r file:requirements.txt
    - -r file:jax-requirements.txt

but I should have flagged you on this change. I should probably rename requirements.txt to something like core-requirements.txt or something to differentiate it as requirements.txt carries with it the cultural assumption that it is inclusive. (edit: done now in PR #10)