turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.45k stars 257 forks source link

EXL2 Proposal: Llamacpp Allocation Heuristics and Specialisation Degree Feature #174

Closed aljungberg closed 3 months ago

aljungberg commented 9 months ago

In this recent EXL2 vs GGUF discussion, one stand-out comment was this one by llama_in_sunglasses:

I wanted a real answer about what is getting quantized more vs less so I went digging through the llama.cpp code. What happens is that some tensors get the quant level bumped up one or more notches (sometimes only at lower quant levels) and some other tensors get extra bits during certain conditions (If current layer is in the first block of (num_layers / 8), or if current layer is in the last block of (num_layers / 8), or if it's every other other layer). Output is always q6_k unless quant type is q8, there are a few special cases for falcon and just one special case for 70B. It's not bad to read in code, but it's a pain to describe in language. Here's the attn_v tensor portion (this is the most complex one).

else if (name.find("attn_v.weight") != std::string::npos) {
if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K)
new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M)
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
use_more_bits(qs.i_attention_wv, qs.n_attention_wv))
new_type = GGML_TYPE_Q6_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4)
new_type = GGML_TYPE_Q5_K;
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
(qs.i_attention_wv < qs.n_attention_wv/8 || qs.i_attention_wv >= 7*qs.n_attention_wv/8))
new_type = GGML_TYPE_Q6_K;
if (qs.model.type == MODEL_70B) {
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K)
new_type = GGML_TYPE_Q5_K;
}
++qs.i_attention_wv;

tldr; llamacpp has a handwritten heuristic to allocate extra bits to certain layers, presumably over and above what's justified by the actual measurements.

I don't know the research or experiments that lead to this, but subjectively llamacpp does quantise "better". Perhaps EXL2 and convert.py are "too good" at specialising the model towards the calibration set, leading to a kind of mode collapse on generality, and llama.cpp's "regardless of the data, just spend more bits on these layers unconditionally" helps counteract that. It's like the learning rate is too high if you singularly optimise for PPL per bit on the calibration set.

I'd love to do some experiments on this and become world famous but as usual, so little time in the day. I think replicating llamacpp's heuristic, allowing some deviation from the calibration optimal allocation is definitely worth testing.

Bonus Idea: Configurable Specialisation Degree

Hardcoded rules for bumping some layers bothers me. Who knows if that's the right choice every time, for every model.

Maybe we could directly target the loss of generality problem in a, no pun intended, general way.

I propose mixing the input calibration samples with "model average samples". Let's call this the secondary dataset, or the "control group", with the calibration set as the primary dataset. Let the user choose the ratio of primary versus secondary samples.

The model average samples are, as the name implies, model specific. We'd generate them on the fly from the full FP16 version of the model. Since the full model can be heavy, one might do this on CPU (perhaps even with llamacpp, let's not go all not invented here syndrome!). Even if slow, it only needs to be done once per model.

To get the control samples we let the OG model generate freely from a BOS token (or maybe from a random starting token since not all models are trained with BOS), sampling at a low temperature. This should naturally lead to maximally generic median output for the model. The lower the temperature, the lower the PPL of the resulting samples, trending towards 0 as temperature drops. With greedy sampling the PPL would be 0 since the model is literally scored on its ability to predicts its own predictions. We want some more variety than that since the goal is to get generality, so low temp or min-p, aiming to create a corpus with high diversity yet low average ppl.

So the hypothesis is then that the primary samples, from the calibration dataset, allow us to aggressively quantise layers in a use-case specific way. Maybe some layers really affect mathematical ability but that's not tested by the calibration set and so it's justified to drop precision in those layers: that's sort of what the user is asking us to do. But at the same time we don't want to be too aggressive in specialising the model since no dataset is ever complete. The LLM ability to generalise only appears with massive training data sets, so it's probably quite easy to step all over that ability when quantising. We counter that with these "normalisation" samples that ensure we don't stray too far from the baseline, and the ratio between primary and secondary samples is then a control knob for how specialised we want the quant to be. Specialisation degree.

turboderp commented 9 months ago

Well, I do have some issues with that benchmark.

The results are suspect since they're made using two different sampling pipelines. Specifically, note that the 2.4bpw version is listed as "broken", despite the fact that it works just fine, even on German questions. I suspect the difference lies in how text-generation-webui tokenizes prompts, but there isn't a lot to go by in the post. So whether this problem applies to the higher-bitrate models as well would need to be investigated. Perhaps they simply recover more easily from seeing a BOS token that they weren't exposed to in finetuning? It's hard to say, but in any case the presence or absence of the BOS token does make a big impact on some specific models at extremely low bitrates.

