A Flax (Linen) implementation of ResNet (He et al. 2015), Wide ResNet (Zagoruyko & Komodakis 2016), ResNeXt (Xie et al. 2017), ResNet-D (He et al. 2020), and ResNeSt (Zhang et al. 2020). The code is modular so you can mix and match the various stem, residual, and bottleneck implementations.
You can install this package from PyPI:
pip install jax-resnet
Or directly from GitHub:
pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git
See the bottom of jax-resnet/resnet.py
for the available aliases/options for
the ResNet variants (all models are in Flax)
Pretrained checkpoints from
torch.hub
are available for the
following networks:
The models are
tested
to have the same intermediate activations and outputs as the torch.hub
implementations, except ResNeSt-50 Fast, whose activations don't match exactly
but the final accuracy does.
A pretrained checkpoint for ResNetD-50 is available from fast.ai. The activations do not match exactly, but the final accuracy matches.
import jax.numpy as jnp
from jax_resnet import pretrained_resnest
ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
out = model.apply(variables,
jnp.ones((32, 224, 224, 3)), # ImageNet sized inputs.
mutable=False) # Ensure `batch_stats` aren't updated.
You must install PyTorch yourself (instructions) to use these functions.
To extract a subset of the model, you can use
Sequential(model.layers[start:end])
.
The slice_variables
function (found in in
common.py
)
allows you to extract the corresponding subset of the variables dict. Check out
that docstring for more information.
The top 1 and top 5 accuracies reported below are on the ImageNet2012 validation split. The data was preprocessed as in the official PyTorch example.
Model | Size | Top 1 | Top 5 |
---|---|---|---|
ResNet | 18 | 69.75% | 89.06% |
34 | 73.29% | 91.42% | |
50 | 76.13% | 92.86% | |
101 | 77.37% | 93.53% | |
152 | 78.30% | 94.04% | |
Wide ResNet | 50 | 78.48% | 94.08% |
101 | 78.88% | 94.29% | |
ResNeXt | 50 | 77.60% | 93.70% |
101 | 79.30% | 94.51% | |
ResNet-D | 50 | 77.57% | 93.85% |
The ResNeSt validation data was preprocessed as in zhang1989/ResNeSt.
Model | Size | Crop Size | Top 1 | Top 5 |
---|---|---|---|---|
ResNeSt-Fast | 50 | 224 | 80.53% | 95.34% |
ResNeSt | 50 | 224 | 81.05% | 95.42% |
101 | 256 | 82.82% | 96.32% | |
200 | 320 | 83.84% | 96.86% | |
269 | 416 | 84.53% | 96.98% |