Closed IngwiePhoenix closed 1 year ago
Thanks for the info
i wonder if it could be made faster by making sure the model is in RAM
maybe see if subsequent runs are faster once the model is cached?
Thanks for posting this.
Just as a heads up, the RK3588 does have NPU units on it but these are not leveraged with the llama.cpp codebase (at time of writing). If other devs are interested, the NPU API for this can be found in this file: https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h
Note: I'm sure I've read somewhere that INT4 tensors should be supported, but I cannot see them in that API. Also, I believe the model might have to be converted to a specific RK3588 format (toolkit link in the root README.md)?
I did actually expect far better performance even with the CPUs only though with a 7B model. I notice this is an 8GB RK3588, so maybe there was a lot of memory swapping happening that slowed it down.
I don't have any chips with RK3588 yet, but if I manage to get one, I'll try to do some testing on my side. Might make great little units for running a dedicated assistant on if it can be optimized well.
If there is a specific test you want me to run, let me know!
I don't have any swap configured, regrettably. But what could easily have happened is that because this was running literally alongside my homeserver stuff, that memory management on th e kernel side got quite hectic. :)
Also, llama.cpp has improved a lot since last time - so I might just rerun the test, to see what happens. Also, Vicuna and StableLM are a thing now. Might as well give it a shot... that said, I'd have to think of a good way to gather the output into a nice table structure - because I don't want to flood this ticket, or anyone else, with a crapton of redundant output. xD
That all said, there is one more thing:
# dmesg -l err | grep -i npu
[ 3.702909] RKNPU fdab0000.npu: can't request region for resource [mem 0xfdab0000-0xfdabffff]
[ 3.702953] RKNPU fdab0000.npu: can't request region for resource [mem 0xfdac0000-0xfdacffff]
[ 3.702978] RKNPU fdab0000.npu: can't request region for resource [mem 0xfdad0000-0xfdadffff]
[ 3.707178] debugfs: Directory 'fdab0000.npu-rknpu' with parent 'vdd_npu_s0' already present!
[ 3.729270] RKNPU fdab0000.npu: failed to find power_model node
[ 3.729289] RKNPU fdab0000.npu: RKNPU: failed to initialize power model
[ 3.729297] RKNPU fdab0000.npu: RKNPU: failed to get dynamic-coefficient
Thanks to RockChip's - at least in my experience - rather spotty documentation, I couldn't figure out if these messages were relevant or not. Though it'd actually be interesting to see INT4 on this.
I did a quick test with this on Orange Pi 5 16GB using a 7B Q5_1 model. My setup is a bit clunky, so I don't have a proper benchmark (will re-run and edit in next week when I'm setup better), but I'd estimate performance at around at almost 1 token/sec. This was using 7 threads. The heatsink became pretty hot to touch - I suspect the slower performance above might've been due to either a) memory constraints or b) thermal throttling.
Would love to see how well this could run if leveraging the NPU, but I don't think the RK SDK supports INT4 quant yet. Basically, the RK process is that the models have to be converted into an RK-compatible format using their SDK's, so the quantization probably won't be great using that approach.
I haven't looked into whether RK API is low-level enough that it might be able to support running GGML models yet, but that'd probably work better than using whatever quantization process RK SDK may eventually support
I tinkered around a bit more with this last night.
I was able to get around 500ms/token using 4 threads on a 7B Q5_1.
I also played around with the new OpenCL implementation (using CLBlast), but this was significantly slower if I transfer all layers to GPU (> 1s/token). I don't have time to thoroughly investigate but, looking at the GGML OpenCL implementation, I suspect a lot of the slowdown might be how memory is handled.
In the OpenCL implementation, it looks like the tensors might be copied to the GPU as opposed to using a pointer to the Host Memory (I noticed some loops in there that do this). This makes sense for non-iGPUs (as they have their own VRAM), but probably results in unnecessary copy op's for devices with shared RAM/VRAM like the RK3588 (and AMD APU's for that matter). I believe there are flags that can be used to simply point OpenCL to host memory, but I'm unsure whether it would be compatible with the GGML tensor format. Might be a worthy optimization to consider though if it would speed up inference on AMD APU's also.
Side-note: I have tiny heatsinks on my Orange Pi 5. These get quite hot and I notice inference time slows down quite a bit as they heat up, so assuming the device gets underclocked to maintain safe temperatures.
if i had a 3588 i'd totally be down to fuck around with this, can anyone point me to a relatively-cheap 3588 dev board?
edit: 8GB if possible
I probably can't recommend a specific board sorry. I haven't priced them out.
Just want to add to this though - the guy that's been doing a lot of the work on the llama.cpp GPU implementations isn't sure if optimizations to the OpenCL code will yield that much benefit for boards like this. He posted the following graph yesterday indicating that the big bottleneck appears to be memory.
@spv420 here are some links - I have Orange Pi 5 & 5b plan on purchasing NanoPC-T6 & Orange Pi 5 Plus as well
NanoPC T6 https://www.friendlyelec.com/index.php?route=product/product&path=69&product_id=292
Not sure if this helps the discussion. I made a fork that supports the RK3588 NPU via the matrix multiplication API. Unfortunately it is not faster then just using the CPU and generates questionable output due to running in int8 mode (FP16 is too slow).
Feel free to contribute, and see if anyone can work around the accuracy issue. I have a prototype that gets up to 10% faster by chunking operations. But it's complicated and I feel not worth the work if all I'm able to get is hallucinating outputs.
I'd love to upstream the code. Please contribute if you are also interested in the subject
Not sure if this helps the discussion. I made a fork that supports the RK3588 NPU via the matrix multiplication API. Unfortunately it is not faster then just using the CPU and generates questionable output due to running in int8 mode (FP16 is too slow).
Feel free to contribute, and see if anyone can work around the accuracy issue. I have a prototype that gets up to 10% faster by chunking operations. But it's complicated and I feel not worth the work if all I'm able to get is hallucinating outputs.
I'd love to upstream the code. Please contribute if you are also interested in the subject
Thanks for this! I looked into it at one point too, but I think the bottleneck will be the RAM speed on the Pi 5? This approach might still be able to speed up prompt ingestion substantially though.
Do you know if using the NPU reduces power consumption? I'm an idiot and installed a tiny heatsink on my Pi 5, so it throttles very quickly.
Will try and give your fork a go next week when I get some time.
but I think the bottleneck will be the RAM speed on the Pi 5?
No, the NPU on the RK3588 is really, really bad at matrix multiplication. It's designed for vision models thus focused on convolution. It has a pretty low FLOPS when doing matrix multiplication.
This approach might still be able to speed up prompt ingestion substantially though.
Maybe, but the inaccuracy is quite significant. I am not sure what'll happen.
Do you know if using the NPU reduces power consumption? I'm an idiot and installed a tiny heatsink on my Pi 5, so it throttles very quickly.
I think it can. But not with my backend in the current state. My backend only uses 1 thread out out all given by GGML. And GGML will spin non-working threads. It's a design flaw in GGML itself and needs major refactor. Can't just use 1 thread either. Some matrices are too large to fit on the NPU. It's possible to split the work and distribute to different NPU cores. But I it's too much work for little gain (as the model is hallucinating constantly).
To compile and run my fork. I don't recommend running more then 13 layers or a 7B model on the NPU. It starts going crazy afterwards. I develop with 10.
cmake .. -DLLAMA_RKNPU2=ON
make -j
./bin/main .... -ngl 10
Also you need a Q8_0
model. It's kinda moot for lower bits since the minimal supported by the NPU is 8.
Great thread! I have a Firefly RK3588S board, so it would be great to try this out. Don't have much hope for the NPU, but am wondering if offloading matrix multiplications to the Arm Mali GPU via Arm Computer Library might be worthwhile? Any thoughts?
@prusnak I tried something similar with GGML's OpenCL backend way back. I modified it enough to get RWKV (not llama) running on the Mali GPU. it has many problems. Mainly
ACL can work. But I have question if it'll be helpful. GGML pre-transposes matrix B in matmul(A, B)
. Thus access pattern is already as efficient as it can. IMO most OpenCL compilers can easily optimize that (need to confirm by decompiling though). After getting GGML working on Mali. You'll have to choose. Either to not support k-quants and run into the same accuracy vs bandwidth tradeoff as I do with the NPU. Or support k-quants and make k-quant decompression fast on the Mali.
Good luck. I'd love to see more LLMs on the edge.
====
For anyone interested; progress update on my side. With RKNPU2 1.6.0. It almost makes sense to use the NPU. I'm less then 10% off to being faster then the CPU on INT8 mode with just 1 NPU core. Next step is to debug non-square matrix multiplication. Something somewhere is wrong.
I won't update every step here. Please either follow my fork or check my blog from time to time. Latest progress: https://clehaxze.tw/gemlog/2023/12-17-update-on-ggml-rknpu2-backend-and-rknpu2-1_6_0.gmi
@marty1885 Your work is very interesting. Have you considered running Whisper models on the NPU? Could be better suited as the models are much smaller compared to 7B LLMs and would immediately have various real-world applications.
@ggerganov Thanks, Already done by other people. https://github.com/usefulsensors/useful-transformers runs Whisper on the NPU. They are able to do much extensive optimizations compared to GGML though. The NPU demands a custom matrix layout for maximal performance. And they are able to eliminate a majority of layout conversions by abstracting them away.
Actually good idea. I can try targeting my work against whisper.cpp. Do you know any use cases for it? And what would be the process to upstream an entire new backend?
From quick look at this repo, it looks like they use the NPU just for the matrix multiplications. All other operations, such as convolutions, softmax, layernorm, etc. are on the CPU. Does the NPU API allow to implement all other ops or is it limited just to matrix multiplications?
The reason I'm wondering is that ggml
currently does not provide an efficient way to run both CPU and NPU ops in a single compute graph, because the CPU threads must remain spinning while the NPU is doing stuff. So it would be much better if we could offload the entire compute on the NPU and leave the CPU idle. Starting and stopping threads can become quite expensive, especially for smaller models, so that's why it should be avoided.
Still, if it is not possible for the NPU to do general computations, then we can perform just the heavy matrix operations in the Whisper Encoder in a similar way as we currently use BLAS. I think you've already prototyped this to a good extend in your fork. Some of the smaller matrix multiplication probably should remain on the CPU - needs experimenation.
I don't see a way around reshuffling the tensor data to fit the NPU layout. This will be some overhead that the NPU backend implementation would have to perform on the input and output data.
As long as the changes are contained as much as possible in ggml-npu.h/ggml-npu.c
, it should be easy to upstream. We would need to make some basic CI and I don't see a problem with having the backend merged, given that we see performance / energy gains.
Does the NPU API allow to implement all other ops or is it limited just to matrix multiplications?
For now it is limited to only matrix multiplications. Softmax, convolution, etc.. are locked behind their ONNX compiler and is not open source.
Yeah, reordering is a major performance bottleneck right now. I hope the vendor can solve this or at least mitigate it largely. I hope future chip designers can make data layout easy and expose more low level API.
I'll submit a PR if I made it useful/new SDK solve current problems.
@marty1885 I'm in the midst of trying to reverse engineering parts of the RK3588 NPU as I'm am keen to understand how the matrix multiplication was handled by the NPU to see if it could be optimised/open sourced. From your testing for fp16 do have any insight in to how large the matrices get for llama 7b. I'm assuming they can't be larger than [512x512] x [512x512] as that would already require 0.5Mb of memory for the output for a single operation.
@mtx512 There are 2 kinds of matrix multiplications in llama. One for the dot-product attention. Another for token processing. I never saw the the [N x N x N] matrix multiplication hit my backend. I assume either llama.cpp have special code path to handle it. Or it failed the NPU compatibility check since I only implemented matrix relayout during initialization. More likely it is the latter. I never tried to debug this since relayout is very slow and simply not worth on the fly.
The regular matrix multiplications on encoder/decoder weights are more like GEMV instead of GEMM. They have shape basically the following (note that in GGLM's source code src0
is matrix B for RKNN API. And src1
is A).
batch
here is the number of tokens in process. During prompts processing this is the number of tokens up to some parameter that can be controlled by CLI. IIRC default max is 512. And 1 during text generation. Instead of optimizing for matrix multiplication. I think it'll be much more beneficial to optimize for matrix-vector multiplication if possible, since that's what the vast majority of time spent during generation. Also nice if we can offload softmax from the CPU.
Good luck! Hope you find success.
@marty1885 I'm in the midst of trying to reverse engineering parts of the RK3588 NPU as I'm am keen to understand how the matrix multiplication was handled by the NPU to see if it could be optimised/open sourced. From your testing for fp16 do have any insight in to how large the matrices get for llama 7b. I'm assuming they can't be larger than [512x512] x [512x512] as that would already require 0.5Mb of memory for the output for a single operation.
I doubt the NPU can actually run MatMul "natively" with matrix size >= 256x256. (for ONNX models, MatMul with size equal or larger than 256x256 cannot run on NPU!)
@prusnak I tried something similar with GGML's OpenCL backend way back. I modified it enough to get RWKV (not llama) running on the Mali GPU. it has many problems. Mainly
1. ARM's OpenCL implementation is buggy and doesn't play nice with GGML 2. For some reason, the OpenCL latency is very high on my OrangePi 5 3. Decompressing k-quants requires a lot of integer operations. But the Mali GPU has 1/4 integer capacity compared to floating point.
ACL can work. But I have question if it'll be helpful. GGML pre-transposes matrix B in
matmul(A, B)
. Thus access pattern is already as efficient as it can. IMO most OpenCL compilers can easily optimize that (need to confirm by decompiling though). After getting GGML working on Mali. You'll have to choose. Either to not support k-quants and run into the same accuracy vs bandwidth tradeoff as I do with the NPU. Or support k-quants and make k-quant decompression fast on the Mali.Good luck. I'd love to see more LLMs on the edge.
====
For anyone interested; progress update on my side. With RKNPU2 1.6.0. It almost makes sense to use the NPU. I'm less then 10% off to being faster then the CPU on INT8 mode with just 1 NPU core. Next step is to debug non-square matrix multiplication. Something somewhere is wrong.
I won't update every step here. Please either follow my fork or check my blog from time to time. Latest progress: https://clehaxze.tw/gemlog/2023/12-17-update-on-ggml-rknpu2-backend-and-rknpu2-1_6_0.gmi
TVM has better support for Mali GPU with OpenCL. See MLC-LLM project. Also I have tried to run some other small models that cannot run effectively on NPU on GPU, and it performs pretty good.
RKNPU2 memory allocation size limit issue have been resolved in my fork by https://github.com/happyme531/llama.cpp/commit/eaf7a1584c180a4303e69b963b6d3293c78b5b60 But after testing there are still output quality issues even in fp16 precision. Don't know why.
@happyme531 Looks like you are right. The 1.6.0 SDK does state that the product between channels cannot be >= 65532. Maybe this is the reason? They forgot to document this limitation for the matmul API?
(For the people in this thread whom can't read Chinese, trust me) On Page 57 of 05_RKNN_Compiler_Support_Operator_List_v1.6.0.pdf
I've merged your fix into my fork.
RKNPU2 memory allocation size limit issue have been resolved in my fork by happyme531@eaf7a15 But after testing there are still output quality issues even in fp16 precision. Don't know why.
RK3588 NPU data pointers are limited to 31:0 bits (based on TRM) hence the 4GB limit. Curious why you think it can be larger?
RKNPU2 memory allocation size limit issue have been resolved in my fork by happyme531@eaf7a15 But after testing there are still output quality issues even in fp16 precision. Don't know why.
RK3588 NPU data pointers are limited to 31:0 bits (based on TRM) hence the 4GB limit. Curious why you think it can be larger?
Honestly I do not know this limit when writing this fix. No document ever mentioned it. And the resulting code runs smoothly without a single error(except the output quality issue which have many potential causes). (Probably there is actually not a issue, some sort of workaround about using >4GB memory is present inside rknn library?)
RKNPU2 memory allocation size limit issue have been resolved in my fork by happyme531@eaf7a15 But after testing there are still output quality issues even in fp16 precision. Don't know why.
RK3588 NPU data pointers are limited to 31:0 bits (based on TRM) hence the 4GB limit. Curious why you think it can be larger?
Honestly I do not know this limit when writing this fix. No document ever mentioned it. And the resulting code runs smoothly without a single error(except the output quality issue which have many potential causes). (Probably there is actually not a issue, some sort of workaround about using >4GB memory is present inside rknn library?)
The RKNN docs mention Zero-Copy apis, for these the memory has to be compatible with the NPU, so for RK3588 this would a 32 bit address in physical memory. If your providing a physical address over 4GB I'd suspect it just truncating it to 32 bits so using a random location. If you provide a virtual address then it has copy the data to a physical location in 32bit range hence performance drop.
RKNPU2 memory allocation size limit issue have been resolved in my fork by happyme531@eaf7a15 But after testing there are still output quality issues even in fp16 precision. Don't know why.
RK3588 NPU data pointers are limited to 31:0 bits (based on TRM) hence the 4GB limit. Curious why you think it can be larger?
Honestly I do not know this limit when writing this fix. No document ever mentioned it. And the resulting code runs smoothly without a single error(except the output quality issue which have many potential causes). (Probably there is actually not a issue, some sort of workaround about using >4GB memory is present inside rknn library?)
The RKNN docs mention Zero-Copy apis, for these the memory has to be compatible with the NPU, so for RK3588 this would a 32 bit address in physical memory. If your providing a physical address over 4GB I'd suspect it just truncating it to 32 bits so using a random location. If you provide a virtual address then it has copy the data to a physical location in 32bit range hence performance drop.
Are we certain there is a constraint on 32bit PHYSICAL memory address? Looking at the RK NPU API here:
... the physical address is defined as a uint64_t
.
Also, regarding the FP16 constraint, is this a hardware limitation? In theory, it looks like it should be able to support 8bit.
I've yet to play with any of this though, so take the above with a grain of salt.
EDIT: Looking at that structure a bit deeper, it looks like there is a 32bit constraint on the tensors themselves. But, if these do not have to sit (or be copied) to first 4GB of physical memory, might it be possible - given that memory is shared - to take an approach where we process with the NPU layer-at-a-time?
Also, regarding the FP16 constraint, is this a hardware limitation? In theory, it looks like it should be able to support 8bit.
It's both. GGML doesn't natively do quantized inference. "quantization" to GGMl means compressing the weights, decompress it on the fly and keep it in cache. The decompressed result in still floating point and GGML does all it's math in floating point (FP32 on CPU and optionally FP16 on GPU)
This is while the NPU expects both matrices to be the same type - both FP16 or INT8. I tried converting both weight and input into fixed point (INT8). It seems the network needs more accuracy then 8 bits else goes crazy if too many layers are run in this very limited accuracy.
It would be perfect if RKNN can support weights in INT8/INT4 fixed point but keep inputs in FP16. But I doubt that since the NPU is more like a fixed pipeline GPU in the old days.
RKNPU2 memory allocation size limit issue have been resolved in my fork by happyme531@eaf7a15 But after testing there are still output quality issues even in fp16 precision. Don't know why.
RK3588 NPU data pointers are limited to 31:0 bits (based on TRM) hence the 4GB limit. Curious why you think it can be larger?
Honestly I do not know this limit when writing this fix. No document ever mentioned it. And the resulting code runs smoothly without a single error(except the output quality issue which have many potential causes). (Probably there is actually not a issue, some sort of workaround about using >4GB memory is present inside rknn library?)
The RKNN docs mention Zero-Copy apis, for these the memory has to be compatible with the NPU, so for RK3588 this would a 32 bit address in physical memory. If your providing a physical address over 4GB I'd suspect it just truncating it to 32 bits so using a random location. If you provide a virtual address then it has copy the data to a physical location in 32bit range hence performance drop.
Are we certain there is a constraint on 32bit PHYSICAL memory address? Looking at the RK NPU API here:
... the physical address is defined as a
uint64_t
.
That is just a variable to hold the value, as I mentioned the NPU registers in the TRM dicate 32bit data pointers and my testing verifies that if I directly call the NPU. Note, there are a few peripherals on the 3588 like pcie & rga2 which have the same constraints for memory (dma) buffers.
The decompressed result in still floating point and GGML does all it's math in floating point (FP32 on CPU and optionally FP16 on GPU)
Not entirely accurate - on the CPU, the activations are quantized to 8-bits and then the dot-products within the quantum blocks are carried out with integer arithmetic utilizing the available SIMD intrinsic. The results are then scaled to F32 and accumulated across the blocks.
Huh... for clarification. On CPU, GGML casts activation to INT8 and dots that with the compressed weight. Then back into FP16?
In that's the case then there should be no reason INT8 on the NPU would cause accuracy issues... something fishy in my code then.
something fishy in my code then.
Not sure if this has been emphasized elsewhere in RK's docs, but if you're using zero-copy, might want to double check that you have it formatted as NHWC (Number:Height:Width:Channels).
On CPU, GGML casts activation to INT8 and dots that with the compressed weight. Then back into FP16?
It's not just casting to INT8, or fixed point, the activations are quantized to Q8_0 or other similar format, which includes a scaling factor per group of 32 weights. The dot product is performed with integer arithmetic, but converting back to FP32 requires scaling the result with the scales of the block.
Also wanted to follow-up on a previous comment:
My backend only uses 1 thread out out all given by GGML. And GGML will spin non-working threads. It's a design flaw in GGML itself and needs major refactor.
With the new ggml-backend
implementation, this is no longer a limitation. The CPU parts of the graph will now be executed in separate graph splits that can utilize all CPU threads available. When a CPU split is computed, the threads are joined and this way the next GPU or NPU split will run without any threads spinning in parallel.
The functionality is already implemented in llama.cpp
and there is also a simpler GPT-2 example in the ggml
reop to demonstrate that.
Exciting stuff seems to be happning at Rockchip. Their RKNPU2 1.6.1b10 contains the exact mixed type operation we need to work around low precision. I've tried using them but SDK just reports not implemented yet. We'll see when they formally release the next version. Hope they can also fix the low FLOPS.
Files (Password: rknn
):
https://console.zbox.filez.com/l/I00fc3
Slightly off-topic: @marty1885 Would this new rknpu release also help with correctly utilizing the npu and speeding up whisper.cpp?
Generically speaking yes, getting RKNPU backend into GGML will accelerate everything uses GGML. But you should really use useful-transformers. They already have whisper on the NPU. They did some deep abstraction to workaround issues in RKNPU; which is impossible in GGML.
@raystriker
Generically speaking yes, getting RKNPU backend into GGML will accelerate everything uses GGML. But you should really use useful-transformers. They already have whisper on the NPU. They did some deep abstraction to workaround issues in RKNPU; which is impossible in GGML.
@raystriker
Yeah I've been trying out useful-transformers, it's pretty impressive. But the project isn't as fully-featured and as fast moving as whisper.cpp (no offense intended to them; I'm grateful for the work they've done).
@mtx512 Nice work! Wow, it's really cool.
Can you share more about the registers and their fields? I'm very interested in building a more powerful RK3588 NPU backend. Support for convolution would enable a larger (any maybe more practical, due to less overhead?) set of use cases.
Eventually I want try and build a small compiler in GGML. Probably not as good as RKNN's. But good enough to merge common operations and enable cool networks to run. I need to build a low level API before that.
==========
Update on my part. But not to annoy everyone with a notification. I did some benchmarking for the Matmul API. As expected the NPU is very slow on GEMV. But can reach 1 TFLOPS on multiplication of matrices 128x1024x8192. However this is not super useful for llama. The more useful 1x4096x4096 is bounded to 11 GFLOPS. Even with aggressive batching, the maximal throughput is 77 GFLOPS. Likely bounded by memory.
More details on my blog. https://clehaxze.tw/gemlog/2024/02-14-benchmarking-rk3588-npu-matrix-multiplcation-performance-ep2.gmi
@marty1885 I'm still working out some of the register values they are a bit trickier than I originally thought , once I'm done I'll document.
Interesting results from the benchmarks, few comments:
@mtx512 To answer your question
I suspect the 3 cores aren't be utilised correctly, run the test against 1 core to see the difference
You are correct. RKNPU2's Matmul API only uses 1 NPU core. The driver/SDK automatically selects an idle core to use. There's options to bind the operation to a certain core. But I never find it helpful by the looks.
Change to your test to send in parallel to the 3 cores
Certainly! I've updated the benchmark to use 2 and 3 threads, each calling the matrix multiplication API. Seems the FLOPS doesn't scale linearly. With diminishing returns.
2 NPU cores:
3 NPU cores:
How about trying 4096x4096x1
I can't. The SDK requires the last dimension to be a multiple of 8 for FP16 and 32 for INT8. Only the 1st dimension is toally free.
/*
matmul information struct
*/
typedef struct rknn_matmul_info_t
{
int32_t M;
int32_t K; // limit: RK3566/3568: int8 type must be aligned with 32byte, float16 type must be aligned with 16byte;
// RK3562: int8 type must be aligned with 32byte, float16 type must be aligned with 32byte;
// RK3588: int8 type must be aligned with 32byte, float16 type must be aligned with 32byte,
// int4 type must be aligned with 32byte;
int32_t N; // limit: RK3566/3568: int8 type must be aligned with 16byte, float16 type must be aligned with 8byte;
// RK3562: int8 type must be aligned with 16byte, float16 type must be aligned with 8byte;
// RK3588: int8 type must be aligned with 32byte, float16 type must be aligned with 16byte,
// int4 type must be aligned with 64byte;
====
Edit: I just realized that the NPU driver is surprisingly readable. But is contains next to nothing about how to control the NPU https://github.com/rockchip-linux/kernel/tree/develop-5.10/drivers/rknpu
@marty1885 Thank you for testing and given I have manage to decode more of the npu behaviour. I think I can partly explain why your seeing this behaviour. The convolution core relies on a cache (CBuf) to temporary store feature + weights data while performing a convolution. For your test data lets assume the maximum weight cache size is configured at 360kb (roughly) . As you increase N and K the number of weight entries increases resulting in more hits to repopulate the cache from RAM to complete the convolution. Table below (for fp16) demonstrates how the number of CBuf refreshes increase as N and K go beyond the cache limit.
N K time (ms) No of CBuf refreshes
256 256 0.09 1
512 512 0.09 2
1024 1024 0.27 6
2048 2048 0.83 24
4096 4096 3.08 94
There are a few other variables that influence the time taken, for example after completing convolution for a single CBuf the resulting output needs to be written back to RAM.
@mtx512 Thanks for explanation. Stuff makes much sense now. I've been toying around the idea of more efficient matrix-vector multiplications, but very hand wavy as I don't have the understanding of how the NPU works on a low level. I have some questions:
m == 1
?Someone just told me there's a Mesa fork that is also attempting to add RK3588 support. There's not much actual code there yet. But worth keeping an eye on.
https://gitlab.freedesktop.org/tomeu/mesa/-/tree/rknpu?ref_type=heads
Someone just told me there's a Mesa fork that is also attempting to add RK3588 support. There's not much actual code there yet. But worth keeping an eye on.
https://gitlab.freedesktop.org/tomeu/mesa/-/tree/rknpu?ref_type=heads
I'm in touch with Tomeu as he wanted me to help test/develop the driver as part of the reverse engineering effort.
- What is causing the low FLOPS when
m == 1
?
Not sure what you mean, m==1 should be quicker than m=128 ? Regardless CBuf still comes in the picture and increasing m also effect number of CBuf refreshes and potentially more tasks are needed eg:
M N K time (ms)
1 1024 1024 0.27
128 1024 1024 0.45
1 1024 8192 1.58
128 1024 8192 5.95
The 2 TOPS figure isn't particularly useful as it just quotes the number of MAC operation the Convolution engine can complete based on the clock. As your now aware there other operations that need to be performed ie reading feature/weight data + writing output result. So you can't expect the NPU to perform at 2 TOPS.
(Cont). Do you see a way to optimize the specific case? Or is Rockchip's implementation already optimal? I've some toy ideas:
- Ex: we reorder the matrix in some way and create an layout easy for the NPU to compute and generates a partial sum. Then use the CPU to merge the partial sum
- In an 1x4096x4096 matrix multiplication. We load the [1x4096] vector into CBuf. Then continuously stream the weights into the NPU, but never write the results back into DRAM before we are finished.
- etc..
Not much you can do, its all implemented in the NPU silicon. For 1x4096x4096, the feature data 1x4096 is always kept in the CBuf and just the weight data that is refreshed. The NPU should be using dma for read/writes so it is attempting to be efficient as possible, however there is still a cost to performing read/writes. There is 1MB of SRAM on RK3588 which should be slighter faster than RAM how its not enough to hold the weight data. For this use case the CBuf would need to be 20MB (ideally greater) and directly accessible to the application for read/writes.
To speed things up, potential options:
- Is CBuf a scratchpad or cache?
It a scratchpad (aka Convolution buffer) memory which is part of the NPU hardware and inaccessible to the application.
- Does the content of CBuf persistence across tasks?
No, its populated on a convolution operation.
!! Rockchip just released its official LLM inference library: RKLLM https://github.com/airockchip/rknn-llm
I have done a quick test on my RK3588 dev board (with lpddr4x-4266 64bit dram) , the result is:
Model: qwen1.5-4b-chat
dtype: w8a8 (the only quant dtype currently supported with rk3588) - model size: 4.1GB
max_context_len=512 (seems can be set higher)
3x NPU core cowork
Performance: prefill: 55 tokens/s, decode: 3.9 tokens/s
CPU load is very low. GPU load is zero. So you can still run other applications while doing LLM inference.
NPU load is 3x ~47%
DMC (ddr controller) load is ~48%. (I don't know what does this load mean. probably 16GB/s out of 32GB/s max?)
Some notes:
diff --git a/include/linux/mm.h b/include/linux/mm.h
index dfefcfa1d6a4..c48e937b5d5d 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -3404,5 +3404,20 @@ static inline int seal_check_future_write(int seals, struct vm_area_struct *vma)
return 0;
}
+static inline void vm_flags_set(struct vm_area_struct *vma,
+static inline void vm_flags_clear(struct vm_area_struct *vma,
@happyme531 Cool! How is the text synthesis quality?
Just did a very simple run with llama-7b-4bit. It... took a while. Had it run in a screen. But, it worked!
Model was loaded from external microSD via internal bus.
Im quite amazed this worked at all, honestly.
CPU Info in detail:
(
/proc/cpuinfo
doesnt give any more useful details here, sadly.)Hardware is a FriendlyElec NanoPi R6s