I've done quite a number of tests, after all, looking at perplexity as well as MMLU scores and subjective tests, so it's hard for me to put much stock in a comparison like the one you linked, when the methodology is questionable and the results are so contrary to previous tests by myself and others.

Creating a better calibration dataset does make sense, but I don't really follow the model averaging idea? If the model made only 100% certain predictions, the perplexity would approach 1, but greedy sampling would invariably select tokens with a probability anywhere from close-to 100% down to maybe 5% or so. And with only greedy sampling and no repetition penalty or randomness added, especially smaller models have a tendency to get stuck in simple patterns like short, repeating sentences, and I'm not sure that makes for a good target..? Idk.

Using outputs from the model itself does sound interesting, though. Or, at least augmenting the calibration set with data that captures where the model "would like the inference to go", as opposed to wherever the chosen dataset goes. That could be very valid. The problem is running inference on the model before it's quantized. For a 70B model on consumer GPUs, that's an infeasible amount of swapping in and out, essentially loading the entire model once per token. And with CPU inference that could still be days of processing to produce, say, 100,000 tokens of calibration data. Quantization is already a lengthy process as it is.

As for the calibration data, I would note two things:

First, it's used for error correction in the same way that GPTQ-for-Llama has always used some fixed chunk of either wikitext, C4 or PTB. The original authors of the GPTQ paper found it sufficient to just use 128x2048 tokens of C4 and just not worry too much about it. TheBloke has quantized hundreds of GPTQ models using wikitext-test and it hasn't caused much of a stir. It also seems fair enough to me, since in my own tests the calibration generalizes well across datasets. Here is a perplexity on a 7B model calibrated with wikitext (normalized to the FP16 result), showing no affinity to that dataset over others, except at 3 bpw and below where it starts to do better on the exact text it was calibrated with (as expecting since that's where it would start to lean most heavily into the calibration):

llama2_7b_ppl_rel

Secondly of course, the calibration is also used during measurement. Here, each layer is completely quantized and optimized at a number of different settings. Only a portion of the calibration dataset is used for this, to speed up the process since it has to be done so many times over.

The result of the measurement is the difference between the hidden state produced by the quantized linear layer and the one produced by the original. Specifically, it's the Frobenius norm of the error divided by the Frobenius norm of the target. That's definitely the part I struggled with the most: deciding on a robust, scale-invariant error measure. And if any part of the algorithm needs attention, that's probably it, since a number of "intuitive" assumptions were made in that regard. And while I did spend many weeks verifying those assumptions empirically, there are no doubt still unexplored possibilities there.

Either way, this is where the script makes the same observation that the llama.cpp devs seem to have made, that some layers are more important than others. Only the behavior isn't hardcoded. So while the attention V projection is usually allocated more bits (because it is higher effective rank, is harder to align to a coarser quantization grid, doesn't arrange as neatly into an "activation order", whatever it may be), the overall process still has a budget so it may choose to prioritize other parts of the model anyway.

The main concern with measurement, I'd say, would be if there isn't enough diversity in that part of the dataset, so the model ends up thinking a layer is easier to quantize than it ends up being. Then again, I have done a lot of tests, and the measured error doesn't move around that much as you add more rows to the test, and the final error from quantizing on typically 100 rows is always very close to the measured error from the first 16 rows.

Now, I have been meaning to return to the quantization algorithm. One of the first things on my wish list is a "standard" calibration dataset that covers a wide range of scenarios like code, dialogue, reasoning and at least some other languages besides English. I'd also like to make sure it includes, if not one of each token, then at least enough tokens to reasonably cover the embedding space.

Including elements generated by the model itself (one way or another) is definitely worth considering, too.

aljungberg commented 9 months ago

I've done quite a number of tests, after all, looking at perplexity as well as MMLU scores and subjective tests, so it's hard for me to put much stock in a comparison like the one you linked, when the methodology is questionable and the results are so contrary to previous tests by myself and others.

Agreed. There were clearly methodology issues with that particular test. I think the core idea is sound, testing the model's ability on novel (not trained on) Q&A. It's easy to score and avoids the pitfalls of testing on the usual suspects (where there's ever growing risk they've accidentally made it into the model training data). Another concern is that I couldn't find what the quant (by LoneStriker) calibration set was. It could very well have been in English for example, whereas WolframRavenWolf tested it in German. This would penalise an aggressive quantiser more. Not sure on what and how GGUF measures their error.

Problems with that examination aside, and this is anecdotal, in my testing I seem to find unexpected fragility and failure modes in some EXL2 quantisations. I mean, I haven't done as much testing as you have, but I have a feeling it's easy to cut too deep with the current algo.

