Eqxvision is a package of popular computer vision model architectures built using Equinox.
Use the package manager pip to install eqxvision.
pip install eqxvision
requires: python>=3.7
optional: torch
, only if pretrained
models are required.
Available at https://eqxvision.readthedocs.io/en/latest/.
Picking a model and doing a forward pass is as simple as ...
import jax
import jax.random as jr
import equinox as eqx
from eqxvision.models import alexnet
from eqxvision.utils import CLASSIFICATION_URLS
@eqx.filter_jit
def forward(net, images, key):
keys = jax.random.split(key, images.shape[0])
output = jax.vmap(net, axis_name=('batch'))(images, key=keys)
...
net = alexnet(torch_weights=CLASSIFICATION_URLS['alexnet'])
images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
output = forward(net, images, jr.PRNGKey(0))
FCN
, DeepLabV3
and LRASPP
added as new image segmentation models.v0.2.0
for loading a pretrained
model.torchvision
.adversarial examples
and others coming soon.Start with any one of these easy to follow tutorials.
@equinox.filter_jit
instead of @jax.jit
.jax.{v,p}map
with axis_name='batch'
when using models that use batch normalisation.inference
mode for evaluations. (model = eqx.tree_inference(model)
)optim.init(eqx.filter(net, eqx.is_array))
. (See here.)Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
If you plan to modify the code or documentation, please follow the steps below:
dev
.mkdocs serve
pytest tests -vvv
pre-commit
hook.