Jax_verify is a library containing JAX implementations of many widely-used neural network verification techniques.
If you just want to get started with using jax_verify to verify your neural networks, the main thing to know is we provide a simple, consistent interface for a variety of verification algorithms:
output_bounds = jax_verify.verification_technique(network_fn, input_bounds)
Here, network_fn
is any JAX function, input_bounds
define bounds over
possible inputs to network_fn
, and output_bounds
will be the computed bounds
over possible outputs of network_fn
. verification_technique
can be one of
many algorithms implemented in jax_verify
, such as interval_bound_propagation
or crown_bound_propagation
.
The overall approach is to use JAX’s powerful program transformation system,
which allows us to analyze general network structures defined by network_fn
and then to define corresponding functions for calculating
verified bounds for these networks.
The methods currently provided by jax_verify
include:
Stable: Just run pip install jax_verify
and you can import jax_verify
from any of your Python code.
Latest: Clone this directory and run pip install .
from the directory root.
We suggest starting by looking at the minimal examples in the examples/
directory.
For example, all the bound propagation techniques can be run with the run_boundprop.py
script:
cd examples/
python3 run_boundprop.py --boundprop_method=interval_bound_propagation
For documentation, please refer to the API reference page.
Contributions of additional verification techniques are very welcome. Please open an issue first to let us know.
All code is made available under the Apache 2.0 License. Model parameters are made available under the Creative Commons Attribution 4.0 International (CC BY 4.0) License. See https://creativecommons.org/licenses/by/4.0/legalcode for more details.
This is not an official Google product.