I tried quantising the model in question myself, lzlv_70b_fp16_hf, to 7.0 bpw. The quant went heavy on the very first layer.

 -- Linear: model.layers.0.self_attn.q_proj -> 0.05:3b/0.95:2b 32g s4, 2.17 bpw
 -- Linear: model.layers.0.self_attn.k_proj -> 0.05:3b/0.95:2b 32g s4, 2.17 bpw
 -- Linear: model.layers.0.self_attn.v_proj -> 0.1:6b/0.9:4b 32g s4, 4.31 bpw
 -- Linear: model.layers.0.self_attn.o_proj -> 0.25:3b/0.75:2b 32g s4, 2.38 bpw
 -- Layer rfn_error: 0.007747

Clearly the reconstruction error of 0.007747 is quite low so the optimiser's decision is justified, technically. But it's really very aggressive to go to as little as 2.17 bpw on the very first layer on a 7 bit budget, and one wonders how that affects generality. Maybe it causes just enough rounding error on that leading BOS token for example, which may not have been seen in the calibration set.

Creating a better calibration dataset does make sense, but I don't really follow the model averaging idea? If the model made only 100% certain predictions, the perplexity would approach 1, but greedy sampling would invariably select tokens with a probability anywhere from close-to 100% down to maybe 5% or so. And with only greedy sampling and no repetition penalty or randomness added, especially smaller models have a tendency to get stuck in simple patterns like short, repeating sentences, and I'm not sure that makes for a good target..? Idk.

Yes, sorry if I didn't express that clearly. We wouldn't use greedy sampling, because we need variety. That's kind of the point, to get a broader set to push back against over-specialisation. But while we want variety, we at the same time don't want high PPL samples. By definition those wouldn't represent the original model well. So need to strike a balance between variety and verisimilitude.

Repetition penalty and so on will likely be necessary too, although to some extent if the model gets stuck repeating something, it might not matter. We're not making samples to be pleasing, but to be representative. We want to capture the natural tendencies of the model, as a yardstick, so we can measure how far our quant has taken us from the original in the most general sense.

(Actually I guess what we could use is mirostat, targeting a certain low PPL.)

I honestly have no idea how easy or hard it is to generate "average" samples, but it's not without precedent. For image generation this is a thing people do, if I remember right. When you let the model generate with no guidance, just on noise, it supposedly generates images tending towards the "median" images of the training set.

Using outputs from the model itself does sound interesting, though. Or, at least augmenting the calibration set with data that captures where the model "would like the inference to go", as opposed to wherever the chosen dataset goes. That could be very valid. The problem is running inference on the model before it's quantized. For a 70B model on consumer GPUs, that's an infeasible amount of swapping in and out, essentially loading the entire model once per token. And with CPU inference that could still be days of processing to produce, say, 100,000 tokens of calibration data. Quantization is already a lengthy process as it is.

Yeah it wouldn't be fast. On CPU only, a 70B model would take like... 27 hours to generate 100k control tokens, back of the napkin. I remember 1 token/s being a reasonable expectation.

You could batch the generation on the GPU, making 25 inference streams at once. Sure, you still have to send the weights to the device 4096 times, if we aim for a full context window, but 25x4096 tokens is the 100k desired, isn't it? Although I guess in practice on a 24 GB GPU you'd only be able to squeeze in about 12 batches, even when working layer by layer, because the KV cache has to fit too. And these will all by design be completely divergent batches so we can't share the cache... Anyhow this should still be a ton faster than naive CPU gen or single batch GPU gen with streaming weights.

One thing that helps is that it needs to be done just once per model, just like the measurement pass itself. And maybe one can get away with much fewer than 100k control tokens. It depends on the specialisation degree wanted.

