HomebrewNLP / revlib

Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed support and parameter offload
https://github.com/HomebrewNLP/revlib
BSD 2-Clause "Simplified" License
124 stars 6 forks source link

Suggestions for constructing a ResNet with the revlib #5

Closed taokz closed 2 years ago

taokz commented 2 years ago

Hi

I would like to use this library to build a ResNet20 model, I've tried several times but I still have the mismatched dimension error. My model is shown as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm

hidden_size = [16, 32, 64]

class View(nn.Module):
    def forward(self, x):
        batch_size = x.size(0)
        return x.view(batch_size, -1)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, option='A'):
        super(BasicBlock, self).__init__()

        self.bn1 = norm_layer(in_planes)
        self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = norm_layer(planes)
        self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        print(1, shortcut.size())
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        print(2, out.size())
        out += shortcut
        return out

class ResNet(nn.Module):

    def __init__(self, hidden_size, block, num_blocks, num_classes=10, bn_type='bn',
                 share_affine=False, track_running_stats=True):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.bn_type = bn_type
        if bn_type == 'bn':
            norm_layer = lambda n_ch: nn.BatchNorm2d(n_ch, track_running_stats=track_running_stats)
        elif bn_type == 'gn':
            norm_layer = lambda n_ch: nn.GroupNorm(4, n_ch) # 3 can be changed -- # of groups
        else:
            raise RuntimeError(f"Not support bn_type={bn_type}")
        conv_layer = nn.Conv2d
        first = conv_layer(3, hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False)
        layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1,
                                       norm_layer=norm_layer, conv_layer=conv_layer)
        layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2,
                                       norm_layer=norm_layer, conv_layer=conv_layer)
        layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2,
                                       norm_layer=norm_layer, conv_layer=conv_layer)

        self.rev_layers = revlib.ReversibleSequential(*[layer1, layer2, layer3])

        norm = norm_layer(hidden_size[2] * block.expansion)
        linear = nn.Linear(hidden_size[2] * block.expansion, num_classes)

        self.full_model = nn.Sequential(first, self.rev_layers, nn.ReLU(), norm, \
                                        nn.AdaptiveAvgPool2d((None, 1)), View(), linear)

    def _make_layer(self, block, planes, num_blocks, stride, norm_layer, conv_layer):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, norm_layer, conv_layer))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.full_model(x)
        return out

def init_param(m):
    """Special init for ResNet"""
    if isinstance(m, (_BatchNorm, _InstanceNorm)):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        m.bias.data.zero_()
    return m

def resnet20(**kwargs):
    model = ResNet(hidden_size, BasicBlock, [3,3,3], **kwargs)
    model.apply(init_param)
    return model

I've tried to modify self.in_planes = 8 and hidden_size = [8, 16, 32], respectively, but it still does not work. Could you provide any hints? Is it possible to build a model in a forward way instead of wrapping reversible model with non-reversible layers like model = nn.Sequential(conv, rev_layer, conv)? I appreciate your help.

ClashLuke commented 2 years ago

For RevNet to work, you must ensure that twice as many features get passed in as each branch uses. Otherwise, RevLib can't split the input into two equal-sized inputs, one for each of the two branches.\ To do this, change the output channels of first from hidden_size[0] to hidden_size[0] * 2. Optionally, you could also tell RevLib to feed the same input into both branches by setting split_dim=None when constructing the ReversibleSequential module.\ Keep in mind that not only do you need twice as many features in the input, but you also get twice as many features in the output. So, the final norm and linear both need to work on twice as many features as well.

Unfortunately, I'm unsure what you mean by "build a model in a forward way." Could you elaborate?

taokz commented 2 years ago

@ClashLuke Thank you a lot for your reply.

Doubling the output channels of first will raise RuntimeError,

File ~\Anaconda3\lib\site-packages\revlib\core.py:183, in additive_coupling_forward(other_stream, fn_out)
    181 fn_out = split_tensor_list(fn_out)
    182 if isinstance(fn_out, torch.Tensor):
--> 183     return other_stream + fn_out
    184 return [other_stream + fn_out[0]] + fn_out[1]

RuntimeError: The size of tensor a (32) must match the size of tensor b (16) at non-singleton dimension 3

I've printed the output size in the feedward of the BasicBlock, layer 1 and layer 2 work well but I can not print the output of the layer 3.

There is the other problem due to the the doubling operation, the number of parameters of the RevNet will not equal to the the original ResNet. I may not fully understand this library and how RevNet works, and implement RevNet with revlib in a wrong way.

About the question build a model in a forward way, I would like to know if I can stack revlib.ReversibleSequential(), specifically,

layer 1 = self._make_layer(...); layer 2 = self._make_layer(...); layer 3 = self._make_layer(...)
rev_layers = nn.Sequantial (layer 1, layer 2, layer 3) 

or

def forward(x):
    out = layer 1 (x)
    out = layer 2 (out)
    out = layer 3 (out)

where ._make_layer(...) returns revlib.ReversibleSequential(*layers), which is different from the previous one --

self.rev_layers = revlib.ReversibleSequential(*[layer1, layer2, layer3]), where layer* is nn.Sequantial().

I've tried this in my previous RevNet20 code, and I got the same RuntimeError but the error happens in layer 2 instead of layer 3.

ClashLuke commented 2 years ago

The top problem you faced is that RevNet requires all inputs and outputs to be the same size. As the second layer has more output features than the first, RevNet will have to add a tensor with 32 features to one with 16, which isn't possible.\ Think about it like in a ResNet. In ResNet, you only have the residual path within each resolution+feature size, but not across them. To get the residual stream across, you usually use downsampling (such as AvgPool2d) and add its output to the output of your "residual" block. In RevNet, the second thing doesn't exist. Instead, you would have to use PixelShuffle and feature padding to arrive at a similar result (see #2).

The easiest way forward would be to have multiple ReversibleSequential modules, one for each _make_layer()-call, and put these into a standard nn.Sequential-container. This is how the original RevNet did it. Their method uses marginally more parameters but otherwise gives the same results: grafik grafik

Another alternative would be to avoid this multi-stage assembly and construct one large ReversibleSequential module instead. Using one large block saves memory, and i-RevNet documented how they achieved marginally worse ImageNet accuracy with this kind of architecture: grafik


Yes, you can define the reversible architecture in forward. However, I'd advise against it, as ReversibleSequential is a thin wrapper around things you have to do anyway.\ If you want to do what ReversibleSequential would usually handle for you, you'd have to wrap your modules in ReversibleModules like so: https://github.com/HomebrewNLP/revlib/blob/34dad19318e2f861ea6b0ce263506625a934b568/revlib/core.py#L471-L487

and call these modules one-by-one, just like in a normal nn.Sequential module: https://github.com/HomebrewNLP/revlib/blob/34dad19318e2f861ea6b0ce263506625a934b568/revlib/core.py#L509-L511

taokz commented 2 years ago

Thanks a lot for your detailed explanation! Now I can understand why my code does not work.

I also read the source codes of revnet and i-revnet, and they provide a downsampling to match the dimension of other_stream and fn_out (the mismatch is caused by the changing of number of channels). If I do not misunderstand, it seems that your code does not provide this feature, right? Do you have a plan to add this feature?

ClashLuke commented 2 years ago

Sorry, I'm not planning to add these, as the most common functions (pooling, pixelshuffle, upsample) are already part of PyTorch.