High-performance numerical integration on the GPU with PyTorch, JAX and Tensorflow
Explore the docs »
View Example notebook
·
Report Bug
·
Request Feature
The torchquad module allows utilizing GPUs for efficient numerical integration with PyTorch and other numerical Python3 modules. The software is free to use and is designed for the machine learning community and research groups focusing on topics requiring high-dimensional integration.
This project is built with the following packages:
This is a brief guide for how to set up torchquad.
We recommend using conda, especially if you want to utilize the GPU. With PyTorch it will automatically set up CUDA and the cudatoolkit for you, for example. Note that torchquad also works on the CPU; however, it is optimized for GPU usage. torchquad's GPU support is tested only on NVIDIA cards with CUDA. We are investigating future support for AMD cards through ROCm.
For a detailed list of required packages and packages for numerical backends, please refer to the conda environment files environment.yml and environment_all_backends.yml. torchquad has been tested with JAX 0.2.25, NumPy 1.19.5, PyTorch 1.10.0 and Tensorflow 2.7.0 on Linux; other versions of the backends should work as well but some may require additional setup on other platforms such as Windows.
The easiest way to install torchquad is simply to
conda install torchquad -c conda-forge
Alternatively, it is also possible to use
pip install torchquad
The PyTorch backend with CUDA support can be installed with
conda install "cudatoolkit>=11.1" "pytorch>=1.9=*cuda*" -c conda-forge -c pytorch
Note that since PyTorch is not yet on conda-forge for Windows, we have explicitly included it here using -c pytorch
.
Note also that installing PyTorch with pip may not set it up with CUDA support. Therefore, we recommend to use conda.
Here are installation instructions for other numerical backends:
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_releases.html # linux only
conda install "numpy>=1.19.5" -c conda-forge
More installation instructions for numerical backends can be found in environment_all_backends.yml and at the backend documentations, for example https://pytorch.org/get-started/locally/, https://github.com/google/jax/#installation and https://www.tensorflow.org/install/gpu, and often there are multiple ways to install them.
After installing torchquad
and PyTorch through conda
or pip
,
users can test torchquad
's correct installation with:
import torchquad
torchquad._deployment_test()
After cloning the repository, developers can check the functionality of torchquad
by running the following command in the torchquad/tests
directory:
pytest
This is a brief example how torchquad can be used to compute a simple integral with PyTorch. For a more thorough introduction please refer to the tutorial section in the documentation.
The full documentation can be found on readthedocs.
# To avoid copying things to GPU memory,
# ideally allocate everything in torch on the GPU
# and avoid non-torch function calls
import torch
from torchquad import MonteCarlo, set_up_backend
# Enable GPU support if available and set the floating point precision
set_up_backend("torch", data_type="float32")
# The function we want to integrate, in this example
# f(x0,x1) = sin(x0) + e^x1 for x0=[0,1] and x1=[-1,1]
# Note that the function needs to support multiple evaluations at once (first
# dimension of x here)
# Expected result here is ~3.2698
def some_function(x):
return torch.sin(x[:, 0]) + torch.exp(x[:, 1])
# Declare an integrator;
# here we use the simple, stochastic Monte Carlo integration method
mc = MonteCarlo()
# Compute the function integral by sampling 10000 points over domain
integral_value = mc.integrate(
some_function,
dim=2,
N=10000,
integration_domain=[[0, 1], [-1, 1]],
backend="torch",
)
To change the logger verbosity, set the TORCHQUAD_LOG_LEVEL
environment
variable; for example export TORCHQUAD_LOG_LEVEL=DEBUG
.
You can find all available integrators here.
See the open issues for a list of proposed features (and known issues).
Using GPUs torchquad scales particularly well with integration methods that offer easy parallelization. For example, below you see error and runtime results for integrating the function f(x,y,z) = sin(x * (y+1)²) * (z+1)
on a consumer-grade desktop PC.
Runtime results of the integration. Note the far superior scaling on the GPU (solid line) in comparison to the CPU (dashed and dotted) for both methods.
Convergence results of the integration. Note that Simpson quickly reaches floating point precision. Monte Carlo is not competitive here given the low dimensionality of the problem.
The project is open to community contributions. Feel free to open an issue or write us an email if you would like to discuss a problem or idea first.
If you want to contribute, please
git clone https://github.com/esa/torchquad.git
environment_all_backends.yml
.torchquad
and installs the required dependencies.
conda env create -f environment_all_backends.yml
conda activate torchquad
Once the installation is done, you are ready to contribute.
Please note that PRs should be created from and into the develop
branch. For each release the develop branch is merged into main.
git checkout -b feature/AmazingFeature
)git commit -m 'Add some AmazingFeature'
)git push origin feature/AmazingFeature
)develop
branch, not main
(NB: We autoformat every PR with black. Our GitHub actions may create additional commits on your PR for that reason.)and we will have a look at your contribution as soon as we can.
Furthermore, please make sure that your PR passes all automated tests. Review will only happen after that.
Only PRs created on the develop
branch with all tests passing will be considered. The only exception to this rule is if you want to update the documentation in relation to the current release on conda / pip. In that case you may ask to merge directly into main
.
Distributed under the GPL-3.0 License. See LICENSE for more information.
Error enabling CUDA. cuda.is_available() returned False. CPU will be used.
conda
, you can install them with conda install cudatoolkit
. For more detailed installation instructions, please refer to the PyTorch documentation.Created by ESA's Advanced Concepts Team
pablo.gomez at esa.int
gabriele.meoni at esa.int
havard.hem.toftevaag at esa.int
Project Link: https://github.com/esa/torchquad