Open secret-hammer opened 4 days ago
@secret-hammer Hello. I guess the problem is that you need to do BN fuse (fuse the BN parameter into linear layer) before do permutation. We actually do not directly handle the BN layer, but fuse it into linear layer and permute the linear layer instead.
@Adamdad Thank you for your answer! Is there any tool or code available to integrate the BN layer parameters into the linear layer? Thanks again.
This should work
def fuse_module(module):
"""
Recursively fuse all batch normalization layers in the module with their preceding convolutional layers.
"""
children = list(module.named_children())
prev_name, prev_module = None, None
for name, child in children:
# print(name)
if isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.Conv2d):
# Fuse the conv and bn layers, replace the conv layer with the fused layer
fused_conv = fuse_conv_bn(prev_module, child)
module._modules[prev_name] = fused_conv
# Remove the batch normalization layer
module._modules[name] = nn.Identity()
elif isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.ConvTranspose2d):
# Fuse the conv and bn layers, replace the conv layer with the fused layer
fused_conv = fuse_conv_transpose_bn(prev_module, child)
module._modules[prev_name] = fused_conv
# Remove the batch normalization layer
module._modules[name] = nn.Identity()
else:
# Recursively apply to all submodules
fuse_module(child)
prev_name, prev_module = name, child
Thank you for the reply. Could you provide the related code for the two core functions, fuse_conv_bn and fuse_conv_transpose_bn?
import torch
import torch.nn as nn
# test()
def fuse_conv_bn(conv, bn):
"""
Fuse convolution and batch normalization layers
"""
# Extract conv layer parameters
conv_w = conv.weight
conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean)
# Extract bn layer parameters
bn_rm = bn.running_mean
bn_rv = bn.running_var
bn_eps = bn.eps
bn_w = bn.weight
bn_b = bn.bias
# Calculate fused parameters
inv_var = torch.rsqrt(bn_rv + bn_eps)
bn_w_div_var = bn_w * inv_var
bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var
fused_conv_weight = conv_w * bn_w_div_var.view(-1, 1, 1, 1)
fused_conv_bias = conv_b * bn_w_div_var + bn_bias_sub_rm_w_div_var
# Create and return the fused layer
fused_conv = nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True
)
fused_conv.weight = nn.Parameter(fused_conv_weight)
fused_conv.bias = nn.Parameter(fused_conv_bias)
return fused_conv
def fuse_conv_transpose_bn(conv_transpose, bn):
"""
Fuse ConvTranspose2d and BatchNorm2d layers.
"""
# Extract conv transpose layer parameters
conv_transpose_w = conv_transpose.weight
conv_transpose_b = conv_transpose.bias if conv_transpose.bias is not None else torch.zeros_like(bn.running_mean)
# Extract bn layer parameters
bn_rm = bn.running_mean
bn_rv = bn.running_var
bn_eps = bn.eps
bn_w = bn.weight
bn_b = bn.bias
# Calculate fused parameters
inv_var = torch.rsqrt(bn_rv + bn_eps)
bn_w_div_var = bn_w * inv_var
bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var
# print(conv_transpose_w.shape, bn_w_div_var.shape)
fused_conv_transpose_weight = conv_transpose_w * bn_w_div_var.view(1, -1, 1, 1)
fused_conv_transpose_bias = conv_transpose_b * bn_w_div_var + bn_bias_sub_rm_w_div_var
# Create and return the fused layer
fused_conv_transpose = nn.ConvTranspose2d(
conv_transpose.in_channels,
conv_transpose.out_channels,
conv_transpose.kernel_size,
conv_transpose.stride,
conv_transpose.padding,
conv_transpose.output_padding,
groups=conv_transpose.groups,
dilation=conv_transpose.dilation,
bias=True
)
fused_conv_transpose.weight = nn.Parameter(fused_conv_transpose_weight)
fused_conv_transpose.bias = nn.Parameter(fused_conv_transpose_bias)
return fused_conv_transpose
Thank you so much! I'll give it a try as soon as I can. I really appreciate your help and patient explanation!
I modified the ResNet18 model to include a fusion of the Batch Normalization (BN) parameters into the preceding convolution layers. After performing this fusion, I observed the following issues:
tv_loss
) significantly increased, which was unexpected.I have included the new experimental code below, which outlines the steps I followed. It includes the fusion of BatchNorm layers, the calculation of total variation loss before and after fusion, and the application of permutations to the model.
import loaders
import torch
import torch.nn as nn
import metrics
from utils.train_util import AverageMeter
from utils import permute
def test_model_base(test_loader, model_base, device):
model_base.eval()
acc_meter = AverageMeter('Acc', ':6.2f')
for i, samples in enumerate(test_loader):
inputs, labels, _ = samples
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(False):
outputs = model_base(inputs)
acc = metrics.accuracy(outputs, labels)[0]
acc_meter.update(acc.item(), inputs.size(0))
return acc_meter
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load('/nfs196/wjx/projects/PP/outputs/rn18_OH_Ar_base_10/Art/model.pt', weights_only=True)
model = models.load_model('rn18', num_classes=65)
model.load_state_dict(state_dict['last_param'])
model = model.to(device)
model.eval()
test_loader = loaders.load_images("/nfs196/hjc/datasets/Office-Home/Art", 'Office-Home', data_type='test', batch_size=512)
acc_init = test_model_base(test_loader, model, device)
# Compute the total variation loss for the network
total_tv = permute.compute_tv_loss_for_network(model, lambda_tv=1.0)
print("Total Total Variation After Training:", total_tv)
fuse_module(model)
acc_fuse = test_model_base(test_loader, model, device)
total_tv = permute.compute_tv_loss_for_network(model, lambda_tv=1.0)
print("Total Total Variation After Fuse:", total_tv)
# Apply permutations to the model's layers and check the total variation
input_tensor = torch.randn(1, 3, 224, 224).to(device)
permute_func = permute.PermutationManager(model, input_tensor)
# Compute the permutation matrix for each clique graph, save as a dict
permute_dict = permute_func.compute_permute_dict()
# Apply permutation to the weight
model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[])
total_tv = permute.compute_tv_loss_for_network(model_permute, lambda_tv=1.0)
print("Total Total Variation After Permute:", total_tv)
acc_permute = test_model_base(test_loader, model_permute, device)
print("Initial accuracy: ", acc_init.avg)
print("Accuracy after fuse", acc_fuse.avg)
print("Accuracy after permutation: ", acc_permute.avg)
def fuse_module(module):
"""
Recursively fuse all batch normalization layers in the module with their preceding convolutional layers.
"""
children = list(module.named_children())
prev_name, prev_module = None, None
for name, child in children:
# print(name)
if isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.Conv2d):
# Fuse the conv and bn layers, replace the conv layer with the fused layer
fused_conv = fuse_conv_bn(prev_module, child)
module._modules[prev_name] = fused_conv
# Remove the batch normalization layer
module._modules[name] = nn.Identity()
elif isinstance(child, nn.BatchNorm2d) and isinstance(prev_module, nn.ConvTranspose2d):
# Fuse the conv and bn layers, replace the conv layer with the fused layer
fused_conv = fuse_conv_transpose_bn(prev_module, child)
module._modules[prev_name] = fused_conv
# Remove the batch normalization layer
module._modules[name] = nn.Identity()
else:
# Recursively apply to all submodules
fuse_module(child)
prev_name, prev_module = name, child
def fuse_conv_bn(conv, bn):
"""
Fuse convolution and batch normalization layers
"""
# Extract conv layer parameters
conv_w = conv.weight
conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean)
# Extract bn layer parameters
bn_rm = bn.running_mean
bn_rv = bn.running_var
bn_eps = bn.eps
bn_w = bn.weight
bn_b = bn.bias
# Calculate fused parameters
inv_var = torch.rsqrt(bn_rv + bn_eps)
bn_w_div_var = bn_w * inv_var
bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var
fused_conv_weight = conv_w * bn_w_div_var.view(-1, 1, 1, 1)
fused_conv_bias = conv_b * bn_w_div_var + bn_bias_sub_rm_w_div_var
# Create and return the fused layer
fused_conv = nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True
)
fused_conv.weight = nn.Parameter(fused_conv_weight)
fused_conv.bias = nn.Parameter(fused_conv_bias)
return fused_conv
def fuse_conv_transpose_bn(conv_transpose, bn):
"""
Fuse ConvTranspose2d and BatchNorm2d layers.
"""
# Extract conv transpose layer parameters
conv_transpose_w = conv_transpose.weight
conv_transpose_b = conv_transpose.bias if conv_transpose.bias is not None else torch.zeros_like(bn.running_mean)
# Extract bn layer parameters
bn_rm = bn.running_mean
bn_rv = bn.running_var
bn_eps = bn.eps
bn_w = bn.weight
bn_b = bn.bias
# Calculate fused parameters
inv_var = torch.rsqrt(bn_rv + bn_eps)
bn_w_div_var = bn_w * inv_var
bn_bias_sub_rm_w_div_var = bn_b - bn_rm * bn_w_div_var
# print(conv_transpose_w.shape, bn_w_div_var.shape)
fused_conv_transpose_weight = conv_transpose_w * bn_w_div_var.view(1, -1, 1, 1)
fused_conv_transpose_bias = conv_transpose_b * bn_w_div_var + bn_bias_sub_rm_w_div_var
# Create and return the fused layer
fused_conv_transpose = nn.ConvTranspose2d(
conv_transpose.in_channels,
conv_transpose.out_channels,
conv_transpose.kernel_size,
conv_transpose.stride,
conv_transpose.padding,
conv_transpose.output_padding,
groups=conv_transpose.groups,
dilation=conv_transpose.dilation,
bias=True
)
fused_conv_transpose.weight = nn.Parameter(fused_conv_transpose_weight)
fused_conv_transpose.bias = nn.Parameter(fused_conv_transpose_bias)
return fused_conv_transpose
if __name__ == '__main__':
main()
This is the result of experiment.(Output in bash)
Total Total Variation After Training: tensor(3745.0908, device='cuda:0', grad_fn=<DivBackward0>)
Total Total Variation After Fuse: tensor(6456.5610, device='cuda:0', grad_fn=<DivBackward0>)
Total Total Variation After Permute: tensor(5958.5391, device='cuda:0', grad_fn=<DivBackward0>)
Initial accuracy: 68.39719835995744
Accuracy after fuse 68.39719835995744
Accuracy after permutation: 2.843016060281792
model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[])
try :
model_permute = permute_func.apply_permutations(permute_dict, ignored_keys=[('conv1.weight', 'in_channels'), ('fc.weight', 'out_channels'), ('fc.bias', 'out_channels')])
That's how they do in their code and it's working for me without any accuracy drop after the smoothing
@TSommariva Thank you for the answer. @secret-hammer This could be the problem, because we need to ignore the first layer and last layer, such that the order of the RGB channels and class names are not changed.
I'm attempting to reproduce the results of applying permutation-based parameter smoothing to the ResNet-18 model, but I'm seeing a significant performance drop after the permutation step. Below are the details of the issue, including the code I used and the experimental results.
Experiment Details:
model.pt
).test_model_base
.PermutationManager
.Code:
Experiment Code:
ResNet-18 Model Code:
Observations:
Here is the output from the experiment:
Issue:
The accuracy after permutation is unexpectedly much lower than the initial accuracy. The total variation loss decreases after smoothing, but the accuracy drops significantly, which is not the expected behavior. Is there any issue with the permutation procedure or its configuration that could cause such a drastic drop in performance?
Thank you for your help!
Best,
[Secret Hammer]