TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

set_scheduler does not work using VotingClassifier #33

Closed Alex-Medium closed 3 years ago

Alex-Medium commented 3 years ago

Hi, I have used torchensemble for classification on CIFAR-10, and it looks like there is a problem on set_scheduler.

I have followed the examples and used 5 resnet-18 models. The results on FusionClassifier are fine, over 95% on CIFAR-10. However, VotingClassifier only achieves a testing accuracy of 90%.

After looking at the source codes on VotingClassifier, I think the problem is related to set_scheduler, could you help fix it?

Sincerely

xuyxu commented 3 years ago

Thanks for reporting @Alex-Medium. Could you provide a running script on your experiment, so that we can reproduce the problem?

Alex-Medium commented 3 years ago

Sure.

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchensemble.fusion import FusionClassifier
from torchensemble.voting import VotingClassifier
from torchensemble.utils.logging import set_logger

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

if __name__ == "__main__":

    # Hyper-parameters
    n_estimators = 5
    lr = 1e-1
    weight_decay = 5e-4
    epochs = 200
    n_jobs = n_estimators
    method = "voting"

    # Utils
    batch_size = 128
    data_dir = "../../Dataset/cifar"  # MODIFY THIS IF YOU WANT
    records = []
    torch.manual_seed(0)
    torch.cuda.set_device(0)

    # Load data
    train_transformer = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
    )

    test_transformer = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
    )

    train_loader = DataLoader(
        datasets.CIFAR10(data_dir, train=True, download=True,
                         transform=train_transformer),
        batch_size=batch_size,
        shuffle=True,
    )

    test_loader = DataLoader(
        datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
        batch_size=batch_size,
        shuffle=True,
    )

    logger = set_logger("classification_cifar10_cnn")

    if method == "fusion":
        model = FusionClassifier(
            estimator=ResNet,
            estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]},
            n_estimators=n_estimators,
            cuda=True
        )
    elif method == "voting":
        model = VotingClassifier(
            estimator=ResNet,
            estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]},
            n_estimators=n_estimators,
            cuda=True,
            n_jobs=n_estimators
        )
    else:
        raise NotImplementedError

    model.set_optimizer("SGD", lr=lr, weight_decay=weight_decay, momentum=0.9)
    model.set_scheduler("CosineAnnealingLR", T_max=epochs)

    model.fit(train_loader,
              epochs=epochs,
              test_loader=test_loader)
zzzzwj commented 3 years ago

@Alex-Medium OK, I'll check it soon.

Alex-Medium commented 3 years ago

@zzzzwj Maybe the optimizer located in the scheduler no longer exists after the parallelization, right ?

xuyxu commented 3 years ago

I agree with you @Alex-Medium ;-)

The deepcopy operation on the side of joblib may have corrupted the bindings between optimizers and schedulers in PyTorch. Zhang is currently working on this. Hopefully this bug would be fixed soon.

xuyxu commented 3 years ago

Hi @Alex-Medium, this issue has been fixed in PR #37. Feel free to open another issue if you have any problem ;-)

Alex-Medium commented 3 years ago

Great, thanks!