turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.19k stars 234 forks source link

optimization: reduce GPU transfers #456

Closed kallewoof closed 1 month ago

kallewoof commented 1 month ago

This PR adds a ParamCache to the qparams file, which manages a cache of nn.Parameter(option[idx].weight.cuda()) elements.

It also adds a new batch_test_error function which takes a model factory and factory parameters, and processes each layer in turn, for all model variations, i.e. the opposite of what is happening already.

This gives a 30% increase in speed for the attention layer, at a cost of some additional VRAM usage.

However, applying this method to the measure_mlp function does not give significant speed-ups, so the VRAM bump is not worth it. I'm still investigating why this is the case.

Supercedes (incorporates) #455.

This PR primarily targets the calibration feature, but the .item() tweak to test_error snuck its way into the quantize part. Can remove. Do intend to address quantization optimizations.

==================================================================

                               attn

==================================================================

BASE MASTER:

VRAM PEAK = 2715 MB

--------------------------------------------
| Measured: model.layers.0 (Attention)     |
| Duration: 96.49 seconds                  |
| Completed step: 1/163                    |
| Avg time / step (rolling): 96.49 seconds |
| Estimated remaining time: 260min 31sec   |
| Last checkpoint layer: None              |
--------------------------------------------

------------------------------------------------------------------

MASTER + PARAMCACHE + BATCHING

VRAM PEAK = 6200 MB

--------------------------------------------
| Measured: model.layers.0 (Attention)     |
| Duration: 60.14 seconds                  |
| Completed step: 1/163                    |
| Avg time / step (rolling): 60.14 seconds |
| Estimated remaining time: 162min 23sec   |
| Last checkpoint layer: None              |
--------------------------------------------
turboderp commented 1 month ago

I did at one point keep the entire calibration state in VRAM, which was significantly faster, but it also made it impossible to quantize large models. Then I made it switchable, causing great confusion, and finally I arrived at always swapping as the best compromise between reliability and performance.

This is also in part because, as you've noted, PCIe traffic matters much less when processing the MLP modules, and they account for most of the time spent in measurement. It's a real struggle to get a large enough chunk of the full-precision model to fit sometimes.

One thing I've considered is a better pipeline for swapping. Currently all data is swapped synchronously, but it could be asynchronous via a pinned buffer. It could maybe work in conjunction with this caching mechanism.

I'm currently a little preoccupied with dynamic batching, but I'll give it all some attention real soon.

kallewoof commented 1 month ago

I did at one point keep the entire calibration state in VRAM, which was significantly faster, but it also made it impossible to quantize large models. Then I made it switchable, causing great confusion, and finally I arrived at always swapping as the best compromise between reliability and performance.

I think this modification results in zero peak VRAM increase (because the MLP layer processing uses more even without the batching), and is a pretty significant boost to the attn layer processing speed.

This is also in part because, as you've noted, PCIe traffic matters much less when processing the MLP modules, and they account for most of the time spent in measurement. It's a real struggle to get a large enough chunk of the full-precision model to fit sometimes.

I see! Thanks for confirming that.

One thing I've considered is a better pipeline for swapping. Currently all data is swapped synchronously, but it could be asynchronous via a pinned buffer. It could maybe work in conjunction with this caching mechanism.

Not sure how that works but will look around.

I'm currently a little preoccupied with dynamic batching, but I'll give it all some attention real soon.

Thanks for the heads up, I'll be patient. :)

kallewoof commented 1 month ago

I'm closing this as I'm personally not convinced myself that it's worth it with the added technical debt. A simple 'cache hidden/training states up to a given amount' would probably be good though, and simpler, so I'm going to take a stab at that.