cpldcpu / BitNetMCU

Neural Networks with low bit weights on low end 32 bit microcontrollers such as the CH32V003 RISC-V Microcontroller and others
GNU General Public License v3.0
221 stars 20 forks source link

4.6 quantization scheme #4

Open kimstik opened 1 month ago

kimstik commented 1 month ago

If you have plans to develop this project further, I would like to suggest a 4.6-bit scheme. https://www.mdpi.com/2227-7390/12/5/651 I think this is an interesting schematic that fits very well on a tiny devices. The (255, 3) and (31, 17) encodings look particularly interesting. Fast multiplication can be done simply with a tiny 1k LUT table.

4.6 quantization

cpldcpu commented 1 month ago

Quite interesting! Thank you for the pointer! 👍

It seems that this mainly pays off when quantizting also the activations in the same way. Right now I usually have large activations (e.g. 16 bit) and only quantize the weights.

Need to review again later. Right now I am working on optimizing quantization clipping.

cpldcpu commented 1 month ago

If you want to implement optmized inference code based on microkernel muls, just go ahead. This would already work with 4 bit quantization.

According to my obervations, 4 bit is already pretty good if the clipping is optimized.

kimstik commented 1 month ago

Little proto for (31, 17):

static const int8_t mtbl5_4[512] = {
 127,  110,   93,   76,   59,   42,   25,    8,   -8,  -25,  -42,  -59,  -76,  -93, -110, -127, 
 119,  103,   87,   71,   56,   40,   24,    8,   -8,  -24,  -40,  -56,  -71,  -87, -103, -119, 
 111,   96,   81,   67,   52,   37,   22,    7,   -7,  -22,  -37,  -52,  -67,  -81,  -96, -111, 
 103,   89,   76,   62,   48,   34,   21,    7,   -7,  -21,  -34,  -48,  -62,  -76,  -89, -103, 
  95,   83,   70,   57,   44,   32,   19,    6,   -6,  -19,  -32,  -44,  -57,  -70,  -83,  -95, 
  87,   76,   64,   52,   41,   29,   17,    6,   -6,  -17,  -29,  -41,  -52,  -64,  -76,  -87, 
  79,   69,   58,   48,   37,   26,   16,    5,   -5,  -16,  -26,  -37,  -48,  -58,  -69,  -79, 
  71,   62,   52,   43,   33,   24,   14,    5,   -5,  -14,  -24,  -33,  -43,  -52,  -62,  -71, 
  64,   55,   47,   38,   30,   21,   13,    4,   -4,  -13,  -21,  -30,  -38,  -47,  -55,  -64, 
  56,   48,   41,   33,   26,   19,   11,    4,   -4,  -11,  -19,  -26,  -33,  -41,  -48,  -56, 
  48,   41,   35,   29,   22,   16,   10,    3,   -3,  -10,  -16,  -22,  -29,  -35,  -41,  -48, 
  40,   34,   29,   24,   19,   13,    8,    3,   -3,   -8,  -13,  -19,  -24,  -29,  -34,  -40, 
  32,   28,   23,   19,   15,   11,    6,    2,   -2,   -6,  -11,  -15,  -19,  -23,  -28,  -32, 
  24,   21,   17,   14,   11,    8,    5,    2,   -2,   -5,   -8,  -11,  -14,  -17,  -21,  -24, 
  16,   14,   12,   10,    7,    5,    3,    1,   -1,   -3,   -5,   -7,  -10,  -12,  -14,  -16, 
   8,    7,    6,    5,    4,    3,    2,    1,   -1,   -2,   -3,   -4,   -5,   -6,   -7,   -8, 
   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0, 
  -8,   -7,   -6,   -5,   -4,   -3,   -2,   -1,    1,    2,    3,    4,    5,    6,    7,    8, 
 -16,  -14,  -12,  -10,   -7,   -5,   -3,   -1,    1,    3,    5,    7,   10,   12,   14,   16, 
 -24,  -21,  -17,  -14,  -11,   -8,   -5,   -2,    2,    5,    8,   11,   14,   17,   21,   24, 
 -32,  -28,  -23,  -19,  -15,  -11,   -6,   -2,    2,    6,   11,   15,   19,   23,   28,   32, 
 -40,  -34,  -29,  -24,  -19,  -13,   -8,   -3,    3,    8,   13,   19,   24,   29,   34,   40, 
 -48,  -41,  -35,  -29,  -22,  -16,  -10,   -3,    3,   10,   16,   22,   29,   35,   41,   48, 
 -56,  -48,  -41,  -33,  -26,  -19,  -11,   -4,    4,   11,   19,   26,   33,   41,   48,   56, 
 -64,  -55,  -47,  -38,  -30,  -21,  -13,   -4,    4,   13,   21,   30,   38,   47,   55,   64, 
 -71,  -62,  -52,  -43,  -33,  -24,  -14,   -5,    5,   14,   24,   33,   43,   52,   62,   71, 
 -79,  -69,  -58,  -48,  -37,  -26,  -16,   -5,    5,   16,   26,   37,   48,   58,   69,   79, 
 -87,  -76,  -64,  -52,  -41,  -29,  -17,   -6,    6,   17,   29,   41,   52,   64,   76,   87, 
 -95,  -83,  -70,  -57,  -44,  -32,  -19,   -6,    6,   19,   32,   44,   57,   70,   83,   95, 
-103,  -89,  -76,  -62,  -48,  -34,  -21,   -7,    7,   21,   34,   48,   62,   76,   89,  103, 
-111,  -96,  -81,  -67,  -52,  -37,  -22,   -7,    7,   22,   37,   52,   67,   81,   96,  111, 
-119, -103,  -87,  -71,  -56,  -40,  -24,   -8,    8,   24,   40,   56,   71,   87,  103,  119, 
};

