zama-ai / concrete-ml

Concrete ML: Privacy Preserving ML framework using Fully Homomorphic Encryption (FHE), built on top of Concrete, with bindings to traditional ML frameworks.
Other
909 stars 134 forks source link

Making a fast satellite image classification use case [was: Some questions about GPU acceleration] #839

Open summer-xrx opened 3 weeks ago

summer-xrx commented 3 weeks ago

Hello, I am very curious about GPU acceletarion. Now, I have the following questions:

  1. Why does the concrete library support GPU acceleration, while other mainstream libraries such as google/fully-homomorphic-encryption, SEAL, OpenFHE, TenSEAL (based on SEAL), and secretflow do not support GPU acceleration?
  2. Compared to the aforementioned libraries, what are the advantages of the concrete library?
  3. Why can't GPU acceleration be implemented in those aforementioned libraries now? What are the technical challenges of using GPU to accelerate homomorphic encryption?
  4. Why does the concrete library implement GPU acceleration while concrete-ml does not? Are there any technical challenges in implementing GPU acceleration in homomorphic machine learning? Looking forward to your answer, thank you!
bcm-at-zama commented 3 weeks ago

Hey,

  1. Why does the concrete library support GPU acceleration, while other mainstream libraries such as google/fully-homomorphic-encryption, SEAL, OpenFHE, TenSEAL (based on SEAL), and secretflow do not support GPU acceleration?

Well, we can only answer for Zama and its libraries. For questions to other products, you should go to their support, or maybe to discord.fhe.org in their respective channel.

Having a GPU acceleration is for sure a good thing, since we can make it faster than on CPU. At Zama, we are at the beginning for what's related to Concrete and Concrete ML and already, we see very significant improvements.

  1. Compared to the aforementioned libraries, what are the advantages of the concrete library?

Hard to summarize in a few sentences. We have very easy tools to use, for developers to build privacy in their app without knowing anything in cryptography. Same API than Torch and scikit-learn, which is convenient for users. Exact computations thanks to TFHE, as opposed to other libraries which use CKKS. Open source, everything which is claimed can be reproduced on your side. To name a few.

  1. Why can't GPU acceleration be implemented in those aforementioned libraries now? What are the technical challenges of using GPU to accelerate homomorphic encryption?

Again, ask the other companies / maintainers. Pretty sure they also work on GPU, if they are still active (some libs you mentioned are a bit abandoned or deprecated) and sufficiently staffed to do it.

The challenges is about making it faster than CPU, it requires knowledge and expertise in GPU programming, and time.

  1. Why does the concrete library implement GPU acceleration while concrete-ml does not? Are there any technical challenges in implementing GPU acceleration in homomorphic machine learning? Looking forward to your answer, thank you!

Concrete GPU was released in Q2 2024, Concrete ML GPU will be released in Q3 2024, so just a bit of patience here.

Cheers

summer-xrx commented 3 weeks ago

Hi, @bcm-at-zama, Thank you very much for your generous reply! May I ask another question? When we run a CNN network using concrete-ml library on the server for image classification, it takes a long time to classify an image when the network is large. When the network becomes smaller, the time consumption will decrease, but the accuracy will decrease significantly. From this, we infer that this library is of scientific or experimental nature and still has a long way to go before it can be practically applied. May I ask if our inference is correct?

bcm-at-zama commented 3 weeks ago

Hey, could you share your code, maybe? Hard to say without knowing, it depends on how big your NN is. Having a good accuracy with a smaller NN also depends quite a lot on the task you want to perform. We can perform not-that-small NN in FHE on non-trivial tasks as CIFAR, in less than a minute (and already less than 30s with GPU), eg, https://github.com/zama-ai/concrete-ml/blob/main/use_case_examples/cifar/cifar_brevitas_training/README.md#accuracy-and-performance . And yes, we're making progress on the speed side, quarters after quarters.

summer-xrx commented 3 weeks ago

Hi, @bcm-at-zama, The CPU of the server has 64 cores and 128 threads. The memory is 996G. The task is satellite image classification, which is a "10-classification" task. The dataset is NaSC-TG2, which contains 20000 RGB images that are 128128, including 16000 for training and 4000 for test. In our code, to reduce time consumption, we "resized" the images into 64 64 images. The definition of "small network" is as follows: image The definition of "big network" is as follows: image

In the task, using the "small network" can achieve an accuracy of 71% in plaintext. When nbits=6, the accuracy in ciphertext decreases to 61.2%, and the processing time for each image is about 5 minutes. When nbits=7, the accuracy in ciphertext is 67%, but the processing time for each image is about 27 minutes. If the "big network" is used, an accuracy of 79% can be achieved in plaintext. However, when nbits=6, the accuracy in ciphertext is around 61% (10 minutes) . When nbits=7, higher accuracy can be guaranteed, but the processing time for each image is longer. Looking forward to your reply, thank you!

The code of "small network" is as follows:


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from dataloader import train_loader, test_loader,test_dataset
import torch.nn.functional as F
from concrete.ml.torch.compile import compile_torch_model
import numpy as np
import time
from tqdm import tqdm

