Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Hello Nicholas, while using pretrained RESNET(101) I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there) after running it to an input batch size[1, 224, 224, 3] It was torch.Size ([1, 2048, 28, 28]). However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)

def ResNet(
    block_cls: ModuleDef,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
) -> Sequential:
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)

    layers = [stem_cls(), pool_fn] 

    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        for b in range(n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            layers.append(block_cls(n_hidden=hsize, strides=strides))
 #  layers.append(partial(jnp.mean, axis=(1, 2)))  # global average pool
 # layers.append(nn.Dense(n_classes))
    return Sequential(layers)

It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :

RESNET100, variables = pretrained_resnet(101)
model_out=RESNET.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False) 
print("pretrained resnet100 size:", jax.tree_map(lambda x: x.shape, model_out))

pretrained resnet100 size:--> (1, 7, 7, 2048) So, what's happened at this stage in ResNet layers structure? Kindly reply, if you have any explanation or recommendations.

Hi @sarahelsherif, thanks for raising this issue! Could you also paste in the PyTorch code that gives you torch.Size ([1, 2048, 28, 28]) for comparison?

Thank you @n2cholas , ok here is the PyTorch code:

class RESNET_Layer_4(nn.Module):

    def __init__(self, backbone: nn.Module) -> None:
        super(RESNET_Layer_4, self).__init__()
        self.backbone = backbone
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        input_shape = x.shape[-2:]
        backbone_resnet= self.backbone(x)
        print("backbone resnet output shape",backbone_resnet["out"].shape)

        return backbone_resnet
def resnet4(
    backbone: ResNet,
) -> RESNET_Layer_4:
    return_layers = {"layer4": "out"}
    backbone = create_feature_extractor(backbone, return_layers)
    return RESNET_Layer_4(backbone)

the input_batch shape is : torch.Size([1, 3, 224, 224])

pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
r4= r4.cuda()

the output is : backbone resnet output shape torch.Size([1, 2048, 28, 28])

Hi @sarahelsherif, I wasn't able to directly use the code that you sent since I do not have create_feature_extractor. Instead, see this example of extracting the bacbone in both JAX and PyTorch here. As you can see they have the same output shape.

Does that help?

Hey @n2cholas , first of all thank you so much for help. About create_feature_extractor , it is a utility from TorchVison create_feature_extractor ,which can be imported like this:

from torchvision.models.feature_extraction import create_feature_extractor

And thank you for your example, it helped. I know now why the output shape is different because of replacing strides and dilation in the pretrained resnet:

pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])

So, my issue is solved now about different output shapes. On the other hand, I will be grateful , if you suggested a way to apply replacing stride with dilation in JAX.

This can definitely be supported, essentially we would need to apply the logic in _make_layer to the ResNetBottleneckBlock. I won't have the bandwidth to work on this for a few weeks, but would happy to review a PR if you decide to implement this.

Yes, sure ..thank you so much for help. And will update you, when I implement it