for (uint32_t j = 0; j < 8; j++) {
    int in=*activations_idx++;
    if (in != 0) { // Skip zero activations to speed up inference in layers after first layer
        unsigned idx = (in&0xf8)<<1;    //scale 8 bit to 5 bits
        idx |= ((weightChunk>>(32-4)) & 0xf);
        sum += mtbl5_4[idx & 0x1ff];
    }
    weightChunk <<= 4;
}

which gives nice 10 instructions per weights unrolled code:

...
.L48:
    lb  t2,1(a4)
    beq t2,zero,.L49
    slli    t2,t2,1
    slli    t0,a0,4
    andi    t2,t2,496
    srli    t0,t0,28
    or  t0,t0,t2
    add t0,s0,t0
    lbu t0,0(t0)
    add a5,a5,t0
.L49:
    lb  t2,2(a4)
    beq t2,zero,.L50
    slli    t2,t2,1
    slli    t0,a0,8
    andi    t2,t2,496
    srli    t0,t0,28
    or  t0,t0,t2
    add t0,s0,t0
    lbu t0,0(t0)
    add a5,a5,t0
.L50:
...

I suspect it could be trimmed to 9 instructions by pre-shifting/masking activations in ReLUNorm, but this would necessitate increasing the activations to int16_t. ...to be tested ;)

kimstik commented 1 month ago

What's great about the LUT approach is that it enables the spreading of weight statistics by using techniques such as gamma.

By the way, are you planning to use the 2K memory of the bootloader? Is it as fast as the main memory?

cpldcpu commented 1 month ago

Nice!

Yeah, when a LUT is used, it would allow to use an optimized numerical representation, like NF4 in this paper: https://arxiv.org/pdf/2305.14314

Alternative idea, use exp/log: a*b = exp(log(a)+log(b))

Activations can be logarithmized during Relu. Weights are stored as log. A table lookup is then needed for the exp.

cpldcpu commented 1 month ago

I have not considered using the bootloader memory yet. The flash is not as fast in general. The LUT should ideally also reside in SRAM to reduce waitstates, but not sure that would fit.

cpldcpu commented 1 month ago

I added NF4 quantization now. There is some benefit, but not much.

https://github.com/cpldcpu/BitNetMCU/blob/main/docs/documentation.md#july-26-2024-normalfloat4-nf4-quantization