Authors: Dmitrii Kochkov, Jamie A. Smith, Peter Norgaard, Gideon Dresdner, Ayya Alieva, Stephan Hoyer
JAX-CFD is an experimental research project for exploring the potential of machine learning, automatic differentiation and hardware accelerators (GPU/TPU) for computational fluid dynamics. It is implemented in JAX.
To learn more about our general approach, read our paper Machine learning accelerated computational fluid dynamics (PNAS 2021).
The "notebooks" directory contains several demonstrations of using the JAX-CFD code.
Demos of different simulation setups:
Reproduce results from our PNAS paper:
JAX-CFD is organized around sub-modules:
jax_cfd.base
: core finite volume/difference methods for CFD, written in JAX.jax_cfd.spectral
: core pseudospectral methods for CFD, written in JAX.jax_cfd.ml
: machine learning augmented models for CFD,
written in JAX and Haiku.jax_cfd.data
: data processing utilities for preparing, evaluating and
post-processing data created with JAX-CFD, written in
Xarray and
Pillow.A base install with pip install jax-cfd
only requires NumPy, SciPy and JAX.
To install dependencies for the other submodules, use pip install jax-cfd[ml]
,
pip install jax-cfd[data]
or pip install jax-cfd[complete]
.
JAX-CFD is currently focused on unsteady turbulent flows:
TODO: add a notebook explaining our numerical models in more depth.
In the long term, we're interested in expanding JAX-CFD to implement methods relevant for related research, e.g.,
We would welcome collaboration on any of these! Please reach out (either on GitHub or by email) to coordinate before starting significant work.
Other differentiable CFD codes compatible with deep learning:
JAX for science:
Did we miss something? Please let us know!
If you use our finite volume method (FVM) or ML models, please cite:
@article{Kochkov2021-ML-CFD,
author = {Kochkov, Dmitrii and Smith, Jamie A. and Alieva, Ayya and Wang, Qing and Brenner, Michael P. and Hoyer, Stephan},
title = {Machine learning{\textendash}accelerated computational fluid dynamics},
volume = {118},
number = {21},
elocation-id = {e2101784118},
year = {2021},
doi = {10.1073/pnas.2101784118},
publisher = {National Academy of Sciences},
issn = {0027-8424},
URL = {https://www.pnas.org/content/118/21/e2101784118},
eprint = {https://www.pnas.org/content/118/21/e2101784118.full.pdf},
journal = {Proceedings of the National Academy of Sciences}
}
If you use our spectral code, please cite:
@article{Dresdner2022-Spectral-ML,
doi = {10.48550/ARXIV.2207.00556},
url = {https://arxiv.org/abs/2207.00556},
author = {Dresdner, Gideon and Kochkov, Dmitrii and Norgaard, Peter and Zepeda-Núñez, Leonardo and Smith, Jamie A. and Brenner, Michael P. and Hoyer, Stephan},
title = {Learning to correct spectral methods for simulating turbulent flows},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}
To locally install for development:
git clone https://github.com/google/jax-cfd.git
cd jax-cfd
pip install jaxlib
pip install -e ".[complete]"
Then to manually run the test suite:
pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py