google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link
bi-level deep-learning differentiable-programming jax optimization

JAXopt

Installation | Documentation | Examples | Cite us

⚠️ We are in the process of merging JAXopt into Optax. Because of this, JAXopt is now in maintenance mode and we will not be implementing new features ⚠️

Hardware accelerated, batchable and differentiable optimizers in JAX.

Installation

To install the latest release of JAXopt, use the following command:

$ pip install jaxopt

To install the development version, use the following command instead:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be installed from sources with the following command:

$ python setup.py install

Cite us

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy 
    and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian 
    and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.