narger-ef / LowMemoryFHEResNet20

Source code for the paper "Encrypted Image Classification with Low Memory Footprint using Fully Homomorphic Encryption"
https://eprint.iacr.org/2024/460
MIT License
24 stars 7 forks source link

Scale factor of convbn and relu #14

Closed minnow54426 closed 1 month ago

minnow54426 commented 1 month ago

Very appreciate for your method of fitting the input of relu to [-1, 1], and I want to calculate the range myself to compare with the table in the bottom of page16, but the result appears to be different. Here is my result: init_layer_relu min value: -2.649200201034546 , max value: 3.3062973022460938 , scale factor: 0.30245313974658655 layer1_block0_relu0 min value: -3.8802597522735596 , max value: 2.5891757011413574 , scale factor: 0.2577147056750699 layer1_block0_relu1 min value: -3.4869346618652344 , max value: 5.6907958984375 , scale factor: 0.1757223449666445 layer1_block1_relu0 min value: -3.72189998626709 , max value: 2.535496473312378 , scale factor: 0.2686799762728064 layer1_block1_relu1 min value: -1.9287317991256714 , max value: 5.924444198608398 , scale factor: 0.16879220505357978 layer1_block2_relu0 min value: -3.0873782634735107 , max value: 2.1809089183807373 , scale factor: 0.32389941065236755 layer1_block2_relu1 min value: -2.9409966468811035 , max value: 5.998693466186523 , scale factor: 0.16670296717723732 layer2_block0_relu0 min value: -3.268646240234375 , max value: 3.3809189796447754 , scale factor: 0.2957775699508385 layer2_block0_relu1 min value: -3.1084036827087402 , max value: 4.885541915893555 , scale factor: 0.2046855839567804 layer2_block1_relu0 min value: -2.2072482109069824 , max value: 1.8167043924331665 , scale factor: 0.4530528080433188 layer2_block1_relu1 min value: -2.1643271446228027 , max value: 5.152408123016357 , scale factor: 0.1940840042412194 layer2_block2_relu0 min value: -2.6352195739746094 , max value: 2.7355856895446777 , scale factor: 0.36555243135755844 layer2_block2_relu1 min value: -2.8530592918395996 , max value: 7.528661727905273 , scale factor: 0.13282573133727893 layer3_block0_relu0 min value: -2.3090929985046387 , max value: 2.7618658542633057 , scale factor: 0.362074066144946 layer3_block0_relu1 min value: -3.213973045349121 , max value: 4.215287208557129 , scale factor: 0.23723175919542974 layer3_block1_relu0 min value: -2.548794746398926 , max value: 3.0651886463165283 , scale factor: 0.3262441942037438 layer3_block1_relu1 min value: -4.599428176879883 , max value: 7.215237617492676 , scale factor: 0.13859557411880555 layer3_block2_relu0 min value: -2.864499568939209 , max value: 2.170339822769165 , scale factor: 0.34910111729230364 layer3_block2_relu1 min value: -7.797353267669678 , max value: 18.783254623413086 , scale factor: 0.05323890987206833 And my code is followed, you can run it freely if you have torch, the inference is to make sure that the structure is correct.

# Build ResNet block by block
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

ResNet = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
ResNet.eval()

