Closed kali closed 2 years ago
Hi, I'm working with @acrrd. In general it works, all signed and unsigned byte matrix multiplications can be expressed as a signed byte matrix multiplication with some correctional terms. The main question probably will be, who much work it is to support those additional terms in the assembly code. Please have a look at the following explanation.
Let the matrices of shape and data type (i'll try to use some pseudo notation due to the lack of Latex in markdown, if it helps i can also attach an image of the rendered formulas):
Ai = (ai)_mk
of shape (M, K)
with signed elements ai
of i8
Bi = (bi)_kn
of shape (K, N)
with signed elements bi
of i8
Au = (au)_mk
of shape (M, K)
with unsigned elements au
of u8
Bu = (bu)_kn
of shape (K, N)
with unsigned elements bu
of u8
We have to consider three different cases:
1)
Ai * Bi
is the already implemented signed byte matrix multiplication with results in i32
. The i32
accumulator doesn't overflow for K < 2^17
and the quantization scale and zero-point are as usual.
2)
Ai * Bu
and Au * Bi
are the mixed signed and unsigned byte matrix multiplications. Both can be transformed into each other via matrix transpose as Au * Bi = ((Au * Bi)^t)^t = (Bi^t * Au^t)^t
, but for the sake of completeness let's look at them individually.
2.i)
Let bi := bu - 128
and to simplify notation only consider the K
index in the following sums. Then it holds for any element of the product Ai * Bu
that
sum_k ai_k * bu_k = sum_k ai_k * (bu_k - 128 + 128) = sum_k ai_k * bi_k + 128 * sum_k ai_k
where the first term is the usual signed byte matrix multiplication and the second term can be precomputed once for all rows of Ai
. The i32
accumulator doesn't overflow for K < 2^16
. The zero points z_Bu
of Bu
and z_Bi
of Bi
are related by the same offset, i.e. z_Bi := z_Bu - 128
. With scales s_Ai
of Ai
and s_Bu
of Bu
and the usual mappings v = s_Ai * (ai - z_Ai)
and w = s_Bu * (bu - z_Bu)
the product can be expressed in terms of the original parameters as
sum_k ai_k * bu_k = sum_k (1/s_Ai * v_k + z_Ai) * (1/s_Bu * w_k + z_Bi) + 128 * sum_k (1/s_Ai * v_k + z_Ai)
= 1/(s_Ai * s_Bu) * sum_k v_k * w_k + z_Bu/s_Ai * sum_k v_k + z_Ai/s_Bu * sum_k w_k + z_Ai * z_Bu * K
2.ii)
Similarly, let ai := au - 128
, then it holds for any element of the product Au * Bi
that
sum_k au_k * bi_k = sum_k (au_k - 128 + 128) * bi_k = sum_k ai_k * bi_k + 128 * sum_k bi_k
where the first term is the usual signed byte matrix multiplication and the second term can be precomputed once for all columns of Bi
. The i32
accumulator doesn't overflow for K < 2^16
. The zero points z_Au
of Au
and z_Ai
of Ai
are related by the same offset, i.e. z_Bi := z_Bu - 128
. With scales s_Au
of Au
and s_Bi
of Bi
and the usual mappings v = s_Au * (au - z_Au)
and w = s_Bi * (bi - z_Bi)
the product can be expressed in terms of the original parameters as
sum_k au_k * bi_k = sum_k (1/s_Au * v_k + z_Ai) * (1/s_Bi * w_k + z_Bi) + 128 * sum_k (1/s_Bi * w_k + z_Bi)
= 1/(s_Au * s_Bi) * sum_k v_k * w_k + z_Bi/s_Au * sum_k v_k + z_Au/s_Bi * sum_k w_k + z_Au * z_Bi * K
3)
Au * Bu
is the unsigned byte matrix multiplication. Analogously, it holds for any element of the product Au * Bu
that
sum_k au_k * bu_k = sum_k (au_k - 128 + 128) * (bu_k - 128 + 128)
= sum_k ai_k * bi_k + 128 * sum_k ai_k + 128 * sum_k bi_k + 128^2 * K
= sum_k ai_k * bi_k + 128 * sum_k au_k + 128 * sum_k bu_k - 128^2 * K
Depending on what is more advantagous the second to last or last representation might be chosen, where the first term is the usual signed byte matrix multiplication, the second term can be precomputed once for all rows of Ai
resp. Au
, the third term can be precomputed once for all columns of Bi
resp. Bu
, and the fourth term can be precomputed once. The i32
accumulator doesn't overflow for K < 2^15
. With scales s_Au
of Au
and s_Bu
of Bu
and the usual mappings v = s_Au * (au - z_Au)
and w = s_Bu * (bu - z_Bu)
the product can be expressed in terms of the original parameters as
sum_k au_k * bu_k = sum_k (1/s_Au * v_k + z_Ai) * (1/s_Bu * w_k + z_Bi) + 128 * sum_k (1/s_Au * v_k + z_Au) + 128 * sum_k (1/s_Bu * w_k + z_Bu) - 128^2 * K
= 1/(s_Au * s_Bu) * sum_k v_k * w_k + z_Bu/s_Au * sum_k v_k + z_Au/s_Bu * sum_k w_k + z_Au * z_Bu * K
Yes. Thanks for taking the time to write this down. This is indeed what I had in mind for the arithmetic.
As for the assembly, it should not be (too much of) an issue. I am currently refactoring the quantization so that it the logic exists in "tract-core" while "tract-linalg" just needs to support a few elementary operations (in assembly) allowing to configure a "post-multiplication" pipeline, in the spirit of gemmlowp. I already had this split inside linalg between the 1/ assembly kernels output pipeline and 2/ "higher-level" quantization stuff in linalg/frame. This second part is what I am migrating to core, expressing it in regular operations to make it easier to play with quantization scheme and take advantage of the optimiser.
This refactoring should hopefully land soon (this week or the next), once I'm happy with the performance for some in-house critical use cases. It is here for now: https://github.com/sonos/tract/tree/reexpress-quant-in-core .
Once this is done, it should be straightforward to extend quantization to support the mixed types products on top of it: it should just be a matter of translating the arithmetic to operations, as I already did for zeropoint and scale (minus some optims that are coming back this week). tract-core optimizer and linalg should do the rest. We'll nudge them a bit if we need to, of course.
How does that sound ?
sounds great, looking forward to the finalized refactoring!
(I was pretty sure I already wrote this message, but must have get confused somehow.)
So the big chunk of the refactoring is done. Phewww. Still a few small performance regressions I want to address, but the branch is merged.
So as I said, it introduced a new op, QMatMul, which can have its quantization parameters static or dynamic. It re-wires itself as a MatMul plus a dozen of smaller operators. If some parameters are trivial (scale = 1 or zeropoint = 0) the optimizer will strip the useless stuff during declutter. During codegen, most of these extra operators will amount to ops that follow the matmul and can be "fused" into the linalg matmul kernel, just as they were when the quantization logic was in linalg, so in most cases the actual tasks the linalg assembly kernel are performing are identical to what they were.
Now, we have a relatively easy path for deriving a u8*i8
, etc. products from the i8*i8
. What you wrote above is just at its core a specific case of zeropoint handling, and the current code has support for zeropoint, so we can piggyback on it.
Here is a possible way to implement these products:
|x: u8| -> x.wrapping_sub(128) as i8
add::unary()
op)QMatMul::clamp_and_cast_to
need to be aware of the new type. Maybe there will be some adjustments in the "generic" multiplier implementation. Not too sure.Then we may be able to ditch all the new multipliers combinations that you had to add in linalg/core interface to plug into all the instantiations of the generic multiplier.
What do you think ? feel like giving it a shot ?
apologies for the late reply! we would definitely like to tackle this and we take this into our plannings. we might just not be able to do it right away, but have to align it with our other tasks.
Good ! Feel free to plan this as you prefer, I am not aware of any other group having this specific need. Just give me a heads-up when you know more about timing so I will try to keep these sections of the code as stable as practical when you need them to be.
Just give me a heads-up when you know more about timing
@kali we are finally starting to work on this topic
Good ! It's pretty good timing actually. We have just finished one more good refactoring pass on the quantization code, actually. Let me know how it goes and if I can help.
i just opened #522 which fixes an indexing bug we encountered during the tests for this issue. the implementation for this issue is also in good progress, we are currently running some more tests and will soon open a pr as well :)
Thanks for #527. I'm curious to know if you get the expected speed improvement.
Don't want to close yet :) we will add the quantized inception nnef example on our side... but is there a net using these features you would be willing to contribute so we can run it in the CI ? These sections have very low coverage...
Once this is done, I still want to cleanup linalg/core and remove the i8xu8 u8xi8 that we don't need anymore.
I'm curious to know if you get the expected speed improvement.
we'll do some benchmarking next, i can give an update later on.
but is there a net using these features you would be willing to contribute so we can run it in the CI ?
i have to check if and what model we can provide.
you can download two sets of models from this bucket.
unfortunately, the benchmarking showed quite bad results, slower execution time for the quantized bert model instead of faster (this varied wildly wrt the target architecture, on my mac (intel) it was about 3 times slower, on wasm (chrome) about 2 times, on a phone (android) same speed). we compared it to execution via the onnxruntime
which was also slower for the quantized one (although comparibly faster due to multithreading).
for further investigation we also created some dummy models with a simple matrix multiplication, also to check the overhead from dynamic quantization. in this case tract
is running the quantized matmul slower than the nonquantized, but the overhead from dynamic quantization is minor. onnxruntime
is running the quantized matmul faster than the nonquantized, but the overhead from dynamic quantization is major, leading to an overall slower result for the quantized.
if you want to use the models in your tests, those are the inputs and outputs:
smbert
/smbert-quant
(onnx opset v11):
input_ids
of int64[batch,sequence]
attention_mask
of int64[batch,sequence]
token_type_ids
of int64[batch,sequence]
output_0
of float32[batch,sequence,128]
output_1
of float32[batch,128]
matmul
, matmul-dynquant
, dynquant-overhead
(onnx opset v13):
A
of float32[10,128]
C
of float32[10,128]
overall, the slow quantization is quite weird, we're not sure yet why. the good news is that the i8/u8 matmul offsetting actually works though.
interesting, i'll have a look...
Would you happen to have a working input for smbert ? I get an overrun in a gather op if i use random, so I assume the ids have to be bellow some threshold that i can't guess. Ideally in .npz / .npy form... if the npy names matches input names, a bench + profile can be run with the following command, and I can take it from there...
cargo run --release -- smbert-quant.onnx --input-bundle io.npz -O dump --profile --cost
sorry, i should have mentioned the restrictions on the values, input_ids
can contain any integer from the interval [0, 100000]
and the attention_mask
and token_type_ids
can contain 0
s and 1
s. i pushed a valid, named, uncompressed io.npz
to the bucket (which represents the tokenization of "This is a sequence."
).
All right, had a look at a few things. I am assuming you're after performance on Intel (I9 9900K), but will also use a M1 for comparison.
I9 network performance: 2.224 ms/i
M1 network performance: 1.156 ms/i
For floating point networks these two machines are often in the same ballpark.
First, operation class breakdown.
If you're curious, you can get this with:
cargo run --release -p tract -- smbert-quant.onnx --input-bundle io.npz -O dump --const --profile --cost
On an Intel i9 9900K:
* LirMatMulUnary 7 nodes: 1.093 ms/i 49.9%
* DynamicQuantizeLinearU8 8 nodes: 0.410 ms/i 18.7%
* Reduce<Sum> 11 nodes: 0.147 ms/i 6.7%
* MatMul 2 nodes: 0.118 ms/i 5.4%
* Mul 46 nodes: 0.104 ms/i 4.8%
* Sub 15 nodes: 0.067 ms/i 3.1%
* Add 20 nodes: 0.057 ms/i 2.6%
* MatMatMulPack 7 nodes: 0.037 ms/i 1.7%
* onnx.Erf 1 nodes: 0.032 ms/i 1.5%
* Reduce<Max> 1 nodes: 0.022 ms/i 1.0%
* Gather 2 nodes: 0.022 ms/i 1.0%
* Exp 1 nodes: 0.021 ms/i 1.0%
So we have roughly 58% used by the matrix product proper (LirMatMulUnary, MatMul, MatMatMulPack), whereas it's usually above 80% for NN. So we probably need to have a look at the "rest" first.
DynamicQLU8 seems like a good place to start. There are a few things that can be tried there I think:
Lots of "fusable" operation in the LirMatMulUnary are not merged because a Sub gets in the way. We could implement a fused sub the same way we have a fused add (we can test bench the impact by replacing the sub by an add, even if the computation will be wrong).
On an apple M1 latpop:
* LirMatMulUnary 7 nodes: 0.310 ms/i 27.3%
* DynamicQuantizeLinearU8 8 nodes: 0.186 ms/i 16.4%
* Mul 46 nodes: 0.153 ms/i 13.5%
* Reduce<Sum> 11 nodes: 0.120 ms/i 10.6%
* Add 20 nodes: 0.082 ms/i 7.2%
* Sub 15 nodes: 0.071 ms/i 6.3%
* MatMul 2 nodes: 0.044 ms/i 3.8%
* MatMatMulPack 7 nodes: 0.032 ms/i 2.8%
* onnx.Erf 1 nodes: 0.024 ms/i 2.1%
* Reduce<Max> 1 nodes: 0.023 ms/i 2.0%
* Cast 23 nodes: 0.022 ms/i 2.0%
* Gather 2 nodes: 0.019 ms/i 1.6%
* Exp 1 nodes: 0.014 ms/i 1.2%
* MoveAxis 4 nodes: 0.012 ms/i 1.1%
The bench on the M1 (where the quantized operations are easier to optimise, because Intel is not great on them) also suggest to look at the "rest" first. 40% in products is a very atypical load for a NN.
@janpetschexain One final remark: the input you gave me "feels" short. Is it typical ? If it's not, it would be worth trying again with a more realistic example.
tract Intel AVX2 QMatMul could probably be improved. AVX2 does not offer intel fused multiply and add (arm64 does). The inner loop uses vpmullw (so we promote the bytes to i16 first, then perform the products and addition as two separate ops). There is no vpmulwb (so we can not perform the products on the bytes then promote the words to i32 and add them).
I had a look at what ONNX Runtime does: something that may help is https://www.felixcloutier.com/x86/pmaddwd , but to make the best use of it, we need to change the packing, pairing values for two consecutive "k" together. we could go as far as https://www.felixcloutier.com/x86/pmaddubsw actually, and multiply the bytes themselves... provided they are i8 and u8 mixed, whereas we tried to standardized around i8xi8.
Generally, it suspect that on intel, with AVX2 and FMA (which is what most PC have these days), integer operation are too partial to beat the float fma kernels... Of course, I would be happy to be proven wrong.
thanks a lot for the detailed analysis 👍 we'll have to take this into our plannings to decide how to proceed on this topic.
One final remark: the input you gave me "feels" short. Is it typical ? If it's not, it would be worth trying again with a more realistic example.
the given sequence itself of course doesn't have much meaning, but the resulting tokenization is somewhat representative to our use cases, because the tokenization of the sequence is padded/truncated to a fixed length (64 tokens in the given io.npz
) independent of the word count of the sequence. we currently use 52 tokens for one use case and 90 tokens for another use case, if it helps i can provide more meaningful input sequences for those sizes, but i won't expect the analysis to change much due to the way the tokenization works and because the hidden and final states of the model are vectors of length 128 anyways.
a short update, while i was on vacation my colleagues did some profiling on mobile (android aarch64; mobile and wasm are our main targets) to check the difference in replacing the Sub
operation. there are still some Sub
operations left, but it should be possible to replace the Sub
in the offsetting to allow for fusing operations in LirMatMulUnary
again.
without replacement:
Most time consuming operations
* LirMatMulUnary 7 nodes: 2.356 ms/i 31.4%
* Mul 46 nodes: 0.871 ms/i 11.6%
* DynamicQuantizeLinearU8 8 nodes: 0.824 ms/i 11.0%
* Sub 15 nodes: 0.552 ms/i 7.4%
* Reduce<Sum> 11 nodes: 0.541 ms/i 7.2%
* Add 20 nodes: 0.537 ms/i 7.2%
* Cast 23 nodes: 0.326 ms/i 4.3%
* MatMul 2 nodes: 0.314 ms/i 4.2%
* MoveAxis 4 nodes: 0.263 ms/i 3.5%
* onnx.Erf 1 nodes: 0.223 ms/i 3.0%
* MatMatMulPack 7 nodes: 0.185 ms/i 2.5%
* Reduce<Max> 1 nodes: 0.097 ms/i 1.3%
* Gather 2 nodes: 0.083 ms/i 1.1%
* OffsetU8asI8 20 nodes: 0.069 ms/i 0.9%
* Exp 1 nodes: 0.059 ms/i 0.8%
* DequantizeLinearF32 2 nodes: 0.054 ms/i 0.7%
* AddAxis 74 nodes: 0.054 ms/i 0.7%
* Square 3 nodes: 0.051 ms/i 0.7%
* Slice 1 nodes: 0.010 ms/i 0.1%
* Reshape 4 nodes: 0.009 ms/i 0.1%
* FlippedShiftLeft 2 nodes: 0.004 ms/i 0.1%
* Tanh 1 nodes: 0.004 ms/i 0.1%
* RmAxis 5 nodes: 0.004 ms/i 0.0%
* Rsqrt 3 nodes: 0.003 ms/i 0.0%
* Source 3 nodes: 0.002 ms/i 0.0%
* Recip 1 nodes: 0.002 ms/i 0.0%
* Const 2 nodes: 0.002 ms/i 0.0%
By prefix
0.292 ms/i 3.9% 162_QuantizeLinear
0.906 ms/i 12.1% MatMul_109_quant
0.810 ms/i 10.8% MatMul_109_quant.matmul
0.312 ms/i 4.2% MatMul_29_quant
0.226 ms/i 3.0% MatMul_29_quant.matmul
0.310 ms/i 4.1% MatMul_31_quant
0.226 ms/i 3.0% MatMul_31_quant.matmul
0.308 ms/i 4.1% MatMul_33_quant
0.227 ms/i 3.0% MatMul_33_quant.matmul
0.594 ms/i 7.9% MatMul_68_quant
0.581 ms/i 7.7% MatMul_73_quant
0.308 ms/i 4.1% MatMul_85_quant
0.225 ms/i 3.0% MatMul_85_quant.matmul
1.017 ms/i 13.6% MatMul_99_quant
0.807 ms/i 10.8% MatMul_99_quant.matmul
0.111 ms/i 1.5% MatMul_99_quant_output_scale_mul
0.265 ms/i 3.5% Softmax_72
with replacement:
Most time consuming operations
* LirMatMulUnary 7 nodes: 2.364 ms/i 31.2%
* Mul 46 nodes: 0.895 ms/i 11.8%
* DynamicQuantizeLinearU8 8 nodes: 0.830 ms/i 11.0%
* Add 24 nodes: 0.671 ms/i 8.9%
* Reduce<Sum> 11 nodes: 0.546 ms/i 7.2%
* Sub 11 nodes: 0.433 ms/i 5.7%
* Cast 23 nodes: 0.331 ms/i 4.4%
* MatMul 2 nodes: 0.305 ms/i 4.0%
* MoveAxis 4 nodes: 0.263 ms/i 3.5%
* onnx.Erf 1 nodes: 0.218 ms/i 2.9%
* MatMatMulPack 7 nodes: 0.187 ms/i 2.5%
* Reduce<Max> 1 nodes: 0.098 ms/i 1.3%
* Gather 2 nodes: 0.084 ms/i 1.1%
* OffsetU8asI8 20 nodes: 0.069 ms/i 0.9%
* Exp 1 nodes: 0.066 ms/i 0.9%
* AddAxis 74 nodes: 0.057 ms/i 0.8%
* DequantizeLinearF32 2 nodes: 0.054 ms/i 0.7%
* Square 3 nodes: 0.052 ms/i 0.7%
* Slice 1 nodes: 0.010 ms/i 0.1%
* Reshape 4 nodes: 0.009 ms/i 0.1%
* FlippedShiftLeft 2 nodes: 0.005 ms/i 0.1%
* Tanh 1 nodes: 0.005 ms/i 0.1%
* RmAxis 5 nodes: 0.004 ms/i 0.1%
* Rsqrt 3 nodes: 0.003 ms/i 0.0%
* Source 3 nodes: 0.003 ms/i 0.0%
* Recip 1 nodes: 0.002 ms/i 0.0%
* Const 2 nodes: 0.002 ms/i 0.0%
By prefix
0.294 ms/i 3.9% 162_QuantizeLinear
0.911 ms/i 12.0% MatMul_109_quant
0.812 ms/i 10.7% MatMul_109_quant.matmul
0.316 ms/i 4.2% MatMul_29_quant
0.228 ms/i 3.0% MatMul_29_quant.matmul
0.311 ms/i 4.1% MatMul_31_quant
0.227 ms/i 3.0% MatMul_31_quant.matmul
0.310 ms/i 4.1% MatMul_33_quant
0.227 ms/i 3.0% MatMul_33_quant.matmul
0.593 ms/i 7.8% MatMul_68_quant
0.582 ms/i 7.7% MatMul_73_quant
0.309 ms/i 4.1% MatMul_85_quant
0.226 ms/i 3.0% MatMul_85_quant.matmul
1.022 ms/i 13.5% MatMul_99_quant
0.808 ms/i 10.7% MatMul_99_quant.matmul
0.112 ms/i 1.5% MatMul_99_quant_output_scale_mul
0.276 ms/i 3.7% Softmax_72
Not accounted by ops: 0.393 ms/i 4.9%
Entire network performance: 7.958 ms/i
Good thing that native intel is not a target, because as I was saying before, it's really hard to optimize the quantizes ops for it. Glad that we don't have to consider the pmaddubsw right away.
I would be curious to know which kernel selection happens on aarch64/android. Do you get the generic one or an optimized one ? We may need to tweak the kernel selection code as it relies on linux /proc/cpu & co in the current state (not for aarch64 detection itself, but for CPU variant, we have specific stuff for A53). It should be obvious from the log if you already capture tract logs.
We will do some work on optimise quantized models, and aarch64 is a critical target for us, so it is likely that there will be some improvements there. I'm actually working on generalizing the handful of fusable operations, on some models i8/i32 i've played with, I think it should yield about 10% or performance.
From the logs I think that we are using the optimized kernel on both android and ios because of the presence of FMA
, but I'm not sure that is what I'm suppose to look at. I attached both logs.
Do you think we can do something to improve the performance on wasm? Maybe with the new simd instructions there will be possible to write an optimized kernel for it? Unfortunately we are not yet able to capture logs from the browser.
You can see which kernel is used by dumping with dump --info
: on intel for f32, for instance, it shows * Mult: (fma 16x6)
in the LirMatMulUnary blocks.
I'm not sure what is the current status of simd on wasm. Last time I checked (months, maybe years ago) I came to the conclusions that is was not ready yet. But it would definitely help. Basically what would have to be done would be to implement something equivalent to what lives linalg/src/generic/ in a faster form.
If wasm simd intrinsics are available in Rust, it can be a way to go, and I have little doubt it will improve the performance if we do it right. If this can not be done this way, assembly comes next, just like I did for arm and intel, in webasm+simd form. As a matter of fact, for a matrix multiplier, I'm convinced it is worth going the second way (straight to assembly) and not use intrinsics (because you get explicit control of register placement).
Checking this... https://v8.dev/features/simd .
Have you tried just activating simd on wasm ? With RUSTFLAGS="-C target-feature=+simd128" cargo build
? It may be enough for the autovectorizer to make a dent.
Also, going "full assembly" may not be the best compromise here, as, as far as i can tell, webasm does not deal directly with register alocation anyway. So if you can work in nightly rust, going the rust-with-intrinsics way looks pretty good (and i expect them to be stabilized relatively quickly in Rust anyway)
On both android and ios seems to use arm64simd (generic)
, the logs are attached.
I have to rerun it to get the numbers, but we tried with wasm+simd and it was faster with it but still slower then the original model.
The current head may be a bit better at fusing operations if you care to give it a shot. But don't expect miracles.
I'm feeling ready to close this issue. There will be more work on the i32 kernels for arm64 in a few months, but that's a different story. Any opinion ?
Can we just use the i8xi8 path by applying an offset and playing with zeropoint ?