ftramer / slalom

Fast, Verifiable and Private Execution of Neural Networks in Trusted Hardware
MIT License
161 stars 40 forks source link

Confusion about "Quantization" #8

Closed gutouyu closed 6 years ago

gutouyu commented 6 years ago

In the paper Slalom: Fast,Verifiable and Private Execution of Neural Networks int Trusted Hardware chapter 3.1 Quantization:

First, it says

embed these integers in a field Zp of integers modulo a prime p.

My questions are :

  1. What's the meaning of embed in a field Zp specifically? Any examples will be much appreciated.
  2. Why we need to modulo a prime p? I do not understand.

Second, about xw + b using FP(fixed-point)

it says we need to scale the output by 2^(-k), why it is -k? I think it should be -2k: FP(x) * FP(W) + FP(b) = nearest-integer(2^k * x * 2^k * w + 2^(2k) * b )= nearest-integer(2^(2k) * (xw + b))

so , I think we should scale the output by 2^(-2k) to get the origin output. Am I missing something ?

Third, it says:

For efficiency reasons, we represent quantized inputs and weights as floating point values rather than integers.

  1. Are we here only change the type of integers,right? For example , integer 6, we use float 6.0 to represent it?

We thus need all quantized values computed during a DNN evaluation to be bounded by 2^24

  1. What's the meaning of bounded by 2^24? So all nums are smaller than 2^24? Why it is the 2^24 not 2^32 2^64 or kind of?

Just could not understand this concept. Still working on this awesome project, Looking forward to hear form you. Thanks a lot.

ftramer commented 6 years ago

Great questions!

1) To apply Freivald's algorithm to verify matrix multiplication, you have to operate over an algebraic field (https://en.wikipedia.org/wiki/Field_(mathematics)). The simplest example of a field are the "integers modulo a prime". E.g., for p=11, you work over the set of integers Z_11 = {0,1,2,...,10}, and map an arbitrary integer N to N mod p. So for instance 11 becomes 0, 13 becomes 2 (i.e., 13=1*11+2), 37 becomes 4 (i.e., 37=4*11+4), etc. So to verify DNN computations, we need to evaluate the DNN over Z_p for some prime p. The problem is that we also need to ensure that computing the DNN over Z_p gives us the same results as computing the DNN over "regular" integers (otherwise we'd be destroying our model's accuracy). We thus choose a prime p large enough so that we never actually need to compute a modulo (i.e., all intermediate values computed by the DNN are in the interval [-p/2, p/2]).

2) We first convert all numbers in our neural network (weights and inputs) to integers, by scaling them by 2^k and then rounding. You're right that to get back to the "original" representation, you should scale the output of a linear layer by 2^{-2k}. But there are usually more layers after that, which again need their inputs to be of the form round(2^k * x). So, we just rescale by 2^{-k} to keep outputs of one layer in the right format expected by the inputs of the next layer.

3.1) Yes, correct.

3.2) As described above, we choose a prime p so that all intermediate values in a DNN evaluation are in the interval [-p/2, p/2]. We choose p to be close to 2^24, because that is (roughly) the largest integer representable in a float, without loss of precision. In the end, we select quantization parameters (i.e., k) so that we can do all integer computations in our DNN using floats, and we never end up with a value larger than 2^24.

gutouyu commented 6 years ago

@ftramer It is quite clearly for your explanation. Thanks a lot. I have understood why we are doing this Quantization, for we want a finite field for Freivald's algorithm. In details I still have some questions described below:

  1. intermediate values in a DNN evaluation are in the interval [-p/2, p/2]. We found a large prime P, every number mod P, we get [0, P], why it is [-p/2, p/2]? Could you plz explain how we get the -p/2 and p/2 in details?Thanks.
  2. We choose p to be close 2^24。 It's roughly the largest integer representable in a float. As far as I know float32 could be much larger than 2^24, something around 10^38. I still don't understand why it's 2^24.
  3. Some confusions about the code file: python/slalom/quant_layers.py:
    P = 2**23 + 2**21 + 7
    INV_P = 1.0 / P
    MID = P // 2
    assert(P + MID < 2**24)
    q = float(round(np.sqrt(MID))) + 1
    inv_q = 1.0 / q

    So this P is the large prime we choose , right?How come of the 2**23 + 2**21 + 7 exactly? Also why we still need a q MID, what's the meaning of them and what they are used for ? Also, why we must guarntee P + MID < 2^24?

Really want to understand everything in this paper and the code. Looking froward to your reply. Best wishes. 😆

ftramer commented 6 years ago

1. You can view the output of a modulo as lying in the interval [0, p] or [-p/2, p/2]. They are equivalent, it is just a choice of convention. E.g., 5 mod 7 = -2 is perfectly correct. We prefer using [-p/2, p/2] as the values in the DNN can be positive or negative.

2. Yes floats can go as high as 10^38, but with loss of precision. E.g., if you do operations with integers bigger than 2^24, the results will not necessarily be exact.

3. p=2**23 + 2**21 + 7 is the largest prime such that p+p/2 is smaller than 2^24. This is a small (and not particularly important) optimization for the privacy part. There, we end up performing operations of the form x+r mod p where x is in [0, p] and r is in [-p/2, p/2]. So x+r can be as high as p+p/2 and we make sure this never overflows 2^24. The q parameter is another optimization quirk. When evaluating convolutions on encrypted data on the GPU, we need to use double precision to avoid precision loss. But CUDA is extremely bad at doing double convolutions. So instead, we take the input x in [-p/2, p/2] and write is as x = x1 + q*x2, where |x1| and |x2| are smaller than sqrt(p). Then, we can compute Conv(x) = Conv(x1 + q*x2) = Conv(x1) + q*Conv(x2). So we compute two floating-point convolutions on smaller inputs rather than one double convolution on the full input.

gutouyu commented 6 years ago

Wow! It is so clear, thanks for your reply.