device = torch.device('cpu')

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=7, stride=2)
        self.conv2 = nn.Conv2d(8, 12, kernel_size=3, stride=1)
        self.conv3 = nn.Conv2d(12, 8, kernel_size=3, stride=1)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2)
        self.fc1 = nn.Linear(72,10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = x.view(-1, 72)
        x = self.fc1(x)
        return x

def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the model on the 4000 test images: {100 * correct / total:.2f}%')
    return 100*correct/total

#def train(model, train_loader, criterion, optimizer, scheduler, device, epochs=5):
def train(model, train_loader, criterion, optimizer, device, epochs=5):
    model.train()
    max_accurrancy_rate = 0
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 25 == 24:
                print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0
        accurrancy_rate = test(model, test_loader, device)
        torch.save(model.state_dict(),f"./models_new/model_epoch{epoch}_{accurrancy_rate}.pth")
        #test(model, test_loader, device)
    #scheduler.step()

def test_with_concrete(quantized_module, test_loader, use_sim):
    all_y_pred = np.zeros((len(test_loader.dataset)), dtype=np.int64)
    all_targets = np.zeros((len(test_loader.dataset)), dtype=np.int64)
    idx = 0
    for data, target in tqdm(test_loader):
        data = data.numpy()
        target = target.numpy()
        fhe_mode = "simulate" if use_sim else "execute"
        y_pred = quantized_module.forward(data, fhe=fhe_mode)
        endidx = idx + target.shape[0]
        all_targets[idx:endidx] = target
        y_pred = np.argmax(y_pred, axis=1)
        all_y_pred[idx:endidx] = y_pred
        idx += target.shape[0]
    n_correct = np.sum(all_targets == all_y_pred)
    return n_correct / len(test_loader.dataset)

model = CNN().to(device)

model.load_state_dict(torch.load('models_new64/model_epoch84_70.5.pth'))
test(model, test_loader, device)
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.0005)
#optim_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.8)

#train(model, train_loader, criterion, optimizer, optim_scheduler, device, epochs=1)
# train(model, train_loader, criterion, optimizer, device, epochs=100)
#exit()

#model = CNN()
#model.load_state_dict(torch.load('models/model_epoch4.pht'))
# model=torch.load('models\model_epoch29.pht')

#model = CNN().to(device)
#model.load_state_dict(torch.load('models/model.pth'),strict=False)

#test(model, test_loader, device)

#import torch.onnx
#dummy_input = torch.randn(32, 1, 128, 128)
#torch.onnx.export(model, dummy_input, "model.onnx", do_constant_folding=True)

train_features = []
train_labels = []

for inputs, labels in train_loader:
    train_features.append(inputs)
    train_labels.append(labels)

train_features = torch.cat(train_features) #tensor
train_labels = torch.cat(train_labels)

x_train = train_features.to(device)#.numpy()
y_train = train_labels.to(device)#.numpy()

n_bits = 6

test_features = []
test_labels = []

for inputs, labels in test_loader:
    test_features.append(inputs)
    test_labels.append(labels)

test_features = torch.cat(test_features) #tensor
test_labels = torch.cat(test_labels)

x_test = test_features.to(device)#.numpy()
y_test = test_labels.to(device)#.numpy()

print("===================Start Compile========================")
q_module = compile_torch_model(model, x_train[:,:], n_bits=n_bits,rounding_threshold_bits={"n_bits": n_bits+1, "method": "approximate"})
# # q_module = compile_torch_model(model, x_train, n_bits=6,rounding_threshold_bits={"n_bits": 6, "method": "approximate"})

print(q_module.fhe_circuit.statistics)

start_time = time.time()
accs = test_with_concrete(
    q_module,
    test_loader,
    use_sim=True,
)
sim_time = time.time() - start_time

print(f"Simulated FHE execution for {n_bits} bit network accuracy: {(100*accs):.2f}%")
 # Generate keys first
t = time.time()
q_module.fhe_circuit.keygen()
print(f"Keygen time: {time.time()-t:.2f}s")
# Run inference in FHE on a single encrypted example
mini_test_dataset = TensorDataset(torch.Tensor(x_test[:1, :]), torch.Tensor(y_test[:1]))
mini_test_dataloader = DataLoader(mini_test_dataset)

t = time.time()
accuracy_test = test_with_concrete(
    q_module,
    mini_test_dataloader,
    use_sim=False,
)
elapsed_time = time.time() - t
time_per_inference = elapsed_time / len(mini_test_dataset)
accuracy_percentage = 100 * accuracy_test

print(
    f"Time per inference in FHE: {time_per_inference:.2f} "
    f"with {accuracy_percentage:.2f}% accuracy")