class ResNetPlain(torch.nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        # The frist value is minimal value and the second value is max value
        super().__init__(*args, **kwargs)
        self.relu_dict = {
            "init_layer_relu": [0, 0],
            "layer1_block0_relu0": [0, 0], 
            "layer1_block0_relu1": [0, 0],
            "layer1_block1_relu0": [0, 0], 
            "layer1_block1_relu1": [0, 0],
            "layer1_block2_relu0": [0, 0], 
            "layer1_block2_relu1": [0, 0],
            "layer2_block0_relu0": [0, 0], 
            "layer2_block0_relu1": [0, 0],
            "layer2_block1_relu0": [0, 0], 
            "layer2_block1_relu1": [0, 0],
            "layer2_block2_relu0": [0, 0], 
            "layer2_block2_relu1": [0, 0],
            "layer3_block0_relu0": [0, 0], 
            "layer3_block0_relu1": [0, 0],
            "layer3_block1_relu0": [0, 0], 
            "layer3_block1_relu1": [0, 0],
            "layer3_block2_relu0": [0, 0], 
            "layer3_block2_relu1": [0, 0],
        }
        self.relu = torch.nn.ReLU(inplace=True)
        self.avgPool = ResNet.avgpool
        self.fc = ResNet.fc
        # Init layer
        self.initLayer_conv = ResNet.conv1
        self.initLayer_bn = ResNet.bn1
        # Layer1 block0
        self.layer1_block0_conv1 = ResNet.layer1[0].conv1
        self.layer1_block0_bn1 = ResNet.layer1[0].bn1
        self.layer1_block0_conv2 = ResNet.layer1[0].conv2
        self.layer1_block0_bn2 = ResNet.layer1[0].bn2
        # Layer1 block1
        self.layer1_block1_conv1 = ResNet.layer1[1].conv1
        self.layer1_block1_bn1 = ResNet.layer1[1].bn1
        self.layer1_block1_conv2 = ResNet.layer1[1].conv2
        self.layer1_block1_bn2 = ResNet.layer1[1].bn2
        # Layer1 block2
        self.layer1_block2_conv1 = ResNet.layer1[2].conv1
        self.layer1_block2_bn1 = ResNet.layer1[2].bn1
        self.layer1_block2_conv2 = ResNet.layer1[2].conv2
        self.layer1_block2_bn2 = ResNet.layer1[2].bn2
        # Layer2 block0
        self.layer2_block0_conv1 = ResNet.layer2[0].conv1
        self.layer2_block0_bn1 = ResNet.layer2[0].bn1
        self.layer2_block0_conv2 = ResNet.layer2[0].conv2
        self.layer2_block0_bn2 = ResNet.layer2[0].bn2
        self.downsample0_conv = ResNet.layer2[0].downsample[0]
        self.downsample0_bn = ResNet.layer2[0].downsample[1]
        # Layer2 block1
        self.layer2_block1_conv1 = ResNet.layer2[1].conv1
        self.layer2_block1_bn1 = ResNet.layer2[1].bn1
        self.layer2_block1_conv2 = ResNet.layer2[1].conv2
        self.layer2_block1_bn2 = ResNet.layer2[1].bn2
        # Layer2 block2
        self.layer2_block2_conv1 = ResNet.layer2[2].conv1
        self.layer2_block2_bn1 = ResNet.layer2[2].bn1
        self.layer2_block2_conv2 = ResNet.layer2[2].conv2
        self.layer2_block2_bn2 = ResNet.layer2[2].bn2
        # Layer3 block0
        self.layer3_block0_conv1 = ResNet.layer3[0].conv1
        self.layer3_block0_bn1 = ResNet.layer3[0].bn1
        self.layer3_block0_conv2 = ResNet.layer3[0].conv2
        self.layer3_block0_bn2 = ResNet.layer3[0].bn2
        self.downsample1_conv = ResNet.layer3[0].downsample[0]
        self.downsample1_bn = ResNet.layer3[0].downsample[1]
        # Layer3 block1
        self.layer3_block1_conv1 = ResNet.layer3[1].conv1
        self.layer3_block1_bn1 = ResNet.layer3[1].bn1
        self.layer3_block1_conv2 = ResNet.layer3[1].conv2
        self.layer3_block1_bn2 = ResNet.layer3[1].bn2
        # Layer3 block2
        self.layer3_block2_conv1 = ResNet.layer3[2].conv1
        self.layer3_block2_bn1 = ResNet.layer3[2].bn1
        self.layer3_block2_conv2 = ResNet.layer3[2].conv2
        self.layer3_block2_bn2 = ResNet.layer3[2].bn2

    def update_min(self, location: str, input: torch.Tensor):
        if input.min().item() < self.relu_dict[location][0]:
            self.relu_dict[location][0] = input.min().item()

    def update_max(self, location: str, input: torch.Tensor):
        if input.max().item() > self.relu_dict[location][1]:
            self.relu_dict[location][1] = input.max().item()

    def update(self, location: str, input: torch.Tensor):
        self.update_min(location, input)
        self.update_max(location, input)

    def forward(self, x):
        # Init layer
        x = self.initLayer_conv(x)
        x = self.initLayer_bn(x)
        self.update("init_layer_relu", x)
        x = self.relu(x)
        # Layer1 block0
        x_copy = x
        x = self.layer1_block0_conv1(x)
        x = self.layer1_block0_bn1(x)
        self.update("layer1_block0_relu0", x)
        x = self.relu(x)
        x = self.layer1_block0_conv2(x)
        x = self.layer1_block0_bn2(x)
        x = x + x_copy
        self.update("layer1_block0_relu1", x)
        x = self.relu(x)
        # Layer1 block1
        x_copy = x
        x = self.layer1_block1_conv1(x)
        x = self.layer1_block1_bn1(x)
        self.update("layer1_block1_relu0", x)
        x = self.relu(x)
        x = self.layer1_block1_conv2(x)
        x = self.layer1_block1_bn2(x)
        x = x + x_copy
        self.update("layer1_block1_relu1", x)
        x = self.relu(x)
        # Layer1 block2
        x_copy = x
        x = self.layer1_block2_conv1(x)
        x = self.layer1_block2_bn1(x)
        self.update("layer1_block2_relu0", x)
        x = self.relu(x)
        x = self.layer1_block2_conv2(x)
        x = self.layer1_block2_bn2(x)
        x = x + x_copy
        self.update("layer1_block2_relu1", x)
        x = self.relu(x)
        # Layer2 block0
        x_copy = x
        x = self.layer2_block0_conv1(x)
        x = self.layer2_block0_bn1(x)
        self.update("layer2_block0_relu0", x)
        x = self.relu(x)
        x = self.layer2_block0_conv2(x)
        x = self.layer2_block0_bn2(x)
        x_copy = self.downsample0_conv(x_copy)
        x_copy = self.downsample0_bn(x_copy)
        x += x_copy
        self.update("layer2_block0_relu1", x)
        x = self.relu(x)
        # Layer2 block1
        x_copy = x
        x = self.layer2_block1_conv1(x)
        x = self.layer2_block1_bn1(x)
        self.update("layer2_block1_relu0", x)
        x = self.relu(x)
        x = self.layer2_block1_conv2(x)
        x = self.layer2_block1_bn2(x)
        x += x_copy
        self.update("layer2_block1_relu1", x)
        x = self.relu(x)
        # Layer2 block2
        x_copy = x
        x = self.layer2_block2_conv1(x)
        x = self.layer2_block2_bn1(x)
        self.update("layer2_block2_relu0", x)
        x = self.relu(x)
        x = self.layer2_block2_conv2(x)
        x = self.layer2_block2_bn2(x)
        x += x_copy
        self.update("layer2_block2_relu1", x)
        x = self.relu(x)
        # Layer3 block0
        x_copy = x
        x = self.layer3_block0_conv1(x)
        x = self.layer3_block0_bn1(x)
        self.update("layer3_block0_relu0", x)
        x = self.relu(x)
        x = self.layer3_block0_conv2(x)
        x = self.layer3_block0_bn2(x)
        x_copy = self.downsample1_conv(x_copy)
        x_copy = self.downsample1_bn(x_copy)
        x += x_copy
        self.update("layer3_block0_relu1", x)
        x = self.relu(x)
        # Layer3 block1
        x_copy = x
        x = self.layer3_block1_conv1(x)
        x = self.layer3_block1_bn1(x)
        self.update("layer3_block1_relu0", x)
        x = self.relu(x)
        x = self.layer3_block1_conv2(x)
        x = self.layer3_block1_bn2(x)
        x += x_copy
        self.update("layer3_block1_relu1", x)
        x = self.relu(x)
        # Layer2 block2
        x_copy = x
        x = self.layer3_block2_conv1(x)
        x = self.layer3_block2_bn1(x)
        self.update("layer3_block2_relu0", x)
        x = self.relu(x)
        x = self.layer3_block2_conv2(x)
        x = self.layer3_block2_bn2(x)
        x += x_copy
        self.update("layer3_block2_relu1", x)
        x = self.relu(x)
        # Final layer
        x = self.avgPool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def print(self):
        for key, value in self.relu_dict.items():
            abs_max = 0
            if abs(value[0]) > abs(value[1]):
                abs_max = abs(value[0])
            else:
                abs_max = abs(value[1])
            print(key, "  min value: ", value[0], ", max value: ", value[1], ", scale factor: ", 1 / abs_max)

ResNetPlainInstance = ResNetPlain()
ResNetPlainInstance.eval()

# Load CIFAR10 dataset
transforms = transforms.Compose([
    transforms.ToTensor(),
    # Special parameters(mean and standard deviation) for CIFAR10
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
testSet = torchvision.datasets.CIFAR10(root="./ResNet/data", train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testSet, batch_size = 64, shuffle=False)

# Function for top1 and top5 accuracy
def accuracy(output, target, topk=(1, )):
    with torch.no_grad():
        max_k = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(max_k, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            acc = (correct_k / batch_size).item()
            res.append(acc)
    return res

top1_acc = top5_acc = 0
# Inference
with torch.no_grad():
    for images, labels in tqdm(testloader, desc='Inference Progress', ncols=100):
        outputs = ResNetPlainInstance.forward(images)
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        top1_acc += acc1
        top5_acc += acc5

top1_acc = (top1_acc / len(testloader)) * 100
top5_acc = (top5_acc / len(testloader)) * 100

print(f'Top-1 Accuracy: {top1_acc:.2f}%')
print(f'Top-5 Accuracy: {top5_acc:.2f}%')

ResNetPlainInstance.print()
narger-ef commented 1 month ago

Are you checking the ranges with the test set?

minnow54426 commented 1 month ago

yes, with all the 10000 test images, when making testSet, i set train=False

narger-ef commented 1 month ago

Did you check this notebook? This is how we found the intervals

minnow54426 commented 1 month ago

Yes, it seems that we both loop over all the test images, but it's wired that we get different results. I open this issue because that, if the range of relu approximated by chebyshev is [-1, 1] and the scale factor does not fit, then the decode function will raise an error which says that libc++abi: terminating due to uncaught exception of type lbcrypto::OpenFHEException: /Users/lihao/code/c++/openfhe-development/src/pke/lib/encoding/ckkspackedencoding.cpp:l.537:Decode(): The decryption failed because the approximation error is too high. Check the parameters. Although I can set the range of relu to [-2, 2] to supress this, but this required an extra level. Another wired thing is, in layer1, if I only perform layer1 block0 convbn1 and relu, the answer is correct, but if layer1 block0 convbn2 and another relu is appended, the answer is wrong even when the range of relu is [-2, 2], is there any suggestions to solve this?

minnow54426 commented 1 month ago

ops, i found that there is a normalization in my transforms applied to images, after removing this normalization we get the same result, thanks!

minnow54426 commented 1 month ago

Hello, have you noticed that, without normalization, the performance of this model will downgrade? In more details, if normalization is used, the accuracy is Top-1 Accuracy: 92.61% Top-5 Accuracy: 99.81% which is equal to the result in this page: https://github.com/chenyaofo/pytorch-cifar-models/tree/master If we cancel normalization, the accuracy becomes Top-1 Accuracy: 32.67% Top-5 Accuracy: 76.49% Run the following code to reproduce these two result, with or without normalization in transforms:

# Measure accuracy of ResNet20 on CIFAR10
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

# Load ResNet model
ResNet = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
ResNet.eval()

# Load CIFAR10 dataset
transforms = transforms.Compose([
    transforms.ToTensor(),
    # Special parameters(mean and standard deviation) for CIFAR10
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
testSet = torchvision.datasets.CIFAR10(root="./ResNet/data", train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testSet, batch_size = 64, shuffle=False)

# Function for top1 and top5 accuracy
def accuracy(output, target, topk=(1, )):
    with torch.no_grad():
        max_k = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(max_k, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            acc = (correct_k / batch_size).item()
            res.append(acc)
    return res

top1_acc = top5_acc = 0
# Inference
with torch.no_grad():
    for images, labels in tqdm(testloader, desc='Inference Progress', ncols=100):
        outputs = ResNet(images)
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        top1_acc += acc1
        top5_acc += acc5

top1_acc = (top1_acc / len(testloader)) * 100
top5_acc = (top5_acc / len(testloader)) * 100

print(f'Top-1 Accuracy: {top1_acc:.2f}%')
print(f'Top-5 Accuracy: {top5_acc:.2f}%')
minnow54426 commented 1 month ago

@narger-ef

narger-ef commented 1 month ago

That's interesting, and expected since the network is trained with normalization :-) Thank you for the experiment!