First, it's used for error correction in the same way that GPTQ-for-Llama has always used some fixed chunk of either wikitext, C4 or PTB. The original authors of the GPTQ paper found it sufficient to just use 128x2048 tokens of C4 and just not worry too much about it. TheBloke has quantized hundreds of GPTQ models using wikitext-test and it hasn't caused much of a stir. It also seems fair enough to me, since in my own tests the calibration generalizes well across datasets. Here is a perplexity on a 7B model calibrated with wikitext (normalized to the FP16 result), showing no affinity to that dataset over others, except at 3 bpw and below where it starts to do better on the exact text it was calibrated with (as expecting since that's where it would start to lean most heavily into the calibration):

That's a pretty convincing graph. To be honest, this is what I believed would be the case initially! We're not retraining the model, we're still keeping all the original weights, just with lower precision. And GPTQ specifically compensates against bias if I remember right so the norm remains the same. But seeing some of these quirky degenerations, not to mention this surprising BOS allergy is what made me wonder.

The result of the measurement is the difference between the hidden state produced by the quantized linear layer and the one produced by the original. Specifically, it's the Frobenius norm of the error divided by the Frobenius norm of the target. That's definitely the part I struggled with the most: deciding on a robust, scale-invariant error measure. And if any part of the algorithm needs attention, that's probably it, since a number of "intuitive" assumptions were made in that regard. And while I did spend many weeks verifying those assumptions empirically, there are no doubt still unexplored possibilities there.

Thanks for this, I didn't dig that deep into the specifics of the error measurement (sorry, just a guy doing a drive-by suggestion). Does the quantised hidden state propagate all the way through, building up more and more error as we get to later layers, or do we start from the baseline FP16 hidden state each layer?

If the quant state is carried through, don't the non-linearities cause a kind of exponentially growing error? Errors building on errors as we get to later layers.

If it doesn't (we use the FP16 hidden state as the input to each measurement), do we have to worry about the opposite effect, that our error measurement is more significant in early layers? An error of 0.001 in the first layer might be much worse than an error of 0.001 in the last layer, as the earlier the error the more knock-on effects it can have downstream.

In that case we might need some empirical measurement of how closely the rfn_err on a per layer basis correlates with average PPL loss. And then spend more bits on earlier layers as needed. Like, some errors are worse than others and the amount of 'worseness' can be measured and compensated for.

Either way, this is where the script makes the same observation that the llama.cpp devs seem to have made, that some layers are more important than others. Only the behavior isn't hardcoded. So while the attention V projection is usually allocated more bits (because it is higher effective rank, is harder to align to a coarser quantization grid, doesn't arrange as neatly into an "activation order", whatever it may be), the overall process still has a budget so it may choose to prioritize other parts of the model anyway.

Hey that's what I was just trying to describe but you did it better! Some layers are more important than others. That's equivalent to the error mattering more in some places.

The main concern with measurement, I'd say, would be if there isn't enough diversity in that part of the dataset, so the model ends up thinking a layer is easier to quantize than it ends up being. Then again, I have done a lot of tests, and the measured error doesn't move around that much as you add more rows to the test, and the final error from quantizing on typically 100 rows is always very close to the measured error from the first 16 rows.

Yeah this is precisely my point with a specialisation factor. A useful exercise is often to think about things at the limit. If we had a calibration set that only contained children's songs, and the higher layers in big models are dedicated to more advanced reasoning and recall, an aggressive quantiser would rightly totally lobotomise those layers and spend all the bits on early basic layers. This would hold true whether you used 16 or 1000 rows because all your data has this sameness to it. It's a kind of sampling bias.

Sometimes aggression is right, but there needs to be balance and this might best be controlled by a user knob.

Now, I have been meaning to return to the quantization algorithm. One of the first things on my wish list is a "standard" calibration dataset that covers a wide range of scenarios like code, dialogue, reasoning and at least some other languages besides English. I'd also like to make sure it includes, if not one of each token, then at least enough tokens to reasonably cover the embedding space.

Yep this would probably work too. Like the easy method is to just randomly sample the original training data of the model, if known. But you might need a fairly large sample to ensure it's truly representative.

Of course easier said than done since we don't have the Llama training dataset.

Including elements generated by the model itself (one way or another) is definitely worth considering, too.

The main benefit being that the model we do have. And it's is own domain expert on what it thinks is "normal".

turboderp commented 9 months ago

Problems with that examination aside, and this is anecdotal, in my testing I seem to find unexpected fragility and failure modes in some EXL2 quantisations. I mean, I haven't done as much testing as you have, but I have a feeling it's easy to cut too deep with the current algo.

I've seen some instability too, don't get me wrong, but I've seen this on GPTQ as well. So I'm not sure if it has to do with the bitrate as much as just overfitting in general. It could also be an artifact of models being merged or finetuned in BF16. Since most of the inference logic, and some of the quantization logic is done in FP16, the difference in dynamic range could factor into it somehow.

There are some easy places you can tweak it if you want to experiment and have an idle GPU anyway, or you're just feeling cold or whatever:

Thanks for this, I didn't dig that deep into the specifics of the error measurement (sorry, just a guy doing a drive-by suggestion). Does the quantised hidden state propagate all the way through, building up more and more error as we get to later layers, or do we start from the baseline FP16 hidden state each layer?

During measurement it's the full-precision hidden state that carries through, because there are 20 different versions of the hidden state from 20 different quantization settings. So effectively it measures each layer as if all the previous layers were unquantized.

During the actual conversion, for each layer it first computes the hidden state with the full-precision model, then it quantizes the layer and finally recomputes the hidden state with the full quantized module. Then the quantized state is carried forward.

I have considered keeping both states and using the difference between them to try to adjust subsequent layers. The idea would be that between each module (though after each of the RMS norms), compute the error E as the difference between the quantized state and the original state, multiply the inputs of the next layer (QKV matrices or MLP up/gate) by the outer product of E, and... hope that it generalizes, I guess? It's worth a shot at least.

