Adamdad / neumeta

NeuMeta transforms neural networks by allowing a single model to adapt on the fly to different sizes, generating the right weights when needed.
35 stars 4 forks source link

Issue with Performance Drop after Permutation Smoothing on ResNet-18 #3

Open secret-hammer opened 4 days ago

secret-hammer commented 4 days ago

I'm attempting to reproduce the results of applying permutation-based parameter smoothing to the ResNet-18 model, but I'm seeing a significant performance drop after the permutation step. Below are the details of the issue, including the code I used and the experimental results.

Experiment Details:

Code:

Experiment Code:

import models
import loaders
import torch
import torch.nn as nn
import metrics
from utils.train_util import AverageMeter
from utils import permute

def test_model_base(test_loader, model_base, device):
    model_base.eval()
    acc_meter = AverageMeter('Acc', ':6.2f')

    for i, samples in enumerate(test_loader):
        inputs, labels, _ = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model_base(inputs)
            acc = metrics.accuracy(outputs, labels)[0]
            acc_meter.update(acc.item(), inputs.size(0))

    return acc_meter

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    state_dict = torch.load('/nfs196/wjx/projects/PP/outputs/rn18_OH_Ar_base_10/Art/model.pt', weights_only=True)
    model = models.load_model('rn18', num_classes=65)
    model.load_state_dict(state_dict['last_param'])
    model = model.to(device)
    model.eval()

    test_loader = loaders.load_images("/nfs196/hjc/datasets/Office-Home/Art", 'Office-Home', data_type='test', batch_size=512)

    acc_init = test_model_base(test_loader, model, device)

    # Compute the total variation loss for the network
    total_tv = permute.compute_tv_loss_for_network(model, lambda_tv=1.0)
    print("Total Total Variation After Training:", total_tv)

    # Apply permutations to the model's layers and check the total variation
    input_tensor = torch.randn(1, 3, 224, 224).to(device)
    permute_func = permute.PermutationManager(model, input_tensor)
    permute_dict = permute_func.compute_permute_dict()
    model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[])
    total_tv = permute.compute_tv_loss_for_network(model_permute, lambda_tv=1.0)
    print("Total Total Variation After Permute:", total_tv)

    acc_permute = test_model_base(test_loader, model_permute, device)

    print("Initial accuracy: ", acc_init.avg)
    print("Accuracy after permutation: ", acc_permute.avg)

if __name__ == '__main__':
    main()

ResNet-18 Model Code:

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, input_channel=3):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

Observations:

Here is the output from the experiment:

Total Total Variation After Training: tensor(3745.0908, device='cuda:0', grad_fn=<DivBackward0>)
Total Total Variation After Permute: tensor(3479.0449, device='cuda:0', grad_fn=<DivBackward0>)
Initial accuracy:  68.39719817763157
Accuracy after permutation:  1.8953440410392393

Issue:

The accuracy after permutation is unexpectedly much lower than the initial accuracy. The total variation loss decreases after smoothing, but the accuracy drops significantly, which is not the expected behavior. Is there any issue with the permutation procedure or its configuration that could cause such a drastic drop in performance?

Thank you for your help!

Best,
[Secret Hammer]

Adamdad commented 3 days ago

@secret-hammer Hello. I guess the problem is that you need to do BN fuse (fuse the BN parameter into linear layer) before do permutation. We actually do not directly handle the BN layer, but fuse it into linear layer and permute the linear layer instead.

secret-hammer commented 3 days ago

@Adamdad Thank you for your answer! Is there any tool or code available to integrate the BN layer parameters into the linear layer? Thanks again.

Adamdad commented 3 days ago

This should work

def fuse_module(module):
    """
    Recursively fuse all batch normalization layers in the module with their preceding convolutional layers.
    """
    children = list(module.named_children())
    prev_name, prev_module = None, None

    for name, child in children:
        # print(name)
        if isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.Conv2d):
            # Fuse the conv and bn layers, replace the conv layer with the fused layer
            fused_conv = fuse_conv_bn(prev_module, child)
            module._modules[prev_name] = fused_conv

            # Remove the batch normalization layer
            module._modules[name] = nn.Identity()
        elif isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.ConvTranspose2d):
            # Fuse the conv and bn layers, replace the conv layer with the fused layer
            fused_conv = fuse_conv_transpose_bn(prev_module, child)
            module._modules[prev_name] = fused_conv

            # Remove the batch normalization layer
            module._modules[name] = nn.Identity()
        else:
            # Recursively apply to all submodules
            fuse_module(child)

        prev_name, prev_module = name, child
secret-hammer commented 3 days ago

Thank you for the reply. Could you provide the related code for the two core functions, fuse_conv_bn and fuse_conv_transpose_bn?

Adamdad commented 3 days ago
import torch
import torch.nn as nn
# test()

