ggerganov / llama.cpp

LLM inference in C/C++
MIT License
68.07k stars 9.76k forks source link

[Enhancement Proposal] Single Model Speculative Sampling (+ possible future dynamic bit-depth models) #3269

Closed tysam-code closed 7 months ago

tysam-code commented 1 year ago

Hi,

I'm currently working on developing a different feature that I plan on submitting at some point. In the process of that, I took some time to work on fast sampling. As I'm out of bandwidth at the moment (have a ton to juggle), I'm putting this out here as an enhancement proposal, and would like to see if I could get someone interested in working on it. I could use your help! <3 :') I think it is worthwhile not just because of the present value it brings, but also because of certain future optimizations that it enables.

This may or may not not be the right time for this feature because I believe it depends upon two limiting factors:         A. a good-enough uniform quantization scheme, and         B. efficient utilization of reduced compute from lower bit depths.

TL;DR: For a sufficient non-dynamic quantization scheme, dropping the least N significant bits of a model for the initial speculative sampling generation should be a more unbiased estimator of that model's predictions than an outside model, meaning we only need to keep one model in memory. It also allows for other optimization methods to be implemented (dynamic bit-dropping per layer per forward pass, etc).

Introduction

Llama.cpp models generally use integer quantized weights currently, with some recent developments in dynamic compression with K-Quants. A recent method, Speculative Sampling (https://arxiv.org/abs/2302.01318), was proposed by researchers at Deepmind. I'm proposing a simpler successor to the method -- Single Model Speculative Sampling, which uses only a single large model, I share some details about its benefits, as well as directions for 'fewest cuts' and 'max performance' implementations.

Speculative Sampling uses a smaller, noisier model to create candidate tokens, which are all validated at once by a larger, more accurate model. My best understanding is that this is analogous to branch prediction in computing. However, this requires two different models, which adds not only complexity, but memory cost, and oddly enough, a potential performance hit.

I'm proposing Single Model Speculative Sampling, which uses the same overall methodology as Speculative Sampling (forward pass sampling, token acceptance/rejection, etc) and should be faster and more accurate, with less memory when ideally implemented. I do not believe this method is initially naively compatible with K-Quants, but I also believe that can be accomodated.

Proposal

Here is how it works: For this method, we are assuming direct integer multiplication. Let's assume our model has weights stored in num_bits bits, and we define a num_dropped operator to measure how many bits we want to drop.

For a model already held in memory, we first run the necessary forward passes by _truncating the model weights by dropping the num_dropped least significant bits_. This does not need to occur to all layers, and we can simply pick and choose which layers offer the best speed/truncation advantage at runtime. This offers the benefit of potentially caching values, like the initial embedding layer, if we choose not to drop bits on those (so we can reuse it easily for the larger model).

We also then need to bitshift some scale value left by num_dropped bits on the weight value at some point (where the most efficient point in the code for this is, I can't say).

Here's a figure to demonstrate a 3 -> 2 bit drop showing the higher bits binning strategy.

lsb_down_demo

This might seem a bit simplistic...and it is! This only affects the actual precision of the multiplication where it counts. If we have kernels/functions that can take advantage of this bit width, then that's a win for us. Depending upon quantization implementation details, there might be a mild % accuracy hit for the smaller model compared to a version naturally quantized at that bit depth.

That said, because a reduced-bit quantization version of a larger model is a less-biased (and potentially unbiased) estimator of the larger model's predictions than a smaller model, we should have fewer branch mispredicts.

Implementation

Two methods of implementing this come to mind.

  1. Trading extra memory for simplicity -- just generating a lower-bit-depth copy of some of the weights of the model in memory, and using that with the correct function for that bit depth from there on out.
  2. Rewriting the functions to either take in a stride, a variable number of bits and just read the top N bits, or something that dynamically operates on the data structure (I fear this will interfere with the packing/unpacking schemes used currently in the data structures. It's hard to understand all of what's happening due to all of the duplicated code between the N bit depth functions :'/)

1 does not fully fulfill the promise of 'no extra memory', but is much simpler. #2 is much more complex, and I feel like it might require (please don't hate me!) a low boilerplate, well-templated rewrite of the existing methods to be compatible.

I prefer #2, though I'm not nearly skilled enough in the details of AVX2/AVX-512/NEON/metal etc to make a good guess about how worthwhile this route is. :'))))

Conclusion

All in all -- this method feels very much like 'the right way of doing things' for this particular problem. It's neat, resource efficient, and will scale as our capability to quantize models and run lower-precision compute kernels becomes more and more efficient. However, it does make some assumptions about ideals (larger models quantized more are better than smaller models quantized less, lower bit computations scale well, etc) that may not be well enough established in the real world to be effective. I don't believe that's necessarily the case, and I think it wins on the convenience factor alone. With some good heuristics, I could see a pretty sizeable boost in generation speed.

Additionally, in the future, based on some criterion (during each forward pass of the 'light' model or beforehand), this allows for us to choose what bit-depth to use on-the-fly. For example, in a 5 bit model run AOT in 3 bits, one might find that a small delta in the first residual is correlates with resilience to lower bit depths later on, and the bit precision is dropped to 2 until a residual has a high magnitude (in which case it could possibly be re-run in 3 bit precision and still have a speed advantage, depending upon frequency). This is just an example scenario, what the actual 'tells' are of a model's resilience to dynamic runtime precision I cannot say, though I'd be surprised if there wasn't at least some research about it.

To any one reading this -- I know this was quite a lot of text, I hope you find it interesting and compelling enough to take a crack at! I'm working on my own project and hoping to bring it to fruition, and hopefully this is a useful feature that is compelling enough for some people! :) <3

nonnull-ca commented 1 year ago

Unfortunately, this approach results in the same memory bandwidth usage for the weights as running the full model - and as far as I can tell we're typically memory-(bandwidth-)bound already on common setups. I can see how this could help on setups that have extremely little processing compared to their memory bandwidth.

An alternative that would help for memory bandwidth, although would be substantially more complex in other ways, is to split the model bitwise - instead of storing weight 0 bit 0 / weight 0 bit 1 / weight 0 bit 2 / weight 1 bit 0, ... you'd instead store weight 0 bit 0 / weight 1 bit 0 / ... weight 0 bit 1 / weight 1 bit 1 ... weight 0 bit 2 / weight 1 bit 2 .... To support dynamic num_dropped you'd want to split it fully bitwise; for a static num_dropped you could instead just split it into 2 concatenated arrays (one storing the 'base' bits, the second storing the 'extended' bit(s)).

You can do this transformation in-place from the normal packed format at load time - though as to if you can do it efficiently is another matter.

tysam-code commented 1 year ago

Hi, @nonnull-ca , I'm not sure if I completely understand what you're saying here. This proposal is about Speculative Sampling, as proposed in https://arxiv.org/abs/2302.01318, which uses two models -- both the full-sized model and the smaller model.

This proposal notes that the separate smaller model in the original paper is a biased estimator of the larger model, and under some constraints, a truncated version of the larger model is a less-biased estimator of its own outputs, not only reducing the space taken by the smaller model, but also potentially improving performance.

w.r.t. bit extraction, there's already high and low bitrange extraction, for example, with bit masking, IIRC, and some code branches just load them one by one. I would personally prefer contiguous storage since masking is a lot easier, but either way one 'slices' it, the model likely will have to stay in main memory. This can actually reduce the overall memory bandwidth load overall compared to the original method, I believe.

Ultimately a single pass of the full model will almost always likely be the main bottleneck.

Hope this helps clarify things a bit.

nonnull-ca commented 1 year ago

Ah. I think there's two different ideas here, hence the confusion:

Idea 1: use a duplicated copy of the original model at a lower bitdepth as the speculative model. Idea 2: use a subset of the bits of the original model directly as the speculative model, without duplicating said weight bits in memory.

1) is interesting, but does not appear to save space versus using a different model for the speculative model. 2) is what I thought this issue was going for - hence the "meaning we only need to keep one model in memory" mention in the first comment.

tysam-code commented 1 year ago

Yes, I did touch on those both in different parts of the post.

You'll see Idea 1 reflected in the implementation section. This particular way of implementing things is a stent method that would offer a good 'halfway place' in development instead of a big ol' giant leap to a rework of a low-level kernel. It's an instantiation of the software dev practice of keeping deltas small while marching towards the goal. Especially in an individual-contributor culture, little stent methods are crucial, so this is one way to break the problem out into multiple different pieces so we can get our hands on how it works (raw performance statistics, for example). There are ways to break the problem even smaller, but those will likely need to wait until the time comes for their own discussion.

Idea two Is what I'm talking about primarily in the proposal, and what much of the technical discussion is about (as well as the image). It is specifically the higher bits, and the properties of those higher bits under certain compression schemes that allow direct truncation while being a nearly-ideal quantized estimator of the higher bit values.

What makes this method different is that it takes a method that necessarily uses two separate models, and modifies it so it only requires one. There are some technical pros and cons to different ways of doing things, which is important to keep in mind as there is quite a lot of copypasta&hardware specific kernels, making the management burden high.

You might want to take a snoop through the original post again, most of the points you're making are things I covered in there. I'm open to answering any specific questions beyond that, however.

nonnull-ca commented 1 year ago

Idea two Is what I'm talking about primarily in the proposal, and what much of the technical discussion is about (as well as the image).

Right, so I did indeed read the original proposal correctly in that it was focused on idea 2. My point is that idea 2 is going to be substantially slower than you think, because from a memory bandwidth perspective it is not a smaller model, because GPUs can't load a subset of the bits from a cache line (or rather, they can, but not in a way that helps effective bandwidth.)

Let me elaborate:

(Note: numbers chosen to make the math somewhat neater.)

Normally, if you have, say, a 80B 8bpw model doing inference, you need to load 80GB/token from GPU RAM when doing inference at a minimum, full stop.

The idea behind speculative sampling is that you can use, say, a 40B 4bpw model, with occasional batch processing using the full 80B 8bpw model. Let's say for instance that you try to speculate 8 tokens, and on average can speculate 5 tokens. Then you need on average (8*20GB + 80GB) / 5 tokens, or only 48GB/token loaded from GPU RAM.

This is a significant savings when you're GPU-memory-bandwidth bound. So far so good.

Now, let's say that instead you decide that you'll use the high 2 bits of your 80B 8bpw model directly in place for your speculative model, merely running kernels that ignore the low bits of the model. Naively, this is the same size as our previous speculative model - 40B 4b weights, i.e. 20GB of weights in the one; 80B 2b weights, i.e. 20GB of weights in the other. So far so good.

But how much memory bandwidth is required? Well, you need to pull in every cache line used from GPU RAM into the GPU for processing. You're dropping said low bits and not processing them - but the cache system still needs to pass them around. So your speculative pass, despite being 80B "2b" weights, still needs to load all 80GB of data from GPU RAM before dropping 6/8ths of the bits loaded. So then. How much do you need per token? Well, it's (8*80GB + 80GB) / 5 tokens... or 144GB/token. This is actively worse than no speculative sampling.

As I said originally: I can see how this helps in cases where you are sufficiently compute-bound that this is still faster; I don't see how it helps in memory-(bandwidth-)bound cases as-is - and as far as I can tell we typically are indeed memory bandwidth bound in most cases.

tysam-code commented 1 year ago

Well, to try to ground things a bit -- like everything in CS -- it depends! :) :D :') :') For memory-bandwidth-limited kernels, it can be slower to load weights from memory, but that depends upon a plethora of factors -- cache line structure, how long the weight is held in memory for a set of operations before being released, etc. A lot of those details I consider to be 'boring basic' stuff in the implementation details that I leave assumed, but I think it's important to bring that up here for the sake of importance.