'''
bcm-at-zama commented 3 weeks ago

Thanks a lot for this information, it's very interesting. I need to talk with the team, to analyse that. As it's summer here, it may take a bit of time, but we'll come back to you.

What about we have a zoom call, to discuss about your use-case? If it interests you, would you send an email to hello@zama.ai and say you want to speak with Benoit, please? We can certainly help you making your use-case even better. Cheers

bcm-at-zama commented 3 weeks ago

Also @summer-xrx could you send a self-contained piece of code, please? Eg, it seems that your code in https://github.com/zama-ai/concrete-ml/issues/839#issuecomment-2296284825 does not work / has no reference to its dataset, eg.

bcm-at-zama commented 3 weeks ago

Also, one thing which would help would be to compile with show_mlir=True and attach the MLIR. Doing this, we can see the number of PBS, which are the expensive operations.

summer-xrx commented 3 weeks ago

OK, @bcm-at-zama, thank you for your help! The result of "small network" with nbits=6 is as follows. From the picture, we can see that programmable_bootstrap_count=10396, and "simulate" mode is very fast but "execute" mode is slow. image The result of "small network" with nbits=7 is as follows. From the picture, we can see that programmable_bootstrap_count=10396, and "simulate" mode is fast but "execute" mode is very slow (56min/image). image The result of "big network" with nbits=6 is as follows. From the picture, we can see that programmable_bootstrap_count=20792, and "simulate" mode is fast but "execute" mode is very slow (10min/image). image When running the code, CPU utilization rate is nearly 100%: image

Our goal is to achieve a high level of accuracy (>=75%) with low time consumption. (<=5 minutes) The complete code will be sent to you via email later. Looking forward to your reply, thank you!

bcm-at-zama commented 3 weeks ago

Thanks for the extra info here! In particular, your 'When running the code, CPU utilization rate is nearly 100%' is worth having a look. We'll have a look, it's some work, so be patient please.

Regarding your 10k PBS or 20k PBS: we have about 150k PBS in our CIFAR example, with a nbits=6 if I am not wrong, and it runs in about ~40s on CPU and ~20s in CPU. Which means I think doing the 10k or 20k PBS in less than 5 min is really achievable, at first sight. The tech team will have a look and we'll tell you more.

Yes please send the code, we need it, we can't wait without.

bcm-at-zama commented 3 weeks ago

CC @andrei-stoian-zama @jfrery

bcm-at-zama commented 3 weeks ago

@summer-xrx : please send a mail to hello@zama.ai for Benoit, and let's meet to discuss over Zoom

summer-xrx commented 3 weeks ago

Hello, @bcm-at-zama, The complete code has been sent to your email hello@zama.ai. It is the "small network" with nbits=6. I'm sorry that I don't have the conditions to have a Zoom meeting with you on my end.

Looking forward to your reply, thank you!

bcm-at-zama commented 3 weeks ago

Let me have a look to your email, thanks.

bcm-at-zama commented 3 weeks ago

Why can't you use Zoom, @summer-xrx ? Or maybe tell me what kind of call we can have, I'm pretty open (Google Meet, eg). I would like to discuss the use-case in more details with you. Cheers

bcm-at-zama commented 3 weeks ago

@summer-xrx : could you re-send your mail please

bcm-at-zama commented 3 weeks ago

And when I read "78.83M", I am a bit worried: we'll not be able to audit such a large code. Hopefully it's just data, that we don't really need and that you could replace with random inputs / datasets?

summer-xrx commented 3 weeks ago

Hello, @bcm-at-zama, Don't worry. The codes are 9KB at all. The remaining files are datasets and our trained models, I think they would help. In the new zip file, I have deleted the trainning dataset. Thanks.

[edit by @bcm-at-zama : removed the zip file]

summer-xrx commented 3 weeks ago

My e-mail is 328474049@qq.com. And about why I can't use Zoom, I have 3 reasons.

  1. Because of the nature of my work, I cannot hold video meetings with foreigners.
  2. My network here is not very stable, I'm worried it will affect the quality of the meeting.
  3. My spoken English is not good.

I hope you can understand my difficulties, thank you!

bcm-at-zama commented 3 weeks ago

@summer-xrx : we can only offer support if users send us the code copy pasted in text in an email, or if they create a private/public github repo, as we cannot open external files. Thus, I deleted the zip file in your previous message

summer-xrx commented 2 weeks ago

Hello, @bcm-at-zama, I have created a public GitHub repo named "[satellite-image-classification]"(https://github.com/summer-xrx/satellite-image-classification). I'm sorry for keeping you waiting for so long due to my work reasons.

Looking forward to your reply, thank you!

bcm-at-zama commented 2 weeks ago

Thanks a lot, we'll have a look! It may take a bit of time, we'll keep you updated.

summer-xrx commented 2 weeks ago

Hello, may I ask if there is any new information?

bcm-at-zama commented 2 weeks ago

No there isn't: we'll tell you when we have updates, but it may not be that soon. We have some priorities to take care of + some people are in holidays. We'll keep you updated

andrei-stoian-zama commented 2 weeks ago

Could you post your model's MLIR ? See here how to get it: https://docs.zama.ai/concrete-ml/deep-learning/fhe_assistant#complexity-analysis

You should also check out these performance tips: https://docs.zama.ai/concrete-ml/deep-learning/optimizing_inference