def fuse_conv_bn(conv, bn):
    """
    Fuse convolution and batch normalization layers
    """
    # Extract conv layer parameters
    conv_w = conv.weight
    conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean)

    # Extract bn layer parameters
    bn_rm = bn.running_mean
    bn_rv = bn.running_var
    bn_eps = bn.eps
    bn_w = bn.weight
    bn_b = bn.bias

    # Calculate fused parameters
    inv_var = torch.rsqrt(bn_rv + bn_eps)
    bn_w_div_var = bn_w * inv_var
    bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var

    fused_conv_weight = conv_w * bn_w_div_var.view(-1, 1, 1, 1)
    fused_conv_bias = conv_b * bn_w_div_var + bn_bias_sub_rm_w_div_var

    # Create and return the fused layer
    fused_conv = nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        conv.kernel_size,
        conv.stride,
        conv.padding,
        bias=True
    )
    fused_conv.weight = nn.Parameter(fused_conv_weight)
    fused_conv.bias = nn.Parameter(fused_conv_bias)

    return fused_conv

def fuse_conv_transpose_bn(conv_transpose, bn):
    """
    Fuse ConvTranspose2d and BatchNorm2d layers.
    """
    # Extract conv transpose layer parameters
    conv_transpose_w = conv_transpose.weight
    conv_transpose_b = conv_transpose.bias if conv_transpose.bias is not None else torch.zeros_like(bn.running_mean)

    # Extract bn layer parameters
    bn_rm = bn.running_mean
    bn_rv = bn.running_var
    bn_eps = bn.eps
    bn_w = bn.weight
    bn_b = bn.bias

    # Calculate fused parameters
    inv_var = torch.rsqrt(bn_rv + bn_eps)
    bn_w_div_var = bn_w * inv_var
    bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var

    # print(conv_transpose_w.shape, bn_w_div_var.shape)
    fused_conv_transpose_weight = conv_transpose_w * bn_w_div_var.view(1, -1, 1, 1)
    fused_conv_transpose_bias = conv_transpose_b * bn_w_div_var + bn_bias_sub_rm_w_div_var

    # Create and return the fused layer
    fused_conv_transpose = nn.ConvTranspose2d(
        conv_transpose.in_channels,
        conv_transpose.out_channels,
        conv_transpose.kernel_size,
        conv_transpose.stride,
        conv_transpose.padding,
        conv_transpose.output_padding,
        groups=conv_transpose.groups,
        dilation=conv_transpose.dilation,
        bias=True
    )
    fused_conv_transpose.weight = nn.Parameter(fused_conv_transpose_weight)
    fused_conv_transpose.bias = nn.Parameter(fused_conv_transpose_bias)

    return fused_conv_transpose
secret-hammer commented 3 days ago

Thank you so much! I'll give it a try as soon as I can. I really appreciate your help and patient explanation!

secret-hammer commented 2 days ago

I modified the ResNet18 model to include a fusion of the Batch Normalization (BN) parameters into the preceding convolution layers. After performing this fusion, I observed the following issues:

  1. Increased Total Variation Loss: After applying the fusion, the total variation loss (tv_loss) significantly increased, which was unexpected.
  2. Smooth Models Lost Original Effect: After applying a smoothing operation to the model (using a permutation technique), the model still lose its original effect, which I cannot fully explain. This might be due to an issue in my experimental setup.

I have included the new experimental code below, which outlines the steps I followed. It includes the fusion of BatchNorm layers, the calculation of total variation loss before and after fusion, and the application of permutations to the model.

Experiment Code:

import loaders

import torch
import torch.nn as nn

import metrics
from utils.train_util import AverageMeter

from utils import permute

def test_model_base(test_loader, model_base, device):
    model_base.eval()
    acc_meter = AverageMeter('Acc', ':6.2f')

    for i, samples in enumerate(test_loader):
        inputs, labels, _ = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model_base(inputs)
            acc = metrics.accuracy(outputs, labels)[0]
            acc_meter.update(acc.item(), inputs.size(0))

    return acc_meter

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    state_dict = torch.load('/nfs196/wjx/projects/PP/outputs/rn18_OH_Ar_base_10/Art/model.pt', weights_only=True)
    model = models.load_model('rn18', num_classes=65)
    model.load_state_dict(state_dict['last_param'])
    model = model.to(device)
    model.eval()

    test_loader = loaders.load_images("/nfs196/hjc/datasets/Office-Home/Art", 'Office-Home', data_type='test', batch_size=512)

    acc_init = test_model_base(test_loader, model, device)

    # Compute the total variation loss for the network
    total_tv = permute.compute_tv_loss_for_network(model, lambda_tv=1.0)
    print("Total Total Variation After Training:", total_tv)

    fuse_module(model)
    acc_fuse = test_model_base(test_loader, model, device)
    total_tv = permute.compute_tv_loss_for_network(model, lambda_tv=1.0)
    print("Total Total Variation After Fuse:", total_tv)

    # Apply permutations to the model's layers and check the total variation
    input_tensor = torch.randn(1, 3, 224, 224).to(device)
    permute_func = permute.PermutationManager(model, input_tensor)
    # Compute the permutation matrix for each clique graph, save as a dict
    permute_dict = permute_func.compute_permute_dict()
    # Apply permutation to the weight
    model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[])
    total_tv = permute.compute_tv_loss_for_network(model_permute, lambda_tv=1.0)
    print("Total Total Variation After Permute:", total_tv)

    acc_permute = test_model_base(test_loader, model_permute, device)

    print("Initial accuracy: ", acc_init.avg)
    print("Accuracy after fuse", acc_fuse.avg)
    print("Accuracy after permutation: ", acc_permute.avg)

