narger-ef / LowMemoryFHEResNet20

Source code for the paper "Encrypted Image Classification with Low Memory Footprint using Fully Homomorphic Encryption"
https://eprint.iacr.org/2024/460
MIT License
24 stars 7 forks source link

How to understand downsampling #20

Open Sinking-Stone opened 1 week ago

Sinking-Stone commented 1 week ago

Hello, I'm very sorry to bother you again. When I read your source code, I don't quite understand how your downsampling is calculated. Could you explain it to me, because it is just a simple description in your paper.

narger-ef commented 1 week ago

Essentially there are 3 phases, I take Ctxt FHEController::downsample1024to256(const Ctxt &c1, const Ctxt &c2) { as an example:

(notice the line numbers in the image)

Screenshot 2024-06-24 alle 10 08 05

This is needed because some slots will be empty, because the previous convolution has some blank values (because of the stride equal to 2)

Sinking-Stone commented 6 days ago

Can you explain line 346 what is in the main.cpp vector<Ctxt> res1sx = controller.convbn1632sx(boot_in, 4, 1, scaleSx, timing); //Questo e lento vector<Ctxt> res1dx = controller.convbn1632dx(boot_in, 4, 1, scaleDx, timing); //Questo e lento What does this do? I don't understand it from this place. Can you explain what sx and dx are? How can I understand them better?

narger-ef commented 6 days ago

I refer you to this, especially the image

Sinking-Stone commented 6 days ago

Oh, I have understood sx and dx, but I still can't connect the picture you gave me in the downsampling with the code. Could you give me an example? Thank you very much.

narger-ef commented 5 days ago

The fourth block (in figure) has two branches: sx and dx

Screenshot 2024-06-25 alle 13 46 10

both do a downsampling (on the sx immediately, on the dx after the first convolution)

Sinking-Stone commented 5 days ago

Yes, I already know this, what I don't understand is the line corresponding to the code in the figure below.

342254978-071b906b-c228-4178-b3df-be9a30d5a5a4
narger-ef commented 5 days ago

Ah ok. The "re-arranging" figure is the procedure performed inside the Conv2D blocks, since they have a stride of {2, 2}. This means that the kernel window is shifted by 2 positions (and not by 1, as in previous convolutions), leaving one block empty in our ciphertexts (because the HE convolution assumes that the stride is equal to {1, 1}). For this reason, we have to re-arrange the values in order to fill the empty slots

Sinking-Stone commented 4 days ago

Simulation of FHEController 1191-1194

def rot(nums,index):
    index=index%len(nums)
    if(index>0):
        return nums[index:]+nums[:index]
    else:
        return nums[-index:]+nums[:-index]

def create_vec(n):
    res=[]
    for i in range(1,n+1):
        line=[i]
        res.append(line)
    return res

def vec_add(vec1,vec2):
    for i in range(len(vec1)):
        if vec2[i]==[0]:
            continue
        elif vec1[i]==[0]:
            vec1[i]=vec2[i]
        else:
            vec1[i].extend(vec2[i])
    return vec1

def vec_mult(vec1,vec2):
    for i in range(len(vec1)):
        if vec2[i]==0:
            vec1[i]=[0]

    return vec1

def gen_mask(n):
    mask=[]
    ci=n
    for i in range(1,32*32*32+1):
        if ci>0 :
            mask.append(1)
        else:
            mask.append(0)
        ci-=1
        if ci<=-n:
            ci=n
    return mask

if __name__ == '__main__':

    with open('test.txt','w') as f:
        fullpacke=create_vec(32*32*32)
        fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,1)),gen_mask(2))
        f.write(str(fullpacke))
        f.write('\n------------------------------------------------------------------------------------------\n')
        fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,2)),gen_mask(4))
        f.write(str(fullpacke))
        f.write('\n------------------------------------------------------------------------------------------\n')
        fullpacke=vec_mult(vec_add(fullpacke,rot(fullpacke,4)),gen_mask(8))
        f.write(str(fullpacke))
        f.write('\n------------------------------------------------------------------------------------------\n')
        fullpacke=vec_add(fullpacke,rot(fullpacke,8))
        f.write(str(fullpacke))

I am very sorry that I have been simulating your program, the numbers in my code represent subscripts, maybe there is something wrong with my writing, I am different from the diagram you gave, could you give me an example about the subsampling code? Thank you very much.