yhhhli / APoT_Quantization

PyTorch implementation for the APoT quantization (ICLR 2020)
258 stars 51 forks source link

The MUL unit of APOT #19

Open clevercool opened 3 years ago

clevercool commented 3 years ago

Hi,

Do you have the specific design of the MUL (Multiplication) unit for APOT quantization?

We know that uniform(Int) quantization or POT quantization are friendly to hardware.

Assume that: R = real number S = Scale number T = quantized number R1 = S1 T1 R2 = S2 T2

Uniform quantization simply adopts the INT MUL unit:

T1 = m
T2 = n

So, we have:

R1 * R2 = (S1 * S2) * (m * n) 

For POT:

T1 = 2^m
T2 = 2^n

So, we have:

R1 * R2 = (S1 * S2) * (2^m * 2^n) 
             =  (S1 * S2) * 2^(m + n)

The POT is similar to the only-exponent float MUL.

However, for APOT, I have two questions about the MUL design. There are additive elements in the data. Assume a 4-bit POT: The first two bits decoder table: 00 01 10 11
2^0 2^-1 2^-3 2^-5

And the last two bits:

00 01 10 11
2^0 2^-2 2^-4 2^-6

For the first two bits, the decoder table is not continuous: 0, -1, -3, -5.

Q1: How do you efficiently decode the binary code to the APOT, especially in the MUL unit?

Aussume the two number in APOT:

0101: T1 = 2^-1 + 2^-2
1010: T2 = 2^-3 + 2^-4
T1 * T2 = (2^-1 + 2^-2) * (2^-3 + 2^-4) 
             = (2^-1 * 2^-3) + (2^-1 * 2^-4) + (2^-2 * 2^-3) + (2^-2 * 2^-4)
             = 2^-4 + 2^-5 + 2^-5 + 2^-6

Obviously, the calculation has 4x (9x) add operations than POT in 4-bit (6-bit). And the result violates the definition of APOT, which won't have the same additive element in a number, such as 2^-5.

Q2: How do you deal with the complex computation and the subnormal number for APOT?

One direct solution is to convert a float with fake quantization. But is it a violation of the principle of quantization?

dzdang commented 2 years ago

I have the same question about APoT.

Also, when you wrote (for regular PoT)

R1 * R2 = (S1 * S2) * (2^m * 2^n) 
             =  (S1 * S2) * 2^(m + n)

Are S1, S2 integers? if not, S1 * S2 is a floating point value, and I'm not sure how (S1 * S2) * 2^(m + n) can be done with bit shift operations. This was one of the confusing parts of the paper for me (eqn 4) and it seems to assume the activations are uniformly quantized to integers, otherwise I'm not sure how bit shift can be used. Edit: missed the part about "fixed point representation." That might enable bit shifting for this use case