def fuse_module(module):
    """
    Recursively fuse all batch normalization layers in the module with their preceding convolutional layers.
    """
    children = list(module.named_children())
    prev_name, prev_module = None, None

    for name, child in children:
        # print(name)
        if isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.Conv2d):
            # Fuse the conv and bn layers, replace the conv layer with the fused layer
            fused_conv = fuse_conv_bn(prev_module, child)
            module._modules[prev_name] = fused_conv

            # Remove the batch normalization layer
            module._modules[name] = nn.Identity()
        elif isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.ConvTranspose2d):
            # Fuse the conv and bn layers, replace the conv layer with the fused layer
            fused_conv = fuse_conv_transpose_bn(prev_module, child)
            module._modules[prev_name] = fused_conv

            # Remove the batch normalization layer
            module._modules[name] = nn.Identity()
        else:
            # Recursively apply to all submodules
            fuse_module(child)

        prev_name, prev_module = name, child

def fuse_conv_bn(conv, bn):
    """
    Fuse convolution and batch normalization layers
    """
    # Extract conv layer parameters
    conv_w = conv.weight
    conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean)

    # Extract bn layer parameters
    bn_rm = bn.running_mean
    bn_rv = bn.running_var
    bn_eps = bn.eps
    bn_w = bn.weight
    bn_b = bn.bias

    # Calculate fused parameters
    inv_var = torch.rsqrt(bn_rv + bn_eps)
    bn_w_div_var = bn_w * inv_var
    bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var

    fused_conv_weight = conv_w * bn_w_div_var.view(-1, 1, 1, 1)
    fused_conv_bias = conv_b * bn_w_div_var + bn_bias_sub_rm_w_div_var

    # Create and return the fused layer
    fused_conv = nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        conv.kernel_size,
        conv.stride,
        conv.padding,
        bias=True
    )
    fused_conv.weight = nn.Parameter(fused_conv_weight)
    fused_conv.bias = nn.Parameter(fused_conv_bias)

    return fused_conv

def fuse_conv_transpose_bn(conv_transpose, bn):
    """
    Fuse ConvTranspose2d and BatchNorm2d layers.
    """
    # Extract conv transpose layer parameters
    conv_transpose_w = conv_transpose.weight
    conv_transpose_b = conv_transpose.bias if conv_transpose.bias is not None else torch.zeros_like(bn.running_mean)

    # Extract bn layer parameters
    bn_rm = bn.running_mean
    bn_rv = bn.running_var
    bn_eps = bn.eps
    bn_w = bn.weight
    bn_b = bn.bias

    # Calculate fused parameters
    inv_var = torch.rsqrt(bn_rv + bn_eps)
    bn_w_div_var = bn_w * inv_var
    bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var

    # print(conv_transpose_w.shape, bn_w_div_var.shape)
    fused_conv_transpose_weight = conv_transpose_w * bn_w_div_var.view(1, -1, 1, 1)
    fused_conv_transpose_bias = conv_transpose_b * bn_w_div_var + bn_bias_sub_rm_w_div_var

    # Create and return the fused layer
    fused_conv_transpose = nn.ConvTranspose2d(
        conv_transpose.in_channels,
        conv_transpose.out_channels,
        conv_transpose.kernel_size,
        conv_transpose.stride,
        conv_transpose.padding,
        conv_transpose.output_padding,
        groups=conv_transpose.groups,
        dilation=conv_transpose.dilation,
        bias=True
    )
    fused_conv_transpose.weight = nn.Parameter(fused_conv_transpose_weight)
    fused_conv_transpose.bias = nn.Parameter(fused_conv_transpose_bias)

    return fused_conv_transpose

if __name__ == '__main__':
    main()

This is the result of experiment.(Output in bash)

Total Total Variation After Training: tensor(3745.0908, device='cuda:0', grad_fn=<DivBackward0>)
Total Total Variation After Fuse: tensor(6456.5610, device='cuda:0', grad_fn=<DivBackward0>)
Total Total Variation After Permute: tensor(5958.5391, device='cuda:0', grad_fn=<DivBackward0>)
Initial accuracy:  68.39719835995744
Accuracy after fuse 68.39719835995744
Accuracy after permutation:  2.843016060281792

TSommariva commented 1 day ago
model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[])

try :

model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[('conv1.weight', 'in_channels'), ('fc.weight', 'out_channels'), ('fc.bias', 'out_channels')])

That's how they do in their code and it's working for me without any accuracy drop after the smoothing

Adamdad commented 1 day ago

@TSommariva Thank you for the answer. @secret-hammer This could be the problem, because we need to ignore the first layer and last layer, such that the order of the RGB channels and class names are not changed.