TL-System / plato

A federated learning framework to support scalable and reproducible research
Apache License 2.0
337 stars 79 forks source link

[FR] The implementation of Resnet in Plato needs to be revised #214

Closed CSJDeveloper closed 2 years ago

CSJDeveloper commented 2 years ago

Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

That is pretty strange to implement the ResNet, i.e., plato/models/resnet.py in Plato to only support the CIFAR-related dataset with input size 32. The pooling layer before the fc layer in plato is F.avg_pool2d, which is less effective.

Maybe the current implementation is to support the 'cut_layer'? But, still, this is unnecessary to reimplement a model by ourselves as the torchvision has provided many models. Once we want to remove some layers, just use nn.Identity() to replace them without any risk.

Describe the solution you'd like A clear and concise description of what you want to happen.

The most effective way is to follow the implementation of ResNet in torchvision. They utilize the:

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

to support any input size.

I know the current implementation of ResNet replaces the kernel_size of conv1 7x7 with 3x3, and removes first max pooling to maintain spatial information for input with small sizes, such as 32x32 for CIFAR10.

However, this can be easily achieved by using the torchvision's implementation while setting:

  encoder.conv1 = nn.Conv2d(3,
                                  64,
                                  kernel_size=3,
                                  stride=1,
                                  padding=2,
                                  bias=False)
  encoder.maxpool = nn.Identity()

Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.

Just reuse torchvision's implementations.

Additional context Add any other context or screenshots about the feature request here.

In my own work, I directly utilize torchvision's implementations and revise the torchvision's model slightly based on my own requirement. It works well on many different datasets without error. See the models/encoders_register.py in the contrastive_adaptation branch.

baochunli commented 2 years ago

With the recently redesigned model registry, torchvision's ResNet model can be directly used by loading it from torch.hub:

./run -c configs/CIFAR10/fedavg_resnet18_torchhub.yml

Once torchvision releases 0.14 officially, its get_model() can also be used to load torchvision's model directly.