A lot of asymmetric operations aren't necessarily memory-bound on the weights side -- it would be on the side of the precision of the inputs. Combining clever cache line structure strategies with strategies that maximize weight reuse in-loop would make the relative load times negligible in comparison, for example, though the latter should already be a standard part of most dot product functions.

Another potential contention is device type as well, the performance characteristics of CPUs, GPUs, etc are quite different. There's also a few other special-case scenarios that I think that this particular mode of operation would be well-suited for, even more than CPUs and GPUs, but I feel the conversation here is a bit too vigorous to open that topic for now.

I do not want to offer discouragement, but let's maybe go back to the big picture a bit and bring the concept explanations down a notch or two. You might be misunderstanding my level of experience on some of these things. That said, I appreciate that you are contributing to the conversation.

Unfortunately I think part of the hardest part of implementation will actually be the legacy code as it exists here and the kl divergence (as it were) between that an efficient technical solution. There's a few clear options (I think depending on CPU, GPU, etc), but how it would work in the codebase with backwards compatibility is my main thought for discussion. Unfortunately boring and not-fun-at-all but ultimately I think the most important side of things, since it is what would bring it from 'well-grounded concept' -> 'initial slow prototype' -> 'performant solution ready for porting to other architectures + PR'.

That's how I currently see it going, at least.

github-actions[bot] commented 7 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.