Open sarahelsherif opened 2 years ago
Hi @sarahelsherif, thanks for raising this issue! Could you also paste in the PyTorch code that gives you torch.Size ([1, 2048, 28, 28])
for comparison?
Thank you @n2cholas , ok here is the PyTorch code:
class RESNET_Layer_4(nn.Module):
def __init__(self, backbone: nn.Module) -> None:
super(RESNET_Layer_4, self).__init__()
self.backbone = backbone
def forward(self, x: Tensor) -> Dict[str, Tensor]:
input_shape = x.shape[-2:]
backbone_resnet= self.backbone(x)
print("backbone resnet output shape",backbone_resnet["out"].shape)
return backbone_resnet
def resnet4(
backbone: ResNet,
) -> RESNET_Layer_4:
return_layers = {"layer4": "out"}
backbone = create_feature_extractor(backbone, return_layers)
return RESNET_Layer_4(backbone)
the input_batch shape is : torch.Size([1, 3, 224, 224])
pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
r4=resnet4(pretrained_resnet)
r4= r4.cuda()
out_resnet4=r4(inp_batch)
the output is : backbone resnet output shape torch.Size([1, 2048, 28, 28])
Hi @sarahelsherif, I wasn't able to directly use the code that you sent since I do not have create_feature_extractor
. Instead, see this example of extracting the bacbone in both JAX and PyTorch here. As you can see they have the same output shape.
Does that help?
Hey @n2cholas , first of all thank you so much for help.
About create_feature_extractor
, it is a utility from TorchVison create_feature_extractor ,which can be imported like this:
from torchvision.models.feature_extraction import create_feature_extractor
And thank you for your example, it helped. I know now why the output shape is different because of replacing strides and dilation in the pretrained resnet:
pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
So, my issue is solved now about different output shapes. On the other hand, I will be grateful , if you suggested a way to apply replacing stride with dilation in JAX.
This can definitely be supported, essentially we would need to apply the logic in _make_layer
to the ResNetBottleneckBlock
. I won't have the bandwidth to work on this for a few weeks, but would happy to review a PR if you decide to implement this.
Yes, sure ..thank you so much for help. And will update you, when I implement it
Hello Nicholas, while using pretrained RESNET(101) I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there) after running it to an input batch size[1, 224, 224, 3] It was torch.Size ([1, 2048, 28, 28]). However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)
It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :
pretrained resnet100 size:--> (1, 7, 7, 2048) So, what's happened at this stage in ResNet layers structure? Kindly reply, if you have any explanation or recommendations.