data61 / MP-SPDZ

Versatile framework for multi-party computation
Other
899 stars 278 forks source link

How to implement ResNet? #1076

Closed voidwings closed 2 months ago

voidwings commented 1 year ago

I try to write Resnet18 as this:

program.options_from_args()

from Compiler import ml

try:
    ml.set_n_threads(int(program.args[2]))
except:
    pass

get_data = lambda train, transform=None: torchvision.datasets.CIFAR10(
    root='/tmp', train=train, download=True, transform=transform)

import torchvision, numpy
data = []
for train in True, False:
    ds = get_data(train)
    # normalize to [-1,1] before input
    samples = sfix.input_tensor_via(0, ds.data / 255 * 2 - 1, binary=True)
    labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
    data += [(labels, samples)]

(training_labels, training_samples), (test_labels, test_samples) = data

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        output = F.relu(self.bn1(output))
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)

class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))

        self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
                                    RestNetBasicBlock(512, 512, 1))

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out

net = ResNet18()

# train for a bit
transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
ds = get_data(train=True, transform=transform)
optimizer = torch.optim.Adam(net.parameters(), amsgrad=True)
criterion = nn.CrossEntropyLoss()

for i, data in enumerate(torch.utils.data.DataLoader(ds, batch_size=128)):
    inputs, labels = data
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    ds = get_data(False, transform)
    total = correct_classified = 0
    for data in torch.utils.data.DataLoader(ds, batch_size=128):
        inputs, labels = data
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct_classified += (predicted == labels).sum().item()
    test_acc = (100 * correct_classified / total)
    print('Cleartext test accuracy of the network: %.2f %%' % test_acc)

layers = ml.layers_from_torch(net, training_samples.shape, 128, input_via=0)

optimizer = ml.SGD(layers)

optimizer.fit(
    training_samples,
    training_labels,
    epochs=int(1),
    batch_size=128,
    validation_data=(test_samples, test_labels),
    program=program,
    reset=False
)

The error is CompilerError: unknown PyTorch module: ResNet18. It seems I can't pass a self-defined module to the Compiler. Is there any example of ResNet18 inference in MP-SPDZ?

mkskeller commented 1 year ago

The PyTorch interface only supports sequential networks, but ResNet contains an addition and thus isn't sequential. We have implemented ResNet-50 inference, which you can run as follows from the MP-SPDZ root directory:

git clone https://github.com/mkskeller/EzPC
cd EzPC/Athos/Networks/ResNet
axel -a -n 5 -c --output ./PreTrainedModel http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
cd PreTrainedModel && tar -xvzf resnet_v2_fp32_savedmodel_NHWC.tar.gz && cd ..
python3 ResNet_main.py --runPrediction True --scalingFac 12 --saveImgAndWtData True
cd ../../../..
Scripts/fixed-rep-to-float.py EzPC/Athos/Networks/ResNet/ResNet_img_input.inp
Scripts/compile-emulate.py tf EzPC/Athos/Networks/ResNet/graphDef.bin 8

You can change the last line to the compile-run.sh -E <protocol>.

AliceNCsyuk commented 1 year ago

I have implemented a simple training code for residual blocks in ml.py, and I hope it may bring some motivations for you. If anyone has implemented a complete ResNet training, I am really looking forward to it being open source.

class SimpleRes_Linear(DenseBase):

def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):

    if activation == 'id':
        self.activation_layer = None
    elif activation == 'relu':
        self.activation_layer = Relu([N, d, d_out])
    elif activation == 'square':
        self.activation_layer = Square([N, d, d_out])
    else:
        raise CompilerError('activation not supported: %s', activation)

    self.N = N
    self.d_in = d_in
    self.d_out = d_out
    self.d = d
    self.activation = activation
    self.X = MultiArray([N, d, d_in], sfix)
    self.Y = MultiArray([N, d, d_out], sfix)
    self.W = Tensor([d_in, d_out], sfix)
    self.b = sfix.Array(d_out)
    back_N = min(N, self.back_batch_size)
    self.nabla_Y = MultiArray([back_N, d, d_out], sfix)
    self.nabla_X = MultiArray([back_N, d, d_in], sfix)
    self.nabla_W = sfix.Matrix(d_in, d_out)
    self.nabla_b = sfix.Array(d_out)
    self.debug = debug
    l = self.activation_layer

    if l:
        self.f_input = l.X
        l.Y = self.Y
        l.nabla_Y = self.nabla_Y
    else:
        self.f_input = self.Y

