Open kimstik opened 4 months ago
Quite interesting! Thank you for the pointer! 👍
It seems that this mainly pays off when quantizting also the activations in the same way. Right now I usually have large activations (e.g. 16 bit) and only quantize the weights.
Need to review again later. Right now I am working on optimizing quantization clipping.
If you want to implement optmized inference code based on microkernel muls, just go ahead. This would already work with 4 bit quantization.
According to my obervations, 4 bit is already pretty good if the clipping is optimized.
Little proto for (31, 17):
static const int8_t mtbl5_4[512] = {
127, 110, 93, 76, 59, 42, 25, 8, -8, -25, -42, -59, -76, -93, -110, -127,
119, 103, 87, 71, 56, 40, 24, 8, -8, -24, -40, -56, -71, -87, -103, -119,
111, 96, 81, 67, 52, 37, 22, 7, -7, -22, -37, -52, -67, -81, -96, -111,
103, 89, 76, 62, 48, 34, 21, 7, -7, -21, -34, -48, -62, -76, -89, -103,
95, 83, 70, 57, 44, 32, 19, 6, -6, -19, -32, -44, -57, -70, -83, -95,
87, 76, 64, 52, 41, 29, 17, 6, -6, -17, -29, -41, -52, -64, -76, -87,
79, 69, 58, 48, 37, 26, 16, 5, -5, -16, -26, -37, -48, -58, -69, -79,
71, 62, 52, 43, 33, 24, 14, 5, -5, -14, -24, -33, -43, -52, -62, -71,
64, 55, 47, 38, 30, 21, 13, 4, -4, -13, -21, -30, -38, -47, -55, -64,
56, 48, 41, 33, 26, 19, 11, 4, -4, -11, -19, -26, -33, -41, -48, -56,
48, 41, 35, 29, 22, 16, 10, 3, -3, -10, -16, -22, -29, -35, -41, -48,
40, 34, 29, 24, 19, 13, 8, 3, -3, -8, -13, -19, -24, -29, -34, -40,
32, 28, 23, 19, 15, 11, 6, 2, -2, -6, -11, -15, -19, -23, -28, -32,
24, 21, 17, 14, 11, 8, 5, 2, -2, -5, -8, -11, -14, -17, -21, -24,
16, 14, 12, 10, 7, 5, 3, 1, -1, -3, -5, -7, -10, -12, -14, -16,
8, 7, 6, 5, 4, 3, 2, 1, -1, -2, -3, -4, -5, -6, -7, -8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6, 7, 8,
-16, -14, -12, -10, -7, -5, -3, -1, 1, 3, 5, 7, 10, 12, 14, 16,
-24, -21, -17, -14, -11, -8, -5, -2, 2, 5, 8, 11, 14, 17, 21, 24,
-32, -28, -23, -19, -15, -11, -6, -2, 2, 6, 11, 15, 19, 23, 28, 32,
-40, -34, -29, -24, -19, -13, -8, -3, 3, 8, 13, 19, 24, 29, 34, 40,
-48, -41, -35, -29, -22, -16, -10, -3, 3, 10, 16, 22, 29, 35, 41, 48,
-56, -48, -41, -33, -26, -19, -11, -4, 4, 11, 19, 26, 33, 41, 48, 56,
-64, -55, -47, -38, -30, -21, -13, -4, 4, 13, 21, 30, 38, 47, 55, 64,
-71, -62, -52, -43, -33, -24, -14, -5, 5, 14, 24, 33, 43, 52, 62, 71,
-79, -69, -58, -48, -37, -26, -16, -5, 5, 16, 26, 37, 48, 58, 69, 79,
-87, -76, -64, -52, -41, -29, -17, -6, 6, 17, 29, 41, 52, 64, 76, 87,
-95, -83, -70, -57, -44, -32, -19, -6, 6, 19, 32, 44, 57, 70, 83, 95,
-103, -89, -76, -62, -48, -34, -21, -7, 7, 21, 34, 48, 62, 76, 89, 103,
-111, -96, -81, -67, -52, -37, -22, -7, 7, 22, 37, 52, 67, 81, 96, 111,
-119, -103, -87, -71, -56, -40, -24, -8, 8, 24, 40, 56, 71, 87, 103, 119,
};
for (uint32_t j = 0; j < 8; j++) {
int in=*activations_idx++;
if (in != 0) { // Skip zero activations to speed up inference in layers after first layer
unsigned idx = (in&0xf8)<<1; //scale 8 bit to 5 bits
idx |= ((weightChunk>>(32-4)) & 0xf);
sum += mtbl5_4[idx & 0x1ff];
}
weightChunk <<= 4;
}
which gives nice 10 instructions per weights unrolled code:
...
.L48:
lb t2,1(a4)
beq t2,zero,.L49
slli t2,t2,1
slli t0,a0,4
andi t2,t2,496
srli t0,t0,28
or t0,t0,t2
add t0,s0,t0
lbu t0,0(t0)
add a5,a5,t0
.L49:
lb t2,2(a4)
beq t2,zero,.L50
slli t2,t2,1
slli t0,a0,8
andi t2,t2,496
srli t0,t0,28
or t0,t0,t2
add t0,s0,t0
lbu t0,0(t0)
add a5,a5,t0
.L50:
...
I suspect it could be trimmed to 9 instructions by pre-shifting/masking activations in ReLUNorm, but this would necessitate increasing the activations to int16_t. ...to be tested ;)
What's great about the LUT approach is that it enables the spreading of weight statistics by using techniques such as gamma.
By the way, are you planning to use the 2K memory of the bootloader? Is it as fast as the main memory?
Nice!
Yeah, when a LUT is used, it would allow to use an optimized numerical representation, like NF4 in this paper: https://arxiv.org/pdf/2305.14314
Alternative idea, use exp/log: a*b = exp(log(a)+log(b))
Activations can be logarithmized during Relu. Weights are stored as log. A table lookup is then needed for the exp.
I have not considered using the bootloader memory yet. The flash is not as fast in general. The LUT should ideally also reside in SRAM to reduce waitstates, but not sure that would fit.
I added NF4 quantization now. There is some benefit, but not much.
If you have plans to develop this project further, I would like to suggest a 4.6-bit scheme. https://www.mdpi.com/2227-7390/12/5/651 I think this is an interesting schematic that fits very well on a tiny devices. The (255, 3) and (31, 17) encodings look particularly interesting. Fast multiplication can be done simply with a tiny 1k LUT table.