Open nataliaporqueres opened 1 year ago
Here is our action plan for the week
setting up GPU and the correct GPU python kernel on nersc with dependencies for this project. (should work out of the box!)
# get on a GPU node first!
# https://docs.nersc.gov/development/languages/python/using-python-perlmutter/
module load cudatoolkit/11.7
module load cudnn/8.9.1_cuda11
module load python
# Verify the versions of cudatoolkit and cudnn are compatible with JAX
module list
# Create a new conda environment
conda create -n jglass python=3.9 pip numpy scipy
# Activate the environment before using pip to install JAX
conda activate jgass
# Install a compatible wheel
pip install --no-cache-dir "jax[cuda11_cudnn82]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# try if jax works on GPU
python
import jax.numpy as jnp
import jax
# should be no error if everything worked
jnp.ones(10)
# should print gpus
jax.devices()
exit()
# other dependencies
# glass
git clone https://github.com/glass-dev/glass.git
cd glass
pip install -e .
cd ..
# glass-jax
git clone https://github.com/LSSTDESC/glass-jax.git
cd glass-jax
pip install -e .
cd ..
# pmwd
git clone https://github.com/eelregit/pmwd.git
cd pmwd
pip install -e .
cd ..
# s2fft
git clone https://github.com/astro-informatics/s2fft.git
cd s2fft
pip install -e .
cd ..
# other stuff
conda install -c conda-forge healpy h5py astropy matplotlib
pip install jupyterlab
pip install jax-cosmo
pip install typing-extensions
# setting up the kernel for jupyterhub
# https://docs.nersc.gov/services/jupyter/how-to-guides/
# conda activate jglass
python -m ipykernel install \
--user --name jglass --display-name jglass```
Hello, I've just tried and get
(nersc-python) login18:c/campagne/Work$ conda create -n jglass python=3.9 pip numpy scipy
Collecting package metadata (current_repodata.json): failed
# >>>>>>>>>>>>>>>>>>>>>> ERROR REPORT <<<<<<<<<<<<<<<<<<<<<<
Traceback (most recent call last):
File "/global/common/software/nersc/pe/conda/23.9.0/Miniconda3-py311_23.5.2-0/lib/python3.11/site-packages/conda/exception_handler.py", line 17, in __call__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
....
may I have missed something?
Hello, I've just tried and get
(nersc-python) login18:c/campagne/Work$ conda create -n jglass python=3.9 pip numpy scipy Collecting package metadata (current_repodata.json): failed # >>>>>>>>>>>>>>>>>>>>>> ERROR REPORT <<<<<<<<<<<<<<<<<<<<<< Traceback (most recent call last): File "/global/common/software/nersc/pe/conda/23.9.0/Miniconda3-py311_23.5.2-0/lib/python3.11/site-packages/conda/exception_handler.py", line 17, in __call__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ ....
may I have missed something?
Hi,
Could you try starting fresh and use the following lines:
conda create -n jglass python pip numpy scipy
conda activate jgass
module load gpu
module load evp-patch
conda install -y jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
conda install -c conda-forge numpyro
and then install the "other dependencies." This should get you ready with a more up-to-date version of jax and numpyro as well.
Hello @AZhou00
Thanks for the tip. In fact I'm pretty sure that the origin of the crash was due to disk quota exceed :(, as when I tried the above workflow I get the nasty "OSError: [Errno 122] Disk quota exceeded" immersed among a flow of error reports. After cleaning a little I was able to start the installation process :)
Now, as I am not familiar with NERSC running, I certainly not did an install in the right place. May I ask you where you install such environment, and where to do you run (ie. jupyter notebook @ nersc). I know that these questions are quite basic, so I hope you don't mind too much. Thanks.
I have in a jupyter in Google Colab but I do not manage to import glass-jax after
! git clone https://github.com/LSSTDESC/glass-jax.git
and the pip install -e .
The glass-jax module is not recognized neither glass_jax... Have you an idea?
I have a short use-case to test.
For future reference, use the following setup commands to install the required glass-jax environment with gpu on NERSC 😀
# get on a GPU node first!
module load gpu
module load evp-patch
# Create a new conda environment
conda create -n glass python pip numpy scipy
conda activate glass
conda install -y jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
conda install -c conda-forge numpyro
# -------- then follow the code below --------
# try if jax works on GPU
python
import jax.numpy as jnp
import jax
import numpyro
# should be no error if everything worked
jnp.ones(10)
jnp.fft.fft(jnp.array([1.0,2.0,3.0,4.0,5.0]))
# should print gpus
jax.devices()
exit()
# other dependencies
# other stuff
conda install -c conda-forge healpy h5py astropy matplotlib pyccl
pip install jupyterlab
pip install jax-cosmo
pip install typing-extensions
# s2fft
git clone https://github.com/astro-informatics/s2fft.git
cd s2fft
pip install -e .
cd ..
# glass
git clone https://github.com/glass-dev/glass.git
cd glass
pip install -e .
cd ..
# glass-jax
git clone https://github.com/LSSTDESC/glass-jax.git
cd glass-jax
pip install -e .
cd ..
# pmwd
git clone https://github.com/eelregit/pmwd.git
cd pmwd
pip install -e .
cd ..
# setting up the kernel for jupyterhub
# https://docs.nersc.gov/services/jupyter/how-to-guides/
# conda activate jglass
python -m ipykernel install \
--user --name glass --display-name glass
Building a lognormal and PM model into a Bayesian pipeline
This sprint is to adapt a lognormal and a PM model to our Gaussian pipeline for field-level modelling.
Contacts: Natalia Porqueres, Francois Lanusse Day/Time: Monday - Friday Main communication channel: #desc-sprint-bayesian-pipeline (#desc-bayesian-pipelines-tt secondary) GitHub repo: https://github.com/LSSTDESC/bayesian-pipelines-cosmology/blob/main/notebooks/forward_model/glass_jax_jitted.ipynb Zoom room (if applicable): https://stanford.zoom.us/j/99269396161?pwd=ekxIblNMTEFVbXk0Y09DVlovdG5vQT09 Format: Hybrid
Goals and deliverable
Building a lognormal field from the current Gaussian field generator Connecting a PM simulator to the pipeline
Resources and skills needed
Familiarity with Python and JAX.