Sometimes aggression is right, but there needs to be balance and this might best be controlled by a user knob.

I agree in principle, but I'm also scared of giving users too many knobs. I had a lot of fun implementing Mirostat sampling recently, a process in which I discovered that many most (?) users are using it as a placebo. They crank up the tau parameter so high that the algorithm almost never reaches its "surprise" target, and so the effective top-K just shoots up into the thousands and what they attribute to Mirostat is just more or less unconstrained sampling.

That doesn't mean a specialization knob wouldn't be valuable, especially for testing and experimentation. But I also much prefer if users have some idea of what they're getting when they download a model. So idk.

The main benefit being that the model we do have. And it's is own domain expert on what it thinks is "normal".

This is true. I think it's definitely worth trying, at least as an experiment. And it's easier to start with small models anyway, where running inference on the FP16 weights isn't an issue, and supposedly quantization errors are more pronounced anyway. Then if it turns out to be worth it, figure out a solution for 70B models. And 120B. Apparently that's a thing now.

aljungberg commented 9 months ago

It could also be an artifact of models being merged or finetuned in BF16. Since most of the inference logic, and some of the quantization logic is done in FP16, the difference in dynamic range could factor into it somehow.

Interesting point. So first we get BF16's lesser precision during merge/fine-tune, then FP16's smaller range during quantisation. Bit of a double whammy. Hmm, no wait, it's a single whammy depending on if it's a fine-tune or a merge.

qparams_options in conversion/qparams.py is the only place the quantization options are enumerated, and they're fairly self-explanatory. QParams(32, [3, 2], [0.05, 0.95], 4), specifies group size 32, with 5% of the groups using 3 bits per weight and 95% using 2 bits. The last number is the number of bits for the group scale, and 4 is the only option currently.

I take it you hand-picked the current array of choices in qparams_options based on your empirical testing? I'm noticing we don't have more extreme mixtures like [8, 2] or [8, 3] to tackle layers with particularly gnarly outliers. Is that not possible or just not useful? I'm thinking about Tim Dettmers' observations about outliers being few but significant.

One could imagine some attention heads basically being quite particular and having ~zero weights in the input projections that apply to certain parts of the embedding space, while at the same time being extremely sensitive and demanding of precision in other areas. So one would want to use low precision weights for some blocks in those matrices and high precision blocks in the others. Which I guess is the whole point of varying the bit mixture to begin with, but one could turn that up to 11.

I have considered keeping both states and using the difference between them to try to adjust subsequent layers. The idea would be that between each module (though after each of the RMS norms), compute the error E as the difference between the quantized state and the original state, multiply the inputs of the next layer (QKV matrices or MLP up/gate) by the outer product of E, and... hope that it generalizes, I guess? It's worth a shot at least.

Yep, finding that cascading effect.

It's possible it's not a thing, as well. I mean if it were a big problem, precision error compounding over layers, GPTQ wouldn't work at all since it's layer by layer and one-shot. It's possible that each layer in a transformer model does stand alone. Transformers can generalise and to do that they have to accept inputs that are "close enough". And that forgiveness might be between layers too, not merely for the model as a whole.

In that case it may be that the current method is absolutely fine, errors within a layer are the real source of quality loss when quantising, whereas between layers you have some latitude to be off.

It's a hypothesis we can argue for and against. Have to actually test it in the end to find out.

If error compounding/butterfly effect is a problem, I think we can show that simply: just apply an "effective error" multiplier that scales from X to 1 linearly with the layer number. I got interested so I did just that in a hacked up version of optimize.py. I'll have some results soon.

I agree in principle, but I'm also scared of giving users too many knobs. I had a lot of fun implementing Mirostat sampling recently, a process in which I discovered that many most (?) users are using it as a placebo. They crank up the tau parameter so high that the algorithm almost never reaches its "surprise" target, and so the effective top-K just shoots up into the thousands and what they attribute to Mirostat is just more or less unconstrained sampling.

Hah! Not surprised. I think mirostat is hard to control. Like different texts have different perplexities for different models, right? So how do you know what "average surprise value" to target? The scale of it will vary not just per model but per kind of prompt! Some kinds of texts will be much more surprising for some models than others, so there's no one size fits all answer. I feel like we need some running stats as humans to be able to set the right mirostat settings. At least what the current perplexity is and yes you're right, simply showing the top-k chosen per token would let you know when you're way off.

Anyhow, going back to knobs, yeah the ideal is to have as few settings as possible and have good heuristics built in. And good point about user expectations when downloading an EXL2 model. Still, one could default the knob to a low amount of specialisation, right? Because an ultimate quantiser with a high specialisation degree actually would cut all the fat we don't need for some specific use case, at the extreme. Like dropping entire languages the LLM can speak makes sense if you know you're only going to work in English. At scale you'll be saving lot on your compute bill. So it's not a quant you'd put on HuggingFace but you might make one like that for personal/internal use.

