paganpasta / eqxvision

A Python package of computer vision models for the Equinox ecosystem.
MIT License
97 stars 11 forks source link
equinox python pytorch vision


PyPI Github GitHub Release Date GitHub

Eqxvision is a package of popular computer vision model architectures built using Equinox.


Use the package manager pip to install eqxvision.

pip install eqxvision

requires: python>=3.7

optional: torch, only if pretrained models are required.


Available at


Picking a model and doing a forward pass is as simple as ...

    import jax
    import jax.random as jr
    import equinox as eqx
    from eqxvision.models import alexnet
    from eqxvision.utils import CLASSIFICATION_URLS

    def forward(net, images, key):
        keys = jax.random.split(key, images.shape[0])
        output = jax.vmap(net, axis_name=('batch'))(images, key=keys)

    net = alexnet(torch_weights=CLASSIFICATION_URLS['alexnet'])

    images = jr.uniform(jr.PRNGKey(0), shape=(1,3,224,224))
    output = forward(net, images, jr.PRNGKey(0))

What's New?

Get Started!

Start with any one of these easy to follow tutorials.



Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Development Process

If you plan to modify the code or documentation, please follow the steps below:

  1. Fork the repository and create your branch from dev.
  2. If you have modified the code (new feature or bug-fix), please add unit tests.
  3. If you have changed APIs, update the documentation. Make sure the documentation builds. mkdocs serve
  4. Ensure the test suite passes. pytest tests -vvv
  5. Make sure your code passes the formatting checks. Automatically checked with a pre-commit hook.

