zama-ai / concrete-ml

Concrete ML: Privacy Preserving ML framework using Fully Homomorphic Encryption (FHE), built on top of Concrete, with bindings to traditional ML frameworks.
Other
899 stars 133 forks source link

Feature request: Support Unfold torch operator #799

Open summer-xrx opened 1 month ago

summer-xrx commented 1 month ago

Hello! When I was running the concrete-ml library, I encountered a problem called "torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::col2im' to ONNX opset version 14 is not supported. Support for this operator was added in version 18, try exporting with this version". How can I solve this problem? Thank you for your help!

20240720230726
jfrery commented 1 month ago

Hi @summer-xrx,

Could you share some code how you got this error? E.g. the torch model in question?

The operator seems to be a reshaping operation. If you could rewrite it differently (e.g. using reshape) that could fix your problem.

Otherwise, once we know what operation you are trying to do we can create a corresponding issue to support it in concrete-ml.

summer-xrx commented 1 month ago

Hi, @jfrery, the specific code is as follows. Thank you for your help!

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from dataloader import train_loader, test_loader,test_dataset
import torch.nn.functional as F
from concrete.ml.torch.compile import compile_torch_model
import numpy as np
import time
from tqdm import tqdm

device = torch.device('cpu')

class CustomMaxPool(nn.Module):
    def __init__(self, kernel_size, stride=(1,1), padding=(0,0), dilation=(1,1)):
        super(CustomMaxPool, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
    def f1(self, x):
        return 10.8541842577442*x-62.2833925211098*x**3+114.369227820443*x**5-62.8023496973074*x**7
    def f2(self,x):
        return 4.13976170985111*x-5.84997640211679*x**3+2.94376255659280*x**5-0.454530437460152*x**7
    def f3(self,x):
        return 3.29956739043733*x-7.84227260291355*x**3+12.8907764115564*x**5-12.4917112584486*x**7+6.94167991428074*x**9-2.04298067399942*x**11+0.246407138926031*x**13
    def sign(self,x):
        return self.f3(self.f2(self.f1(x)))
    def custom_relu(self,x):
        return 0.125 * x**2 + 0.5 * x + 0.25

    def forward(self, x):
        unfolded = F.unfold(x, self.kernel_size, stride=self.stride, padding=self.padding)
        k = self.kernel_size[0] * self.kernel_size[1]
        unfolded = unfolded.view(x.size(0), x.size(1), k, -1)
        a, b = unfolded[:, :, ::2, :], unfolded[:, :, 1::2, :]
        tmp = b + self.custom_relu(a[:,:,:4,:] - b)
        tmp1, tmp2 = tmp[:, :, ::2, :], tmp[:, :, 1::2, :]
        tmp = tmp1 + self.custom_relu(tmp2-tmp1)
        tmp1, tmp2 = tmp[:, :, ::2, :], tmp[:, :, 1::2, :]
        tmp = tmp1 + self.custom_relu(tmp2-tmp1)
        custom_max = tmp + self.custom_relu(a[:,:,4:,:]-tmp)
        custom_max = custom_max.view(x.size(0), x.size(1), -1)
        output_size = ((x.size(2) - self.dilation[0]*(self.kernel_size[0]-1) + 2 * self.padding[0]-1) // self.stride[0] + 1, 
                       (x.size(3) - self.dilation[1]*(self.kernel_size[1]-1) + 2 * self.padding[1]-1) // self.stride[1] + 1)
        output = F.fold(custom_max, output_size, (1, 1), stride=1, padding=(0,0))
        return output

class ApproxSigmoid(nn.Module):
    def __init__(self):
        super(ApproxSigmoid,self).__init__()
    def forward(self,x):
        return 0.5+0.25*x-0.02083*x**3+0.00208*x**5-0.00019758*x**7+2.1356922398589065e-05*x**9

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 128, kernel_size=7, stride=2)
        self.conv2 = nn.Conv2d(128, 192, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(192, 128, kernel_size=3, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.fc1 = nn.Linear(2048,1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512,10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1, 1024*2)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            print(0)
            outputs = model(inputs)
            print(1)
            _, predicted = torch.max(outputs.data, 1)
            print(2)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the model on the 4000 test images: {100 * correct / total:.2f}%')

def train(model, train_loader, criterion, optimizer, scheduler, device, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 25 == 24:  
                print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0
        torch.save(model.state_dict(),f"./models/model_epoch{epoch}.pht")
        test(model, test_loader, device)
    scheduler.step()

model = CNN().to(device)
model.load_state_dict(torch.load('models/model_epoch3.pht'),strict=False)

test(model, test_loader, device)

train_features = []
train_labels = []

for inputs, labels in train_loader:
    train_features.append(inputs)
    train_labels.append(labels)

train_features = torch.cat(train_features) 
train_labels = torch.cat(train_labels)

x_train = train_features.to(device)
y_train = train_labels.to(device)

n_bits = 6

test_features = []
test_labels = []

for inputs, labels in test_loader:
    test_features.append(inputs)
    test_labels.append(labels)

test_features = torch.cat(test_features) 
test_labels = torch.cat(test_labels)

x_test = test_features.to(device)
y_test = test_labels.to(device)

q_module = compile_torch_model(model, x_train[:500, :], n_bits=n_bits,rounding_threshold_bits={"n_bits": n_bits, "method": "approximate"})

start_time = time.time()
accs = test_with_concrete(
    q_module,
    test_loader,
    use_sim=True,
)
sim_time = time.time() - start_time

print(f"Simulated FHE execution for {n_bits} bit network accuracy: {(100*accs):.2f}%")

t = time.time()
q_module.fhe_circuit.keygen()
print(f"Keygen time: {time.time()-t:.2f}s")

mini_test_dataset = TensorDataset(torch.Tensor(x_test[:100, :]), torch.Tensor(y_test[:100]))
mini_test_dataloader = DataLoader(mini_test_dataset)

t = time.time()
accuracy_test = test_with_concrete(
    q_module,
    mini_test_dataloader,
    use_sim=False,
)
elapsed_time = time.time() - t
time_per_inference = elapsed_time / len(mini_test_dataset)
accuracy_percentage = 100 * accuracy_test

print(
    f"Time per inference in FHE: {time_per_inference:.2f} "
    f"with {accuracy_percentage:.2f}% accuracy")
summer-xrx commented 1 month ago

By the way, we made some modifications to this code and its error message changed to the following: 054571110a9f59cfd8a44dca6028ae9

Our current problem is that the "unfold" operation depends on the "Range" operation, but the "Range" operation is not supported by the concrete-ml library. So how can we solve this problem? Thank you very much!

The neural network is as follows:

class CustomMaxPool(nn.Module):
    #def __init__(self, kernel_size, stride=None, padding=0):
    #    super(CustomMaxPool, self).__init__()
    #    self.kernel_size = kernel_size
    #    self.stride = stride
    #    self.padding = padding
    def __init__(self, kernel_size, stride=(1,1), padding=(0,0), dilation=(1,1)):
        super(CustomMaxPool, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def f1(self, x):
        return 10.8541842577442*x-62.2833925211098*x**3+114.369227820443*x**5-62.8023496973074*x**7
    def f2(self,x):
        return 4.13976170985111*x-5.84997640211679*x**3+2.94376255659280*x**5-0.454530437460152*x**7
    def f3(self,x):
        return 3.29956739043733*x-7.84227260291355*x**3+12.8907764115564*x**5-12.4917112584486*x**7+6.94167991428074*x**9-2.04298067399942*x**11+0.246407138926031*x**13
    def sign(self,x):
        return self.f3(self.f2(self.f1(x)))
    def custom_relu(self,x):
        #return F.relu(x)
        #return 0.125 * x**2 + 0.5 * x + 0.25
        #return x**2
        return (x+x*self.sign(x/(x**2+1)))/2

    def forward(self, x):
        #print(x.shape)
        # Unfold the input to get all sliding windows
        unfolded = F.unfold(x, self.kernel_size, stride=self.stride, padding=self.padding)
        # Reshape to get pairs of elements
        k = self.kernel_size[0] * self.kernel_size[1]
        unfolded = unfolded.view(x.size(0), x.size(1), k, -1)
        # Get a and b, assuming pairs of elements for simplicity
        a, b = unfolded[:, :, ::2, :], unfolded[:, :, 1::2, :]
        # Apply custom max operation
        tmp = b + self.custom_relu(a[:,:,:4,:] - b)
        tmp1, tmp2 = tmp[:, :, ::2, :], tmp[:, :, 1::2, :]
        tmp = tmp1 + self.custom_relu(tmp2-tmp1)
        tmp1, tmp2 = tmp[:, :, ::2, :], tmp[:, :, 1::2, :]
        tmp = tmp1 + self.custom_relu(tmp2-tmp1)
        #print("shape1:",a[:,:,4:,:].shape)
        #print("shape2:",tmp.shape)
        custom_max = tmp + self.custom_relu(a[:,:,4:,:]-tmp)
        #print(custom_max.shape)
        # Fold back to the original shape
        custom_max = custom_max.view(x.size(0), x.size(1), -1)
        output_size = ((x.size(2) - self.dilation[0]*(self.kernel_size[0]-1) + 2 * self.padding[0]-1) // self.stride[0] + 1, 
                       (x.size(3) - self.dilation[1]*(self.kernel_size[1]-1) + 2 * self.padding[1]-1) // self.stride[1] + 1)
        output = custom_max.view(x.size(0),x.size(1),output_size[0],output_size[1])
        return output

class ApproxReLU(nn.Module):
    def __init__(self):
        super(ApproxReLU,self).__init__()
    def f1(self, x):
        return 10.8541842577442*x-62.2833925211098*x**3+114.369227820443*x**5-62.8023496973074*x**7
    def f2(self,x):
        return 4.13976170985111*x-5.84997640211679*x**3+2.94376255659280*x**5-0.454530437460152*x**7
    def f3(self,x):
        return 3.29956739043733*x-7.84227260291355*x**3+12.8907764115564*x**5-12.4917112584486*x**7+6.94167991428074*x**9-2.04298067399942*x**11+0.246407138926031*x**13
    def sign(self,x):
        return self.f3(self.f2(self.f1(x)))
    def forward(self,x):
        #return F.relu(x)
        #return 0.125 * x**2 + 0.5 * x + 0.25
        #return x**2
        return (x+x*self.sign(x/(x**2+1)))/2

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #self.conv1 = nn.Conv2d(1, 128, kernel_size=7, stride=2, padding=3)
        #self.conv2 = nn.Conv2d(128, 192, kernel_size=3, stride=2, padding=1)
        #self.conv3 = nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1)
        #self.pool = CustomMaxPool(kernel_size=(3,3), stride=(2,2), padding=(1, 1))
        #self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv1 = nn.Conv2d(1, 128, kernel_size=7, stride=2)
        self.conv2 = nn.Conv2d(128, 192, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(192, 128, kernel_size=3, stride=1)
        self.pool = CustomMaxPool(kernel_size=(3,3), stride=(2,2))
        #self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2)
        #self.fc1 = nn.Linear(2048*4,1024)
        #self.fc1 = nn.Linear(2048,1024)
        self.fc1 = nn.Linear(2048,10)
        #self.fc2 = nn.Linear(1024, 512)
        #self.fc3 = nn.Linear(512,10)
        #self.sigmoid = ApproxSigmoid()
        #self.sigmoid = nn.Sigmoid()
        #self.relu = nn.ReLU()
        self.relu=ApproxReLU()

    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = self.conv3(x)
        x = self.relu(x/(x**2+1))
        #x = x.view(-1, 1024*8)
        x = x.view(-1, 1024*2)
        x = self.fc1(x)
        #x = self.relu(x/(x**2+1))
        #x = self.fc2(x)
        #x = self.fc3(x)
        return x
summer-xrx commented 1 month ago

Hi, @jfrery, we found that when compiling torch models using the concrete-ml library, the concrete-ml library first converts the torch model to an onnx model, and then converts the operators in onnx to the operators in concrete-ml. One of the conversions is from the “Unfold” operation to the “numpy_unfold” operation, but onnx itself does not define the “Unfold” operator. We conducted an experiment and found that “torch.onnx.export” does not convert the "Unfold" function to an "Unfold" operator in onnx, but rather to a section that contains a "Range" operator that cannot be recognized as a valid operator by concrete-ml. How should we solve this problem? Thank you very much!

The experiment code is as follows:

import torch
import torch.nn as nn
import onnx
from onnx import helper

class Test(nn.Module):
    def __init__(self):
        super(Test,self).__init__()
        self.ss = nn.Unfold(kernel_size=3, stride=2)

    def forward(self, x):
        return self.ss(x)

model = Test()
torch.onnx.export(model,torch.randn(32,128,32,32),"model.onnx")

model = onnx.load("model.onnx")
print(onnx.helper.printable_graph(model.graph))

The output of this code is as follows. It can be seen from the output that onnx does not support the "Unfold" operation, but support the "Range" operation. However, concrete-ml does not support the "Range" operation.

graph main_graph (
  %input[FLOAT, 32x128x32x32]
) {
  %/ss/Shape_output_0 = Shape(%input)
  %/ss/Constant_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Gather_output_0 = Gather[axis = 0](%/ss/Shape_output_0, %/ss/Constant_output_0)
  %/ss/Shape_1_output_0 = Shape(%input)
  %/ss/Constant_1_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Gather_1_output_0 = Gather[axis = 0](%/ss/Shape_1_output_0, %/ss/Constant_1_output_0)
  %/ss/Constant_2_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Add_output_0 = Add(%/ss/Gather_output_0, %/ss/Constant_2_output_0)
  %/ss/Constant_3_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Sub_output_0 = Sub(%/ss/Add_output_0, %/ss/Constant_3_output_0)
  %/ss/Constant_4_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Constant_5_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Range_output_0 = Range(%/ss/Constant_4_output_0, %/ss/Sub_output_0, %/ss/Constant_5_output_0)
  %/ss/Constant_6_output_0 = Constant[value = <Tensor>]()
  %/ss/Unsqueeze_output_0 = Unsqueeze(%/ss/Range_output_0, %/ss/Constant_6_output_0)
  %/ss/Constant_7_output_0 = Constant[value = <Tensor>]()
  %/ss/Add_1_output_0 = Add(%/ss/Unsqueeze_output_0, %/ss/Constant_7_output_0)
  %/ss/Constant_8_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Add_2_output_0 = Add(%/ss/Gather_1_output_0, %/ss/Constant_8_output_0)
  %/ss/Constant_9_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Sub_1_output_0 = Sub(%/ss/Add_2_output_0, %/ss/Constant_9_output_0)
  %/ss/Constant_10_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Constant_11_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Range_1_output_0 = Range(%/ss/Constant_10_output_0, %/ss/Sub_1_output_0, %/ss/Constant_11_output_0)
  %/ss/Constant_12_output_0 = Constant[value = <Tensor>]()
  %/ss/Unsqueeze_1_output_0 = Unsqueeze(%/ss/Range_1_output_0, %/ss/Constant_12_output_0)
  %/ss/Constant_13_output_0 = Constant[value = <Tensor>]()
  %/ss/Add_3_output_0 = Add(%/ss/Unsqueeze_1_output_0, %/ss/Constant_13_output_0)
  %/ss/Shape_2_output_0 = Shape(%input)
  %/ss/Constant_14_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Gather_2_output_0 = Gather[axis = 0](%/ss/Shape_2_output_0, %/ss/Constant_14_output_0)
  %/ss/Shape_3_output_0 = Shape(%input)
  %/ss/Constant_15_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Gather_3_output_0 = Gather[axis = 0](%/ss/Shape_3_output_0, %/ss/Constant_15_output_0)
  %/ss/Constant_16_output_0 = Constant[value = <Scalar Tensor []>]()
  %/ss/Mul_output_0 = Mul(%/ss/Gather_3_output_0, %/ss/Constant_16_output_0)
  %/ss/Constant_17_output_0 = Constant[value = <Tensor>]()
  %/ss/Unsqueeze_2_output_0 = Unsqueeze(%/ss/Gather_2_output_0, %/ss/Constant_17_output_0)
  %/ss/Constant_18_output_0 = Constant[value = <Tensor>]()
  %/ss/Unsqueeze_3_output_0 = Unsqueeze(%/ss/Mul_output_0, %/ss/Constant_18_output_0)
  %/ss/Constant_19_output_0 = Constant[value = <Tensor>]()
  %/ss/Concat_output_0 = Concat[axis = 0](%/ss/Unsqueeze_2_output_0, %/ss/Unsqueeze_3_output_0, %/ss/Constant_19_output_0)
  %/ss/Constant_20_output_0 = Constant[value = <Tensor>]()
  %/ss/Pad_output_0 = Pad(%input, %/ss/Constant_20_output_0)
  %/ss/Gather_4_output_0 = Gather[axis = 2](%/ss/Pad_output_0, %/ss/Add_1_output_0)
  %/ss/Gather_5_output_0 = Gather[axis = 4](%/ss/Gather_4_output_0, %/ss/Add_3_output_0)
  %/ss/Transpose_output_0 = Transpose[perm = [0, 1, 2, 4, 3, 5]](%/ss/Gather_5_output_0)
  %52 = Reshape[allowzero = 0](%/ss/Transpose_output_0, %/ss/Concat_output_0)
  return %52
}
jfrery commented 1 month ago

Yes you are right. For now concrete-ml does not support the unfold operator. Let's convert your issue into a Feature request for unfold support.

A workaround for you could be this to replace unfold by manual shape transformation:

import torch
import torch.nn as nn
import onnx
from onnx import helper

class Test(nn.Module):
    def __init__(self):
        super(Test,self).__init__()
        self.kernel_size = 3
        self.stride = 2

    def unfold(self, x):
        batch_size, channels, height, width = x.shape
        kernel_size = self.kernel_size
        stride = self.stride

        # Calculate output dimensions
        out_height = (height - kernel_size) // stride + 1
        out_width = (width - kernel_size) // stride + 1

        # Create a list to store patches
        patches = []

        # Use loops to extract patches
        for i in range(out_height):
            for j in range(out_width):
                h_start = i * stride
                w_start = j * stride
                patch = x[:, :, h_start:h_start+kernel_size, w_start:w_start+kernel_size]
                patches.append(patch.reshape(batch_size, -1))

        # Stack patches along the last dimension
        output = torch.stack(patches, dim=-1)

        return output

    def forward(self, x):
        return self.unfold(x)

model = Test()
torch.onnx.export(model,torch.randn(32,128,32,32),"model.onnx")

from concrete.ml.torch.compile import compile_torch_model

torch_inputset = torch.randn(32,128,32,32)
q_module = compile_torch_model(model, torch_inputset=torch_inputset, n_bits=6, rounding_threshold_bits=6)
summer-xrx commented 1 month ago

Hello, @jfrery, Thank you very much for your generous help! Your method has indeed solved the problem, and the error no longer occurs. But now we are facing another problem. Due to too many "for" loops in the "unfold" function you wrote, the efficiency in compiling our model using the "compile_torch_model" function is very low, and we couldn't compile it in a timely manner. May I ask how to solve this problem?

S1eepeng commented 1 month ago

Hello, @jfrery, We 've defined a ApproRelu() in order to reduce the PBC through Linear approximation. As a result,the PBC of q_module is 0 when n_bits=6.Our classification model is used to deal with a 10-class problem,and we've verified that the accuracy of using the ApproRelu() instead of nn.Relu() in plain text does not drop much.But the accuracy of q_module is only 10%,we guess that this is because the compiled model has some problems when processing the ApproRelu() function? The ApproRelu() function is as below:

class ApproxReLU(nn.Module):
    def __init__(self):
        super(ApproxReLU,self).__init__()
    def f1(self, x):
        return 10.8541842577442*x-62.2833925211098*x**3+114.369227820443*x**5-62.8023496973074*x**7
    def f2(self,x):
        return 4.13976170985111*x-5.84997640211679*x**3+2.94376255659280*x**5-0.454530437460152*x**7
    def f3(self,x):
        return 3.29956739043733*x-7.84227260291355*x**3+12.8907764115564*x**5-12.4917112584486*x**7+6.94167991428074*x**9-2.04298067399942*x**11+0.246407138926031*x**13
    def sign(self,x):
        return self.f3(self.f2(self.f1(x)))
    def forward(self,x):
        #return F.relu(x)
        #return 0.125 * x**2 + 0.5 * x + 0.25
        #return x**2
        return (x+x*self.sign(x/(x**2+1)))/2
jfrery commented 1 month ago

Hello, @jfrery, Thank you very much for your generous help! Your method has indeed solved the problem, and the error no longer occurs. But now we are facing another problem. Due to too many "for" loops in the "unfold" function you wrote, the efficiency in compiling our model using the "compile_torch_model" function is very low, and we couldn't compile it in a timely manner. May I ask how to solve this problem?

unfortunately that's a problem we also face sometimes when we have too many loops. We don't have a solution to this yet. That being said compilation should be a one time computation so less of a problem than actual FHE execution being long.

jfrery commented 1 month ago

Hello, @jfrery, We 've defined a ApproRelu() in order to reduce the PBC through Linear approximation. As a result,the PBC of q_module is 0 when n_bits=6.Our classification model is used to deal with a 10-class problem,and we've verified that the accuracy of using the ApproRelu() instead of nn.Relu() in plain text does not drop much.But the accuracy of q_module is only 10%,we guess that this is because the compiled model has some problems when processing the ApproRelu() function? The ApproRelu() function is as below:

class ApproxReLU(nn.Module):
    def __init__(self):
        super(ApproxReLU,self).__init__()
    def f1(self, x):
        return 10.8541842577442*x-62.2833925211098*x**3+114.369227820443*x**5-62.8023496973074*x**7
    def f2(self,x):
        return 4.13976170985111*x-5.84997640211679*x**3+2.94376255659280*x**5-0.454530437460152*x**7
    def f3(self,x):
        return 3.29956739043733*x-7.84227260291355*x**3+12.8907764115564*x**5-12.4917112584486*x**7+6.94167991428074*x**9-2.04298067399942*x**11+0.246407138926031*x**13
    def sign(self,x):
        return self.f3(self.f2(self.f1(x)))
    def forward(self,x):
        #return F.relu(x)
        #return 0.125 * x**2 + 0.5 * x + 0.25
        #return x**2
        return (x+x*self.sign(x/(x**2+1)))/2

What do you call a PBC? Do you mean PBS? If so, approximating the relu with polynomials won't help. This is because x^N is a PBS so you are actually doing more PBS than using nn.Relu() which should be a single PBS per value.

About why the accuracy drop, I am not too sure but one problem I see if that you add a x^7 with a x. I doubt x as any value when quantized since x^7 must be pretty high. So containing x and x^7 on 2^6 values is probably not possible.

If you mean something else by PBC let me know!

S1eepeng commented 1 month ago

Thanks for your reply.It should be PBS yes,just a typo.