Closed saharNooby closed 11 months ago
On my machine, offloading head of 14B model in addition to offloading all layers gives
60 ms
per token latency vs70 ms
without head offloading.
You're right, there is a significant speedup here, I just tested this on my machine and I get the same sort of speedup. Processing 128 tokens in sequence mode goes from, funnily enough, 70ms to 60ms if I offload the model head.
I just slapped this
if (i == 0) {
offload(ctx->instance->model.head);
}
into the rwkv_gpu_offload_layers
loop in order to test. It does literally shave 10ms. It also shaves nearly 50% off evaluating the same sequence in serial mode.
This ends up offloading the head automatically once the first layer is offloaded, which might be the best approach with the current API.
This makes me want to introduce an optimization where passing NULL logits to rwkv_eval does not evaluate logits at all
This makes me want to introduce an optimization where passing NULL logits to rwkv_eval does not evaluate logits at all
This creates further speedups (by 20% when evaluating 128 tokens using serial mode)
This ends up offloading the head automatically once the first layer is offloaded, which might be the best approach with the current API.
I implemented it a little differently -- if the last layer was just offloaded, then head
is offloaded too. Or, we can treat head
as a separate "layer" (AFAIK llama.cpp
treats embedding
/head
like this) -- if a model has 24 actual layers, and we offload 25, then head
is offloaded too.
This makes me want to introduce an optimization where passing NULL logits to rwkv_eval does not evaluate logits at all This creates further speedups (by 20% when evaluating 128 tokens using serial mode)
Sounds great, I support it!
Currently, only matrices of layers are offloaded to the GPU. Head, the biggest matrix in the model, stays on CPU and evaluated there.
On my machine, offloading head of 14B model in addition to offloading all layers gives
60 ms
per token latency vs70 ms
without head offloading.As always, the hardest question here is API design -- we need to preserve compatibility and not inflate API with new small functions.