Of course having a knob only matters if it does anything. Are we specialising much, presently?

My testing so far indicates that we're not. I tried using 100 or 200 calibration rows in the quantisation stage and checking the PPL on a different dataset than the calibration one. Correct me if I'm wrong, but with a high degree of specialisation, increasing the sample count should make the quant worse on datasets it wasn't calibrated on as it hones in more tightly on the numerical distributions relevant to the calibration set. But that didn't happen in my experiment: doubling the number of samples made the quant ppl very slightly lower on both sets. In other words, more calibration samples was always better. It made a better generic quantisation, not a more specialised one. (I used WizardLM_evol_instruct_70k for calibration and wikitext-2 as control, so one is instructions and responses while the other wikipedia articles.)

This is true. I think it's definitely worth trying, at least as an experiment. And it's easier to start with small models anyway, where running inference on the FP16 weights isn't an issue, and supposedly quantization errors are more pronounced anyway. Then if it turns out to be worth it, figure out a solution for 70B models. And 120B. Apparently that's a thing now.

Being able to generate characteristic samples might matter for more than just a specialisation knob... Like if you had the ability to generate like "noise" that was close to the typical inputs, representative noise so to speak, you could possibly quantise faster and better. Just keep computationally generating more inputs until the variability of your error prediction diminishes to some threshold, signalling you've sampled "enough". That'd remove the --dataset_rows setting, simplifying things, and save host-to-device/ease the memory bottleneck since you generate on the fly, on device.

You could sample tons of real inputs and build some statistical model out of that. Doesn't need to be perfect, the noise model, just unbiased and "closeish" to real inputs. But I digress. That's an entirely different cup of tea.

I'll post some numbers on the error by layer ramping tests.

turboderp commented 9 months ago

I take it you hand-picked the current array of choices in qparams_options based on your empirical testing? I'm noticing we don't have more extreme mixtures like [8, 2] or [8, 3] to tackle layers with particularly gnarly outliers. Is that not possible or just not useful? I'm thinking about Tim Dettmers' observations about outliers being few but significant

As I understand it, it's not necessarily that some weights need super high precision, but that they are just larger.

If you quantize the entire matrix to a single, regular quantization grid, aligning that grid so as to minimize the MSE of either the matrix itself or the activations, the small weights are going to dominate and you'll end up with a small grid. The outliers will be clamped, and while the mean error is small you could still lose important features.

With smaller group sizes you remedy the problem somewhat, since each group of e.g. 32 weights will have its own grid. Activation order helps a bit more, since rows (in the matrix view, or columns in the linear layer view) are sorted in a way that tends (somewhat at least) to group weights together based on their magnitude.

This isn't enough, of course, to guarantee that you don't end up with groups that are hard to fit onto a grid, giving you either poor precision for the small weights or clamping of large weights. The NormalFloat formats are an interesting approach, but they're really expensive to do arithmetic on since they require lookup tables.

I do have some experiments brewing inspired by how well the FP8 cache has worked out. Surprisingly, inference works just fine if you simply drop the last 8 bits of every FP16 value in the cache, giving you the equivalent FP8 (e5m2) number. I think if there was an efficient way to quantize to a logarithmic grid instead of a linear one, that would make all the difference.

Correct me if I'm wrong, but with a high degree of specialisation, increasing the sample count should make the quant worse on datasets it wasn't calibrated on as it hones in more tightly on the numerical distributions relevant to the calibration set.

I think this is backwards. The smaller the calibration set, the easier it is to fit the quantized model to it. And even if you show the model more of the same "kind" of text in a larger dataset, it's not like giving it more iterations of finetuning. More data would only ever reduce the number of correlations the quantizer can exploit for error correction, making it more general.

Being able to generate characteristic samples might matter for more than just a specialisation knob... Like if you had the ability to generate like "noise" that was close to the typical inputs, representative noise so to speak, you could possibly quantise faster and better. Just keep computationally generating more inputs until the variability of your error prediction diminishes to some threshold, signalling you've sampled "enough". That'd remove the --dataset_rows setting, simplifying things, and save host-to-device/ease the memory bottleneck since you generate on the fly, on device.

The tokenized samples are tiny compared to the hidden state, so host-to-device communication wouldn't improve much. But I do like the idea of sampling calibration data from the model itself. Specialization can be controlled either by turning the percdamp knob, or by adding a variable amount of noise to the input data.

aljungberg commented 9 months ago

I think this is backwards. The smaller the calibration set, the easier it is to fit the quantized model to it. And even if you show the model more of the same "kind" of text in a larger dataset, it's not like giving it more iterations of finetuning.