def __repr__(self):
    return '%s(%s, %s, %s, activation=%s)' % \
        (type(self).__name__, self.N, self.d_in,
         self.d_out, repr(self.activation))

def reset(self):
    d_in = self.d_in
    d_out = self.d_out
    r = math.sqrt(6.0 / (d_in + d_out))
    print('Initializing dense weights in [%f,%f]' % (-r, r))
    self.W.randomize(-r, r)
    self.b.assign_all(0)

def input_from(self, player, raw=False):
    self.W.input_from(player, raw=raw)
    if self.input_bias:
        self.b.input_from(player, raw=raw)

def compute_f_input(self, batch):
    N = len(batch)
    assert self.d == 1
    if self.input_bias:
        prod = MultiArray([N, self.d, self.d_out], sfix)
    else:
        prod = self.f_input
    max_size = program.Program.prog.budget // self.d_out

    @multithread(self.n_threads, N, max_size)
    def _(base, size):
        X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
        prod.assign_part_vector(
            X_sub.direct_mul(self.W, indices=(
                batch.get_vector(base, size), regint.inc(self.d_in),
                regint.inc(self.d_in), regint.inc(self.d_out))), base)

    if self.input_bias:
        if self.d_out == 1:
            @multithread(self.n_threads, N)
            def _(base, size):
                v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)+self.X.expand_to_vector(0, size)
                self.f_input.assign_vector(v, base)
        else:
            @for_range_multithread(self.n_threads, 100, N)
            def _(i):
                v = prod[i].get_vector() + self.b.get_vector() + self.X.get_vector()
                self.f_input[i].assign_vector(v)
    progress('f input')

def _forward(self, batch=None):
    if batch is None:
        batch = regint.Array(self.N)
        batch.assign(regint.inc(self.N))
    self.compute_f_input(batch=batch)
    if self.activation_layer:
        self.activation_layer.forward(batch)
    if self.debug_output:
        print_ln('dense X %s', self.X.reveal_nested())
        print_ln('dense W %s', self.W.reveal_nested())
        print_ln('dense b %s', self.b.reveal_nested())
        print_ln('dense Y %s', self.Y.reveal_nested())
    if self.debug:
        limit = self.debug
        @for_range_opt(len(batch))
        def _(i):
            @for_range_opt(self.d_out)
            def _(j):
                to_check = self.Y[i][0][j].reveal()
                check = to_check > limit

                @if_(check)
                def _():
                    print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
                    print_ln('X %s', self.X[i].reveal_nested())
                    print_ln('W %s',
                             [self.W[k][j].reveal() for k in range(self.d_in)])

def backward(self, compute_nabla_X=True, batch=None):
    N = len(batch)
    d = self.d
    d_out = self.d_out
    X = self.X
    Y = self.Y
    W = self.W
    b = self.b
    nabla_X = self.nabla_X
    nabla_Y = self.nabla_Y
    nabla_W = self.nabla_W
    nabla_b = self.nabla_b

    if self.activation_layer:
        self.activation_layer.backward(batch)
        f_schur_Y = self.activation_layer.nabla_X
    else:
        f_schur_Y = nabla_Y

    if compute_nabla_X:
        @multithread(self.n_threads, N)
        def _(base, size):
            B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
            nabla_X.assign_part_vector(
                B.direct_mul_trans(W, indices=(regint.inc(size, base),
                                               regint.inc(self.d_out),
                                               regint.inc(self.d_out),
                                               regint.inc(self.d_in))),
                base)
            nabla_X[:]+=sfix.from_sint(1)
            print('res')

        if self.print_random_update:
            print_ln('backward %s', self)
            index = regint.get_random(64) % self.nabla_X.total_size()
            print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
                     index, self.nabla_X.to_array()[index].reveal())

        progress('nabla X')

    self.backward_params(f_schur_Y, batch=batch)
mkskeller commented 2 months ago

You should find that version 0.3.9 now supports using non-sequential PyTorch networks: https://github.com/data61/MP-SPDZ/blob/master/Programs/Source/torch_resnet.py