LSSTDESC / SprintWeek2023

Meeting repository for the LSST DESC 2023 Sprint Week @ CMU
Creative Commons Zero v1.0 Universal
2 stars 0 forks source link

Building a lognormal and PM model into a Bayesian pipeline #18

Open nataliaporqueres opened 1 year ago

nataliaporqueres commented 1 year ago

Building a lognormal and PM model into a Bayesian pipeline

This sprint is to adapt a lognormal and a PM model to our Gaussian pipeline for field-level modelling.

Contacts: Natalia Porqueres, Francois Lanusse Day/Time: Monday - Friday Main communication channel: #desc-sprint-bayesian-pipeline (#desc-bayesian-pipelines-tt secondary) GitHub repo: https://github.com/LSSTDESC/bayesian-pipelines-cosmology/blob/main/notebooks/forward_model/glass_jax_jitted.ipynb Zoom room (if applicable): https://stanford.zoom.us/j/99269396161?pwd=ekxIblNMTEFVbXk0Y09DVlovdG5vQT09 Format: Hybrid

Goals and deliverable

Building a lognormal field from the current Gaussian field generator Connecting a PM simulator to the pipeline

Resources and skills needed

Familiarity with Python and JAX.

EiffL commented 1 year ago

Here is our action plan for the week

Improving GLASS-JAX

Sampling cosmology and initial conditions

Interfacing PM n-body density field

AZhou00 commented 1 year ago

setting up GPU and the correct GPU python kernel on nersc with dependencies for this project. (should work out of the box!)


# get on a GPU node first!
# https://docs.nersc.gov/development/languages/python/using-python-perlmutter/
module load cudatoolkit/11.7
module load cudnn/8.9.1_cuda11
module load python
# Verify the versions of cudatoolkit and cudnn are compatible with JAX
module list
# Create a new conda environment 
conda create -n jglass python=3.9 pip numpy scipy
# Activate the environment before using pip to install JAX
conda activate jgass
# Install a compatible wheel
pip install --no-cache-dir "jax[cuda11_cudnn82]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# try if jax works on GPU
python
import jax.numpy as jnp
import jax
# should be no error if everything worked
jnp.ones(10)
# should print gpus
jax.devices()
exit()

# other dependencies

# glass
git clone https://github.com/glass-dev/glass.git
cd glass
pip install -e .
cd ..

# glass-jax
git clone https://github.com/LSSTDESC/glass-jax.git
cd glass-jax
pip install -e .
cd ..

# pmwd
git clone https://github.com/eelregit/pmwd.git
cd pmwd
pip install -e .
cd ..

# s2fft
git clone https://github.com/astro-informatics/s2fft.git
cd s2fft
pip install -e .
cd ..

# other stuff
conda install -c conda-forge healpy h5py astropy matplotlib
pip install jupyterlab
pip install jax-cosmo
pip install typing-extensions

# setting up the kernel for jupyterhub
# https://docs.nersc.gov/services/jupyter/how-to-guides/
# conda activate jglass
python -m ipykernel install \
    --user --name jglass --display-name jglass```
jecampagne commented 1 year ago

Hello, I've just tried and get

(nersc-python) login18:c/campagne/Work$ conda create -n jglass python=3.9 pip numpy scipy
Collecting package metadata (current_repodata.json): failed

# >>>>>>>>>>>>>>>>>>>>>> ERROR REPORT <<<<<<<<<<<<<<<<<<<<<<

    Traceback (most recent call last):
      File "/global/common/software/nersc/pe/conda/23.9.0/Miniconda3-py311_23.5.2-0/lib/python3.11/site-packages/conda/exception_handler.py", line 17, in __call__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
.... 

may I have missed something?

AZhou00 commented 1 year ago

Hello, I've just tried and get

(nersc-python) login18:c/campagne/Work$ conda create -n jglass python=3.9 pip numpy scipy
Collecting package metadata (current_repodata.json): failed

# >>>>>>>>>>>>>>>>>>>>>> ERROR REPORT <<<<<<<<<<<<<<<<<<<<<<

    Traceback (most recent call last):
      File "/global/common/software/nersc/pe/conda/23.9.0/Miniconda3-py311_23.5.2-0/lib/python3.11/site-packages/conda/exception_handler.py", line 17, in __call__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
.... 

may I have missed something?

Hi,

Could you try starting fresh and use the following lines:

conda create -n jglass python pip numpy scipy
conda activate jgass
module load gpu
module load evp-patch
conda install -y jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
conda install -c conda-forge numpyro

and then install the "other dependencies." This should get you ready with a more up-to-date version of jax and numpyro as well.

jecampagne commented 1 year ago

Hello @AZhou00

Thanks for the tip. In fact I'm pretty sure that the origin of the crash was due to disk quota exceed :(, as when I tried the above workflow I get the nasty "OSError: [Errno 122] Disk quota exceeded" immersed among a flow of error reports. After cleaning a little I was able to start the installation process :)

Now, as I am not familiar with NERSC running, I certainly not did an install in the right place. May I ask you where you install such environment, and where to do you run (ie. jupyter notebook @ nersc). I know that these questions are quite basic, so I hope you don't mind too much. Thanks.

jecampagne commented 1 year ago

I have in a jupyter in Google Colab but I do not manage to import glass-jax after

! git clone https://github.com/LSSTDESC/glass-jax.git

and the pip install -e .

The glass-jax module is not recognized neither glass_jax... Have you an idea?

I have a short use-case to test.

AZhou00 commented 1 year ago

For future reference, use the following setup commands to install the required glass-jax environment with gpu on NERSC 😀

# get on a GPU node first!
module load gpu
module load evp-patch

# Create a new conda environment
conda create -n glass python pip numpy scipy
conda activate glass
conda install -y jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
conda install -c conda-forge numpyro

# -------- then follow the code below --------
# try if jax works on GPU
python
import jax.numpy as jnp
import jax
import numpyro
# should be no error if everything worked
jnp.ones(10)
jnp.fft.fft(jnp.array([1.0,2.0,3.0,4.0,5.0]))
# should print gpus
jax.devices()
exit()

# other dependencies
# other stuff
conda install -c conda-forge healpy h5py astropy matplotlib pyccl

pip install jupyterlab
pip install jax-cosmo
pip install typing-extensions

# s2fft
git clone https://github.com/astro-informatics/s2fft.git
cd s2fft
pip install -e .
cd ..

# glass
git clone https://github.com/glass-dev/glass.git
cd glass
pip install -e .
cd ..

# glass-jax
git clone https://github.com/LSSTDESC/glass-jax.git
cd glass-jax
pip install -e .
cd ..

# pmwd
git clone https://github.com/eelregit/pmwd.git
cd pmwd
pip install -e .
cd ..

# setting up the kernel for jupyterhub
# https://docs.nersc.gov/services/jupyter/how-to-guides/
# conda activate jglass
python -m ipykernel install \
    --user --name glass --display-name glass