Ah. I thought there was a different effect going on, but I didn't actually read that code, so thanks for clarifying. Then we can disregard the result of this particular experiment of mine.

My mistake on what's actually happening aside, let me try to justify my thinking. I think it has bearing on the discussion. The non-finetuning mechanism whereby more calibration rows would lead to higher specialisation I imagined was one of gaining more evidence to group weights differently in terms of their precision allocation. The more data we have, the better we can justify skewing our precision distribution.

The essence of quantisation really boils down to judiciously assigning precision where it matters most, where what matters is estimated through sampling on typical inputs, right?

So as a thought experience, if you gave me a vector of FP32 weights and told me to fit them into half as many bits of arbitrary precision storage while preserving maximum fidelity as measured by output delta in a neural network and you also gave me...

The more proof we get, the more lopsided we can be in our allocation, if that makes sense. At the limit maybe some parts of the matrices can be left as pure noise because they represent the model's understanding of things that just don't exist in the calibration set. Just need a lot of sampling to discover that and be sure.

More data would only ever reduce the number of correlations the quantizer can exploit for error correction, making it more general.

Okay that's a pity. Still, doesn't feel like it has to be like that? With overwhelming evidence that the precision of some weights contribute nothing to the model output wrt calibration set, then even if we can't stick them in their own 0 bit group as a practical matter because they're not clumped up, there are that many fewer constraints when aligning the quantisation grid for that block. I guess I'm not entirely understanding how the calibration rows are used post the qparams selection stage.

turboderp commented 9 months ago

Aligning the quantization grid is only part of it.

The algorithm is trying to solve a reconstruction problem, as the GPTQ paper puts it, where for some matrix W and some collection of input vectors X_l, what we want is the quantized matrix W' that minimizes the (mean squared) difference between WX and W'X for all X in X_l.

If you reduce the number of input vectors, you'll find a more precise solution to a narrower problem. If instead you increase the number of inputs, the solution will tend towards minimizing the difference between W and W' instead, i.e. it will tend towards RTN quantization.

But the idea is that not all possible inputs need to be considered. For instance the Q, K and V matrices of the first block of a regular Llama model will only ever see 32000 distinct inputs (the possible embedding vectors), and those inputs are going to be highly correlated since they encode the model's idea of what each token "means". Hello, hello, hello, hi etc. all mean roughly the same thing, so they will likely cluster together in the embedding space. Then, the first O matrix will only need to consider those states that can be reached after position embeddings and attention on the Q, K and V projections from those 32000 initial embeddings. And so on.

So it's not really phrased in terms of confidence in the importance of weights. But if you do think of it that way, then more inputs necessarily means less confidence. (The way the algorithm works, at least. Whether it should be that way is an open question I guess.)

Do keep in mind though that even though it is a "reconstruction" problem, the algorithm still uses the original weights as a starting point: It quantizes a column, then it distributes the rounding error over the remaining weights so that they compensate for it as much as possible (given whatever correlations may exist as "revealed" by the calibration data), then it proceeds to the next column.

jukofyork commented 2 days ago

Sorry to bump an old thread - I was searching for "control vectors" and for some reason this showed up...

def test_error(module, hidden_states, target_states, cache, attn_params):

    rfn_sum = torch.tensor(0.0).cuda()
    rfn_count = 0
    for x, xref in zip(hidden_states, target_states):
        x = x.cuda()
        xref = xref.cuda()
        xtest = module.forward(x, cache, attn_params)
        xtest = xtest[0].float()
        xref = xref[0].float()
        rfn_sum += torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')
        rfn_count += 1

    return max(1e-6, 1 - (rfn_sum.item() / rfn_count))

The torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro') is likely the source of the overfitting.

The individual norms can be regularised, eg:

https://arxiv.org/abs/2406.06137

But the ratio itself is likely to be the biggest problem. If we (unrealistically) assume it is approximately the ratio of i.i.d. Normals then it produces a Cauchy Distribution:

https://en.m.wikipedia.org/wiki/Cauchy_distribution

which has horrible sample statistics.

The most likely practical method would be to use resampling to compute an empirical CDF, trim the problematic tail and then recalculate the expected value to use as your estimate.

turboderp commented 2 days ago

I'm entirely open to the idea of using another estimate. I'm not convinced that overfitting is a problem at the moment, and the Frobenius norm does align with the quantization objective (minimize MSE on the output of each linear layer), but in principle any function could work as long as it meaningfully captures the error.

I don't know how much time I have to revisit measurement right now, but I would certainly entertain a proof of concept.

jukofyork commented 2 days ago

I'm entirely open to the idea of using another estimate. I'm not convinced that overfitting is a problem at the moment, and the Frobenius norm does align with the quantization objective (minimize MSE on the output of each linear layer), but in principle any function could work as long as it meaningfully captures the error.

No, I think this is the correct metric to use (re: Rate–distortion theory) - it's just regularising the estimate would help.

I tried to do the same for the llama.cpp 'Importance Matrix' calculations as the:

Importance matrix calculations work best on near-random data

observation is actually related to Tikhonov_regularization as I tried to explain here:

https://github.com/ggerganov/llama.cpp/discussions/5263#discussioncomment-9163642

(IIRC, there is an even better explanation of this in Chris Bishop's older "Neural Networks for Pattern Recognition" book)

As for what is happening in the llama.cpp code:

https://github.com/ggerganov/llama.cpp/discussions/5263#discussioncomment-9323017

this still remains a mystery to me...

It make no sense that the calculation when using or not using the Importance Matrix should be so different, and there is no obvious way to regularise from the "with Importance Matrix" calculation towards the "without Importance Matrix" calculation...

I'm not sure if the original author of that code has left now too :( Whatever this code is doing doesn't make any sense from a statistical learning perceptive, and even trying to ask all the large LLMs if they can work out where this method comes from doesn't really help - they just suggest is was written by somebody with a "signal processing background" and give some analogies of what it could be trying to achieve.

The other big problem with the llama.cpp 'Importance Matrix' "improvement" measurements is that they seem to use distributionally similar text from some wiki dataset that has all sorts of odd formatting and spacing (at one time for both the creation of the imatrix and the testing of imatrix, but recently it looks like they are using more diverse datasets). The result of this is that the "improvement" measurements are very sketchy at best and end up with things like Q6_K having lower perplexity than FP16 as a result... :/

turboderp commented 2 days ago

I did some fun experiments a while back measuring the RFN error for Mixtral at various quantization levels throughout the forward pass.

image

A couple of observations:

Anyway, I think it's hard to say that perplexity shouldn't ever be lower for a quantized model than it is for FP16. It always comes down to making the best rounding choices you can, exploiting redundancy in the weights and all that, informed by some sample data that's hopefully not too narrow. You can end up losing something in the process, but you'd deliberately do that during sampling anyway, since presumably there's always a noise floor regardless if the model is quantized or not.

Of course you can also just straight-up overfit in the process. I ran into possibly a quite extreme case of this with EQAT:

image

Which is essentially finetuning under a quantization constraint. It results in much lower perplexity for the instruct version of L3-8B and, correspondingly much worse performance on HumanEval, tracking nicely with the fact that L3-8B scores very poorly on HumanEval without its instruct tuning, which it seems to be forgetting with this QAT method.

In the end I think it's about striking some optimal balance. You also don't want to simply ascribe equal importance to all weights, because then the solution is trivially just RTN quantization, which is demonstrably worse on all benchmarks.

aljungberg commented 1 day ago

Importance matrix calculations work best on near-random data

Oh yes, @jukofyork . Would have been interesting to see it with this idea of "model average samples" I suggested above, which would be between the two extremes. More varied than wikitext2, less (apparently) unhelpful than totally out of distribution stuff from random noise.

Although @turboderp's chart suggests why even random noise should work: the model is inherently error correcting layer by layer. So random calibration samples are only truly out of distribution for the first few layers. After that the model has chewed on them enough to turn them into something more familiar. I mean, that's what an LLM is, a "what looks reasonable detector". Every layer is geared towards steering towards an intermediate state that'd be plausible in the training data.

I guess we'd be reapproaching the "model average samples" idea if we were to quantise with "normal" calibration data on the first few layers, then switch to "near-random data" when quantising later layers (still passing these random samples through the earlier layers, just not measuring the error for quant purposes in that part of the model). We would be using the actual model to partially reshape those inputs into something less bizarre.

Last year I ran a number of experiments on different ways to let the calibration data guide the precision allocation. One of the methods I tried was related to what you're showing in your useful chart, @turboderp and which I think we talked about at the time. Is there a butterfly effect such that we want to generally blindly throw more bits at the early layers? I tried a gradient such that the first layer had 25%, 50% or 75% amplification on its error measurement for bit allocation purposes, linearly dropping for each subsequent layer, going to 0% for the last layer. The results were, unfortunately, inconclusive. 50% gradient for example seemed to create a better 5.0 bpw quant of lzlv-70B calibrating with WizardLM's dataset and using the same (but a different selection) for ppl measurement. But on wikitext-2, ppl was higher than using no gradient. I played around with different numbers of calibration samples and various bit allocation heuristics to run on top, but I only ever saw success on the Wizard LM PPL metric.

My conclusion was that PPL is way too coarse a measurement, at least with my compute budget.