ggerganov / llama.cpp

LLM inference in C/C++
MIT License
65.91k stars 9.46k forks source link

Investigate alternative approach for Q4 quantization #397

Closed ggerganov closed 1 year ago

ggerganov commented 1 year ago

Currently, in Q4_0 quantization we choose the scaling factor for each 32 group of weights as abs(max(x_i))/7. It is easy to see that this is suboptimal.

Consider quantization of the following 4 numbers:

0.1 0.2 0.3 0.6

Currently, we would determine a scaling factor of 0.6 / 7 ~= 0.0857 and the dequantized numbers will be:

0.0857 0.1714 0.3428 0.6

So the RMS between the dequantized and original values will be non-zero:

sqrt((0.1 - 0.0857)^2 + (0.2 - 0.1714)^2 + (0.3 - 0.3428)^2 + (0.6 - 0.6)^2) > 0.0

However, if we choose the scaling factor to be 0.1 instead, then it is easy to see that the original numbers will be quantized perfectly.

So the scaling factor is better to be chosen as the one that minimises some error (e.g. RMS or whatever is more meaningful and easy to compute). Doing that we will certainly achieve better accuracy compared to the existing approach. The question is - how much better?

The goal of this task is to implement the described quantization above and evaluate the perplexity using the new approach. The approach in simple terms boils down to making a linear regression of the data with a fixed zero point. This new quantization might be a bit heavier to compute compared to Q4_0, so for start we can do it just on the model tensors. The intermediate tensors during the evaluation can remain quantized using the existing approach, so that the evaluation is efficient. If the results look promising, we can put effort into optimising the new approach and replacing completely Q4_0 with it.

Whoever demonstrates the results of this quantization will get the chance to give it a name and publish a paper (just kidding 😆 )

Similar strategy for determining the scale factor and offset factor can be applied to Q4_1.

Andrey36652 commented 1 year ago

@ggerganov "so for start we can do it just on the model tensors. The intermediate tensors during the evaluation can remain quantized using the existing approach, so that the evaluation is efficient." - do you mean, that Q4_0 quantize not only weights, but activations too now? Don't quite understand meaning of "model tensors"...

Andrey36652 commented 1 year ago

Might worth reading

A Survey of Quantization Methods for Efficient Neural Network Inference https://arxiv.org/pdf/2103.13630.pdf

sw commented 1 year ago

Low-bit Quantization of Neural Networks for Efficient Inference deals with 4-bit quantization specifically.

As a smaller step, I can think of these optimizations:

prusnak commented 1 year ago

This paper is also relevant: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference, although this one deals with int8

prusnak commented 1 year ago

0.0857 0.1714 0.2571 0.6

This should read 0.0857, 0.1714, 0.3428, 0.6, correct? It seems to me that code the C++ code uses round() not floor(), so the third value will be rounded up.

However, if we choose the scaling factor to be 0.7 instead,

This should read 0.1, correct?

prusnak commented 1 year ago

I came up with a script that's able to compute RMS for various quantization methods - maybe it will come handy for experimenting: https://gist.github.com/prusnak/f54f8f33503458ca1aa9883f71897072

prusnak commented 1 year ago

I was experimenting with grid search to find a better offset and scaling factor, but it does not seem to produce much better results than simply doing Q4_1. This doesn't justify making the whole process much slower (n_search^2 slower).

Pseudocode:

n_search = 30
data_min = min(data)
data_max = max(data)

search_step = (data_max - data_min) / n_search

for min_value in range(data_min, data_max, search_step):
    for max_value in range(min_value + search_step, data_max, search_step):
        perform Q4_1 but use min_value as offset and (max_value - min_value) as scaling_factor
        measure RMS
        when RMS is better than everything we've seen so far, store the result of this Q4_1 run

return the best Q4_1 run

Maybe someone can come up with a better grid search?

prusnak commented 1 year ago

I also found the Lloyd-Max algorithm, but this one creates non-uniform quantization, which is no go for our usecase, I assume. Is that correct?

Resources:

blackhole89 commented 1 year ago

I'm playing around with local search for the q4_1 parameters now, with something like the following approximately in place of the inner loop of quantize_row_q4_1:

        round_block(pp, x + i*QK, min, d);
        float err = sq_error(pp, x + i*QK, min, d), err0=err;

        int step_count = 0;
        while(1) {
            ++step_count;

//            const float next_mins[4] = { min*1.001f, min/1.001f, min, min };
//            const float next_ds[4]   = { d, d, d*1.001f, d/1.001f };

            for (int i=0; i<16; ++i) {
//              const float next_min = next_mins[i];
//              const float next_d = next_ds[i];
                const float next_min = min * (0.99f + 0.0002f*(rand()%100)); //next_mins[i];
                const float next_d = d * (0.99f + 0.0002f*(rand()%100));//next_ds[i];

                round_block(pp, x + i*QK,  next_min, next_d);
                float next_err = sq_error(pp, x + i*QK, next_min, next_d);
                if (next_err < err) {
                    min = next_min;
                    d = next_d;
                    err = next_err;
                    goto quantize_row_q4_1_opt_next;
                }
            }
            break;
        quantize_row_q4_1_opt_next:;
        }

        static float rer = 0.0f;
        rer = 0.001*(err/err0) + 0.999*rer;
        printf("q: %d steps, err ratio %.3f, running %.3f\n", step_count, err/err0, rer);

        round_block(pp, x + i*QK, min, d);

I found that the square error is indeed reduced by this, in a way that's quite sensitive to the parameter the loop over i is bounded by (16 here; higher = it tries more directions to improve in at each step). At 4, I get an error ratio of about 0.93 (with the random walk being just a little better than the commented-out fixed directions); at 16, this is down to about 0.83, and at 64 it goes all the way down to 0.7. Obviously this is much slower than the preexisting code, so we'll have to wait for a while to see whether the lower quantization error actually translates to lower perplexity or anything.

(Yes, I'm aware I could do better picking a random direction, or even N deterministic ones, than that. I promise I'll make it less silly if it ever makes it into a PR.)

Andrey36652 commented 1 year ago

I came up with a script that's able to compute RMS for various quantization methods - maybe it will come handy for experimenting: https://gist.github.com/prusnak/f54f8f33503458ca1aa9883f71897072

@prusnak Do you know, which distribution llama weights have? Why do you use uniform distribution for tests? I suppose weights have normal-ish (or non-uniform at least) distribution. So we could have better results with non-uniform quantization.

Here is my suggestion: Currently with Q4_1 (QK=32) we have this size reduction 32*4 + 2*32 = 192 bits (compressed) 32*16 = 512 bits (non compressed) Size reduction = 512/192=2.66

What we could have is QK=128 (or even 256?) and 16 independent fp16 values. These fp16 values forms lookup table. Each weight is quantized to 4 bit, but its value is used as key in the lookup table (I know lookup table might be implemented using AVX). 16 fp16 values need to be adjusted for minimizing RMS error. Given weights in block have non-uniform distribution, such approach could be more profitable in terms of RMS error. Size reduction = 128*16/(16*16 + 128*4)=2.66

ggerganov commented 1 year ago

I also found the Lloyd-Max algorithm, but this one creates non-uniform quantization, which is no go for our usecase,

Any quantization method that can be evaluated efficiently works

(I know lookup table might be implemented using AVX)

This is interesting - do you know if it can be implemented with ARM NEON too?

I had some similar ideas - instead of the existing linear mapping Q4_0 of the "normally" distributed weights, make a simple transform of the data so you get more uniform distribution and quantize that instead. The transform has to be able to evaluate efficiently, so some approximation of uniform distribution would work.

Currently, the "outer" quantization bins are much less "utilized" compared to the "inner" (i.e. around the zero). You can clearly see that during the quantize run where we print the histograms for each tensor. With a uniform transform added, I expect the bins will get more evenly "utilized" and probably lead to some benefits in accuracy.

Piezoid commented 1 year ago

This is interesting - do you know if it can be implemented with ARM NEON too?

I'm not familiar with NEON, but it does have the vtbl/vtbx instructions which permutes bytes like vpshufb. I've used the latter to do in-register lookup tables.

prusnak commented 1 year ago

If weights have a normal distribution, then I believe this approach is worth trying:

That way our quantisation will cover ~95% values for f = 2.0. For other values of f we have 99% for f = 3.0 and 68% for f = 1.0.

We can try different values of f to see which works the best. Or even we can try search for the best f for given batch of data by iterating few values of f (between 1.0 and 3.0 for example) and computing/comparing RMS, but this probably is not necessary and we can come up with a single fixed f that is suitable to be used as a constant for every batch.

prusnak commented 1 year ago

Reading the comments above - yeah, if we can efficiently implement a lookup table int8/int4->float16 using AVX/NEON, then it might be really worth trying the non-uniform approach.

jarcen commented 1 year ago

I'm stuck with older codebase, so I modified the old quantize method in utils.cpp Basically, I'm adjusting amax by +-0.5 for 100 steps and search for the one that yields lowest RMS. I also added rms_delta output parameter to log improvements.

size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist, float * rms_delta) {
    const int nb = k / qk;
    const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2);
    const size_t row_size = nb*bs;

    assert(k % qk == 0);

    const size_t pp_size = qk / 2;
    uint8_t *pp = static_cast<uint8_t*>(alloca(pp_size));

    char * pdst = (char *) dst;

    float block[qk];
    float rms_improvement = 0;
    int rms_samples = 0;
    for (int j = 0; j < n; j += k) {
        uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
        uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));

        for (int i = 0; i < nb; i++) {
            float amax = 0.0f; // absolute max

            for (int l = 0; l < qk; l++) {
                const float v = src[j + i*qk + l];
                amax = std::max(amax, fabsf(v));
                block[l] = v;
            }

            // search for best quantization factor
            float bestD = 0;
            float bestRms = 1e20;
            float defaultRms = 1e20;
            for(int attempt = 0; attempt < 100; attempt++)
            {
                float tweak = (attempt / 100.0) - 0.5;
                const float d = (amax + tweak) / ((1 << 3) - 1);
                const float id = d ? 1.0f/d : 0.0f;
                float rms = 0;
                for (int l = 0; l < qk; l++) {
                    const float v = block[l];
                    uint8_t vi = std::max((int)0, std::min(((int)round(v*id)) + 8, (int)15));
                    const float v_deq = (vi - 8) * d;
                    const float difference = (v - v_deq);
                    rms += difference * difference;
                }
                if(rms < bestRms) {
                    bestRms = rms;
                    bestD = d;
                }
                if(attempt == 50) // when tweak is 0
                    defaultRms = rms;
            }
            rms_improvement += sqrt(defaultRms) - sqrt(bestRms);
            rms_samples++;

            const float bestId = bestD ? 1.0f/bestD : 0.0f;

            *(float *) pd = bestD;
            pd += bs;

            for (int l = 0; l < qk; l += 2) {
                const float v0 = block[l + 0]*bestId;
                const float v1 = block[l + 1]*bestId;

                const uint8_t vi0 = std::max((int)0, std::min(((int)round(v0)) + 8, (int)15));
                const uint8_t vi1 = std::max((int)0, std::min(((int)round(v1)) + 8, (int)15));

                hist[vi0]++;
                hist[vi1]++;

                pp[l/2] = vi0 | (vi1 << 4);
            }

            memcpy(pb, pp, pp_size);
            pb += bs;
        }
    }

    *rms_delta = rms_improvement / rms_samples;
    return (n/k)*row_size;
}

This works, although process is slower by 100 times. Output shows very marginal RMS improvements, around 0.0014 on average:

layers.31.attention.wo.weight - [ 4096,  4096], type =    f16 size =    64.00 MB ->    10.00 MB, rms -= 0.0012 | hist: 0.024 0.020 0.025 0.038 0.056 0.078 0.098 0.113 0.118 0.113 0.098 0.078 0.056 0.038 0.025 0.022
layers.31.feed_forward.w1.weight - [ 4096, 11008], type =    f16 size =   172.00 MB ->    26.88 MB, rms -= 0.0014 | hist: 0.026 0.020 0.025 0.038 0.056 0.077 0.098 0.112 0.118 0.112 0.098 0.077 0.056 0.038 0.025 0.022
layers.31.feed_forward.w2.weight - [11008,  4096], type =    f16 size =   172.00 MB ->    26.88 MB, rms -= 0.0013 | hist: 0.026 0.019 0.024 0.036 0.054 0.076 0.099 0.117 0.124 0.117 0.099 0.076 0.054 0.036 0.024 0.021
layers.31.feed_forward.w3.weight - [ 4096, 11008], type =    f16 size =   172.00 MB ->    26.88 MB, rms -= 0.0014 | hist: 0.026 0.020 0.025 0.038 0.056 0.077 0.098 0.113 0.118 0.113 0.098 0.077 0.056 0.038 0.025 0.022

Output of 7B model is nearly identical. At least it's not broken but I don't know if it's an improvement. This is just an experiment so you can decide if it's worth doing.

nonnull-ca commented 1 year ago

Reading the comments above - yeah, if we can efficiently implement a lookup table int8/int4->float16 using AVX/NEON, then it might be really worth trying the non-uniform approach.

For the case of int4 on AVX, you can (ab)use VPSHUFB for this fairly easily.

Let me come up with an instruction sequence for AVX2...

loretoparisi commented 1 year ago

Maybe of our interest https://github.com/TimDettmers/bitsandbytes

ggerganov commented 1 year ago

I had some similar ideas - instead of the existing linear mapping Q4_0 of the "normally" distributed weights, make a simple transform of the data so you get more uniform distribution and quantize that instead. The transform has to be able to evaluate efficiently, so some approximation of uniform distribution would work.

Currently, the "outer" quantization bins are much less "utilized" compared to the "inner" (i.e. around the zero). You can clearly see that during the quantize run where we print the histograms for each tensor. With a uniform transform added, I expect the bins will get more evenly "utilized" and probably lead to some benefits in accuracy.

The CDF of a normal distribution is given by:

TF=0.5*(1 + erf((x - mean)/(sqrt(2)*sig)))

So if we compute the intermediate weights x' = TF(x) - 0.5 these will be uniformly distributed in [-0.5, 0.5]. I just ran the following patch to verify this:

diff --git a/ggml.c b/ggml.c
index c9a4e86..cc62e49 100644
--- a/ggml.c
+++ b/ggml.c
@@ -449,6 +449,8 @@ static inline __m128i packNibbles( __m256i bytes )
 // blocks of QK elements
 // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)

+#define TF(x, sig) (0.5*(1.0f + erf((x/sig)/sqrtf(2.0f))) - 0.5f)
+
 // reference implementation for deterministic creation of model files
 static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) {
     assert(k % QK == 0);
@@ -461,11 +463,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric

     uint8_t pp[QK/2];

+    double sig = 0.0;
+    for (int i = 0; i < k; i++) {
+        sig += x[i]*x[i];
+    }
+    sig = sqrt(sig/k);
+
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max

         for (int l = 0; l < QK; l++) {
-            const float v = x[i*QK + l];
+            const float v = TF(x[i*QK + l], sig);
             amax = MAX(amax, fabsf(v));
         }

@@ -476,8 +484,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
         pd += bs;

         for (int l = 0; l < QK; l += 2) {
-            const float v0 = x[i*QK + l + 0]*id;
-            const float v1 = x[i*QK + l + 1]*id;
+            const float v0 = TF(x[i*QK + l + 0], sig)*id;
+            const float v1 = TF(x[i*QK + l + 1], sig)*id;

             const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
             const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
./quantize ./models/7B/ggml-model-f16.bin ./models/7B/ggml-model-q4_0.bin 2

                           tok_embeddings.weight - [ 4096, 32000], type =    f16 quantizing .. size =   500.00 MB ->    78.12 MB | hist: 0.000 0.050 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.069 0.050 
                                     norm.weight - [ 4096,     1], type =    f32 size =    0.016 MB
                                   output.weight - [ 4096, 32000], type =    f16 quantizing .. size =   500.00 MB ->    78.12 MB | hist: 0.000 0.050 0.068 0.069 0.069 0.069 0.070 0.070 0.070 0.070 0.070 0.070 0.069 0.069 0.068 0.049 
                    layers.0.attention.wq.weight - [ 4096,  4096], type =    f16 quantizing .. size =    64.00 MB ->    10.00 MB | hist: 0.000 0.042 0.057 0.064 0.069 0.074 0.076 0.078 0.079 0.078 0.076 0.073 0.069 0.064 0.057 0.042 
                    layers.0.attention.wk.weight - [ 4096,  4096], type =    f16 quantizing .. size =    64.00 MB ->    10.00 MB | hist: 0.000 0.044 0.061 0.066 0.070 0.072 0.074 0.075 0.075 0.075 0.074 0.072 0.070 0.066 0.061 0.044 
                    layers.0.attention.wv.weight - [ 4096,  4096], type =    f16 quantizing .. size =    64.00 MB ->    10.00 MB | hist: 0.000 0.050 0.067 0.067 0.068 0.069 0.070 0.072 0.073 0.072 0.070 0.069 0.068 0.067 0.067 0.050 
                    layers.0.attention.wo.weight - [ 4096,  4096], type =    f16 quantizing .. size =    64.00 MB ->    10.00 MB | hist: 0.000 0.052 0.056 0.058 0.064 0.071 0.077 0.081 0.082 0.081 0.077 0.071 0.064 0.058 0.056 0.052 
                 layers.0.feed_forward.w1.weight - [ 4096, 11008], type =    f16 quantizing .. size =   172.00 MB ->    26.88 MB | hist: 0.000 0.050 0.069 0.069 0.069 0.069 0.069 0.070 0.070 0.069 0.069 0.069 0.069 0.069 0.069 0.050 
                 layers.0.feed_forward.w2.weight - [11008,  4096], type =    f16 quantizing .. size =   172.00 MB ->    26.88 MB | hist: 0.000 0.050 0.069 0.069 0.069 0.069 0.070 0.070 0.070 0.070 0.069 0.069 0.069 0.069 0.069 0.050 
                 layers.0.feed_forward.w3.weight - [ 4096, 11008], type =    f16 quantizing .. size =   172.00 MB ->    26.88 MB | hist: 0.000 0.050 0.069 0.069 0.069 0.069 0.070 0.070 0.070 0.070 0.069 0.069 0.069 0.069 0.069 0.050 
                  layers.0.attention_norm.weight - [ 4096,     1], type =    f32 size =    0.016 MB
                        layers.0.ffn_norm.weight - [ 4096,     1], type =    f32 size =    0.016 MB
                    layers.1.attention.wq.weight - [ 4096,  4096], type =    f16 quantizing .. size =    64.00 MB ->    10.00 MB | hist: 0.000 0.049 0.068 0.069 0.069 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.069 0.069 0.068 0.049 
                    layers.1.attention.wk.weight - [ 4096,  4096], type =    f16 quantizing .. size =    64.00 MB ->    10.00 MB | hist: 0.000 0.049 0.067 0.069 0.069 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.069 0.069 0.068 0.049 

The bins are now much more evenly utilized. In this case, we no longer need to store amax, but have to store sig instead. Also, assumed that the mean is 0 which is probably a valid assumption.

prusnak commented 1 year ago

The bins are now much more evenly utilized.

I am wondering how we could update the algorithm so also the first bin is utilized, currently it is unused.

Also, assumed that the mean is 0 which is probably a valid assumption.

Maybe we can try computing the mean too and store both mean and sigma?

Most probably the storing of mean would allow us to use the first bin easily.

nonnull-ca commented 1 year ago

Basic idea for a int4->f16 lookup table on AVX2:

Low nibbles:

  1. Mask out high nibbles of each input byte, because VPSHUFB uses the high bit to clear bytes to zero, which we don't want.
  2. Do 2 vpshufbs for low/high half of each f16, using the input nibbles as a lookup into a constant vector.
  3. Do a vpunpcklbw and a vpunpckhbw to unshuffle the low/high half of each output f16.

High nibbles:

  1. Shift input right by 4.
  2. Do the same procedure as for low nibbles.

Result:

  1. Calculate low and high nibbles
  2. Do 2x vpunpcklwd and vpunpckhwd to unshuffle f16s back into the original order.

You can get away without the final unshuffle if you preshuffle the input int4s instead.


I'm familiar with ARM, but not so much NEON. (The perils of embedded programming.) That being said, it looks very similar at a first glance, with the following wrinkles:

  1. NEON has shorter vectors (128b).
  2. NEON has some interesting (de)interleaving loads and stores, which may be useful (though we're expanding input data by 4x, so I suspect this won't really work...)

It looks like you can swap vpunpck* to ZIP1/ZIP2, and vpshufb to TBL (same masking necessary, as out of range values turn to 0).


A quick attempt is at https://godbolt.org/z/G74oYP8nK. Not perfect - shuffles tend to be slow (throughput of 1/2) so I suggest storing the int4 table interleaved, and I haven't actually tested this, just stared at the output assembly for a bit - but may be good enough.

nonnull-ca commented 1 year ago

The bins are now much more evenly utilized.

I am wondering how we could update the algorithm so also the first bin is utilized, currently it is unused.

Also, assumed that the mean is 0 which is probably a valid assumption.

Maybe we can try computing the mean too and store both mean and sigma?

Most probably the storing of mean would allow us to use the first bin easily.

The reason why the first bin is not utilized is due to the following. Ignore floating-point rounding for a moment:

...and then something will be assigned to bin 0 if round(val*id) == -8. In other words, if

...but here we have a contradiction. Because amax >= abs(val), and so amax >= -val. And so bin 0 is never used.

This is because we're dealing with a signed bin, not an unsigned bin, and so the correct number is 15/2, not 7. See below.


For an optimal placement, assuming that the inputs have been transformed into a uniformish distribution with a maximum value of 1, to a first approximation you want the following buckets:

  1. [-16/16..-14/16]
  2. [-14/16..-12/16]
  3. [-12/16..-10/16]
  4. [-10/16..-8/16]
  5. [-8/16..-6/16]
  6. [-6/16..-4/16]
  7. [-4/16..-2/16]
  8. [-2/16..-0/16]
  9. [0/16..2/16]
  10. [2/16..4/16]
  11. [4/16..6/16]
  12. [6/16..8/16]
  13. [8/16..10/16]
  14. [10/16..12/16]
  15. [12/16..14/16]
  16. [14/16..16/16]

Tl:DR: try this:

        const float d = amax * (2.0f / 15.0f);
        const float id2 = amax ? 8.0f/amax : 0.0f; // not 1/d!
        [...]

            const float v0 = TF(x[i*QK + l + 0], sig)*id2; // [-amax..amax] -> [-8..=8]
            const float v1 = TF(x[i*QK + l + 1], sig)*id2; // [-amax..amax] -> [-8..=8]

            // Edge case handling: if v0 == 8 (because input value == amax exactly), then we'd end up with +16 as a result.
            // Deal with it by rounding this case down to 7.
            // Ditto, due to rounding abs(v0) can end up slightly larger than 8. Preemptively fix up if so.
            // Any value in [7..<8] works.
            const float BELOW8 = 7.99999952316f; // nextbefore(8.0f) // 7.0f
            const float v02 = min(max(v0, -8.0f), BELOW8); // [-8..=8] -> [-8..8]
            const float v12 = min(max(v1, -8.0f), BELOW8); // [-8..=8] -> [-8..8]

            const uint8_t vi0 = ((int8_t) (floor(v02))) + 8; // [-8..8] -> [0..16]
            const uint8_t vi1 = ((int8_t) (floor(v12))) + 8; // [-8..8] -> [0..16]

Note that d != 1/id2. This is deliberate.

(This does end up "always" shrinking the maximum or minimum value by amax/15, which isn't ideal. I don't know which is worse - shrinking average stddev somewhat, or always shrinking the min/max value.)

jarcen commented 1 year ago

Just in case if this detail was left unnoticed, the code I shared above that adjusts amax naturally utilizes first bins as a side effect:

hist: 0.024 0.020 0.025 0.038 0.056 0.078 0.098 0.113 0.118 0.113 0.098 0.078 0.056 0.038 0.025 0.022
hist: 0.026 0.020 0.025 0.038 0.056 0.077 0.098 0.112 0.118 0.112 0.098 0.077 0.056 0.038 0.025 0.022
hist: 0.026 0.019 0.024 0.036 0.054 0.076 0.099 0.117 0.124 0.117 0.099 0.076 0.054 0.036 0.024 0.021
hist: 0.026 0.020 0.025 0.038 0.056 0.077 0.098 0.113 0.118 0.113 0.098 0.077 0.056 0.038 0.025 0.022

Histogram is still a bit skewed to the right but it's much more symmetric.

blackhole89 commented 1 year ago

I tried to run the previously mentioned Q4_1 quantization method with some number of local relaxation steps to reduce the square error (down to 83% of the naive computation's error on average), but the result did not appear to improve on perplexity for Wikitext, being within 0.01 of naive Q4_1's after 30 steps (which I argued here to be sufficient for a preliminary estimate):

[1]4.5934,[2]5.1141,[3]6.0137,[4]6.5889,[5]6.6549,[6]6.6349,[7]6.8486,[8]6.9613,[9]7.3027,[10]7.5478,[11]7.7765,[12]7.8229,[13]7.7368,[14]7.8268,[15]8.0934,[16]7.6583,[17]7.5161,[18]7.4644,[19]7.0927,[20]7.0639,[21]6.9697,[22]6.7890,[23]6.7553,[24]6.6584,[25]6.6611,[26]6.4884,[27]6.2972,[28]6.1850,[29]6.0983,[30]5.9288,[31]5.8874,[32]5.9094,[33]5.8515

Along with this, I had some lower-quality data suggesting that just throwing out 1 min and max outlier when this improved square error actually made perplexity worse (by about 0.05). My current hypothesis is that perhaps it matters more to accurately represent weights that are further away from 0, as those wind up influencing the final dot product more. I want to try only throwing away the value closest to 0 next.

tacryt-socryp commented 1 year ago

Perhaps Posit arithmetic could be valuable? http://www.johngustafson.net/pdfs/BeatingFloatingPoint.pdf

prusnak commented 1 year ago

Also, assumed that the mean is 0 which is probably a valid assumption.

I used the following patch to compute histograms of the model:

diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py
index ccf2c57..a17a3c2 100644
--- a/convert-pth-to-ggml.py
+++ b/convert-pth-to-ggml.py
@@ -22,6 +22,9 @@ import struct
 import numpy as np
 import torch

+from matplotlib import pyplot as plt
+idx = 0
+
 from sentencepiece import SentencePieceProcessor

 def parse_args():
@@ -124,6 +127,13 @@ def process_and_write_variables(fout, model, ftype):
         fout.write(sname)

         # data output to file
+        hist, bins = np.histogram(data, bins=100)
+        plt.stairs(hist, bins)
+        global idx
+        plt.savefig(f"hist_{idx:08}.png")
+        plt.clf()
+        idx += 1
+
         data.tofile(fout)

 def main():

From quickly inspecting the results it seems that most of the layers indeed have normal distribution around mean 0.0, but there are also around 20% of layers which have mean != 0.0.

Attaching the zipped histograms: hist.zip

Some examples: hist_00000000-or8 hist_00000001-or8 hist_00000019-or8 hist_00000020-or8

nonnull-ca commented 1 year ago

Perhaps Posit arithmetic could be valuable? http://www.johngustafson.net/pdfs/BeatingFloatingPoint.pdf

Posits offer more dynamic range, at the expense of less accuracy for large numbers. If the largest weights matter the most, and they fit in an f16 (which we know they do), a posit is exactly the opposite of what we want.

image

nonnull-ca commented 1 year ago

From quickly inspecting the results it seems that most of the layers indeed have normal distribution around mean 0.0, but there are also around 20% of layers which have mean != 0.0.

Hm. What do those look like on a log plot? At a quick glance, those look heavier-tailed than a normal distribution.

Also, am I reading those correctly in that the max absolute weight is only ~2.5 or so? If so, a u16 scaling may actually be more accurate than a f16 scaling:

Instead of i4 * f16, try i4 * u16 / 2^17. This works for values between -4 and +3.75.

Say you're encoding a value near -2.5 as your max absolute. i4 is -8 in either scheme. With f16, your final step size is 2^-10. For u16, your step size is 2^-14.

MarcioPais commented 1 year ago

The quantization process should attempt to optimally preserve the magnitudes of the weights first and foremost, which is easily accomplished by considering the binary representation of the original IEEE 754 half-precision floating-point (binary16) weights and using a significance coding method, which allows for combined quantization and sparsification of the weights to an arbitrary (byte-level limited only) bit-per-weight with sub-bit accuracy.

The main problem is that such an approach isn't (to the best of my knowledge) vectorizable nor fast, so I don't think it will interest you.

A binary16 float has the following format: 1 sign bit, 5 exponent bits in offset-binary, 10 explicitly stored significand precision bits.

A preliminary run through the current block of weights to be encoded would examine the 5 exponent bits only and record the highest value seen. We'd then encode, in 2 bits, the position of the first active bit in this max exponent value, counting down from the MSB. Example:

If the max exponent is 01101, we store the value "1" (01b). This allows us to know that we will need to do (4-1)=3 significance coding rounds (Note: there is no need to do significance coding for the LSB, obviously. A corner case is if the max exponent is 0000x, where we'd spend bits on a useless significance coding round, but this should not be a problem in practice).

To perform a significance coding round, we need to choose a positional encoding strategy. A simply binary partitioning scheme should work well in practice (depending on the block size).

We start with 3 lists, the significant list S (empty), the insignificant list I (contains the indexes for all the weights in this block), along with the refinement list R (empty):

S={}, I={0, 1, 2, 3,.., 31}, R={}

We now recursively divide the I list looking for weights that are significant at this level, i.e., whose 2nd MSB of the exponent bits is set, and encode this partitioning (either with DFS or BFS). Example using BFS:

Assume that the weights 3 and 11 are the only ones that are significant at this level.

We begin by dividing the I list in 2 sub lists, as I0,0={0..15} and I0,1={16..31}.

Since there are significant weights in I0,0, we code a "1" for it, and a "0" for I0,1.

We now divide I0,0 into I1,0={0..7} and I1,1={8..15}, and since both contain significant weights, we code 2 "1"'s.

If we keep proceeding like this, the final output for this significant coding round will be "1011101001010101", and the lists will look like this:

S={3, 11}, I={0, 1, 2, 4,.., 10, 12,.., 31}, R={}

We now encode the sign bits for the weights in the S list, and since the R list is still empty, we don't need to do a refinement round.

The weights in the S list are now moved to the R list, and we restart the process, this time the threshold for significance being the 3rd MSB of the exponent bits being set.

Seeing as how the R list is no longer empty, at the end of this significance coding round, we'd perform a refinement round, where we'd emit the value of the next bit in the binary representation of the weights (in this case, the 3rd MSB of the exponent) contained in it.

So after a few rounds, for the highest magnitude weights, we'd already have sent the full exponent and would then start sending the bits for the significand.

We'd proceed in this fashion until we exhaust the bit budget allowed for encoding this block of weights (say 128 bits if we want to encode 32 weights-per-block and use an average of 4 bits-per-weight. Or 80 bits would get us 2.5 bits-per-weight).

So it's possible to get fractional average bits-per-weight, and we get variable weight precision, where the most significant weights get much better precision, and insignificant ones get culled (in essence, achieving quantization and sparsification all in one go).

This is probably more interesting as a highly compressed model storage format for transmission purposes, where the weight block dequantization would only be done once and the in-memory calculations would then use the recovered fp16 weights.

prusnak commented 1 year ago

Hm. What do those look like on a log plot? At a quick glance, those look heavier-tailed than a normal distribution.

Attached are histograms where y axis is on log scale: hist-log.zip

Also, am I reading those correctly in that the max absolute weight is only ~2.5 or so?

Correct! Minimum value in the whole LLaMa-7B model is -2.492, maximum is 2.625.

nonnull-ca commented 1 year ago

@MarcioPais very interesting. That may indeed be vectorizable if you process multiple quantization blocks at once.

What do you do if your bit budget is exhausted in the middle of a round?

encode this partitioning (either with DFS or BFS).

Another way to view this is you're just constructing a bitset of which of the entries in I are above the current threshold. Same effect, but somewhat easier to code.

So after a few rounds, for the highest magnitude weights, we'd already have sent the full exponent and would then start sending the bits for the significand.

This adds complexity, but in an ideal world you'd do this with round-to-nearest on the exponent, not round-towards-zero. If you have e.g. 3.999 (a.k.a. 2^1 with mantissa of 0x3FF), truncated at the end of the exponent, you'd recover a value of 2, whereas 4 is better.

This is probably more interesting as a highly compressed model storage format for transmission purposes, where the weight block dequantization would only be done once and the in-memory calculations would then use the recovered fp16 weights.

Something I've been noting is that larger CPU models appear to be fairly heavily memory-bandwidth limited. It's possible that we can get this working 'well enough' to save time overall...

@prusnak - thank you!

Interesting. The layers are all over the place.

Some of it looks normalish:

hist_00000041-fs8.png

(A normal distribution looks parabolic in a log plot.)

...and this one looks pretty much like a logistic distribution:

hist_00000012-fs8.png

(A logistic distribution looks like a ^ in a log plot.)

...and this looks far closer to a Cauchy distribution or somesuch:

hist_00000003-fs8.png

Correct! Minimum value in the whole LLaMa-7B model is -2.492, maximum is 2.625.

We're wasting bits then.

MarcioPais commented 1 year ago

@nonnull-ca

What do you do if your bit budget is exhausted in the middle of a round?

For the significance coding and refinement rounds, you only need to check for end-of-buffer (EOB) at the end of each round, it's only when reading the sign bits that you need to individually check for EOB, and if so, discard the weights as insignificant also.

I usually handle this by having my bit-IO routines ignore the requests (when writing) or by returning zeros (when reading) after EOB.

This adds complexity, but in an ideal world you'd do this with round-to-nearest on the exponent, not round-towards-zero.

Yes, but that is simple and only affects the performance of the quantization stage, where we'd probably want to do preprocessing anyway.

Further experiments would obviously be needed, but it's easy to think of possible optimizations:

nonnull-ca commented 1 year ago

Some observations about the current output:

The way this block-based quantization scheme works is by picking a scale for groups of, say, 32 weights. A natural question then is "what is the distribution of the scale values?" Well, let's assume that we're picking weights at random from a gaussian with mean 0 and stddev of σ. Your scale is going to be, roughly, 1/8th of the max value in a block[^1]. This is the largest order statistic, which is... complicated in general, but we can fairly easily Monte-carlo it.

For a stddev of 1 and QK=32, I get the following expected distribution of scales:

image

Note log/log scale. This is pretty close to a log-normal distribution with a (log) mean of -1.1668 and stddev of 0.1928, which I plotted in red.

I grabbed all the scale values (across all layers) of 63B aand plotted them as a lin/log:

image

You can see that most scales are near ~0.0045 or so, but with a tail to put it mildly. This makes sense - we're summing a bunch of different distributions, some of which have heavy tails.

[^1]: Actual figure is closer to max(v/7 if v>0 else -v/8 for v in vals), although not quite.

nonnull-ca commented 1 year ago

Alright, here's an entirely different approach:

  1. For each layer, compute 16 percentiles: 1/34, 3/34, 5/34, ... 33/34. (This can be done approximately fairly easily... or you can just brute-force it (sort all & lerp between the appropriate two elements for each.)
  2. For each weight, pick the minimum-error percentile.

Other things to check:


Something that would be useful for "someone" to do:

Run LLaMA on, say, the wikipedia text that we're using for perplexity, but including the gradient calculation. Keep track of mean((dPerplexity/dWeight)^2) for each weight across the text, and plot that versus the weight value for each layer.

Why? I don't know if relative or absolute error matters more for weights. (Or even something more complex). There's arguments to be made for both.


Maybe someone can come up with a better grid search?

I've implemented a basic optimal[^1] search locally. Essentially just a simple αβ search over i4s, with a little rearranging to avoid divisions and a few heuristics that seem to help speed. It's still very slow, but appears to have a ~14% improvement to RMS error over the naive approach[^2] when compressing a random sample of 1600 values from a gaussian (mean=0, stddev=0.1):

Histogram of a one-truncated distribution with mean of ~0.86 and stddev of ~0.09 (This is RMS error of optimal search / RMS error of naive approach.)

Unfortunately, I'm only getting ~2000 batches per second (~64k weights / second) with it, so quantizing the entire thing would take... 11 days or so?

This is actually survivable if the thing can be multithreaded[^3] - and indeed this is essentially trivially parallelizable - but someone would need to rewrite the framework to be able to do so. Right now it streams everything sequentially.

(Looking at cachegrind, I also suspect a factor of 2 or so can be shaved off by moving from a recursive to an iterative approach, as there's a lot of redundant work and function call overhead currently, but meh. That transformation is annoying to do when you have a recursive call in the middle of a loop.)

[^1]: rounding errors on the scale coefficient notwithstanding. [^2]: well really, over the slightly-less naive approach, which flips the sign of everything if abs(min) > abs(max) [^3]: on 32 cores, this would be 8.5 hours or so.

LostRuins commented 1 year ago

Is there a simple guide or explanation about the differences between q4_0 and q4_1? I understand the latter is slower, but is it actually better and why/by how much?

Edit: Never mind, I found it. https://github.com/ggerganov/ggml/pull/27

prusnak commented 1 year ago

Where do we stand on with GPTQ? I see there is now a convert script convert-gptq-to-ggml.py by @comex, but isn't it more effective to just use GPTQ instead of Q4_1 or even Q4_0? Has anyone done some benchmarks?

Did you have chance to look at GPTQ, @ggerganov?

Resources:

ggerganov commented 1 year ago

@prusnak My latest impression on the topic (I think based on @blackhole89 and @comex comments) is that GPTQ is equivalent to Q4_1 data storage in the sense that it has a floating point offset and scale factors + 4-bit quants. The difference between GPTQ and Q4_1 is the way that one computes the numbers. But the storage of the numbers is the same. Therefore, one can convert GPTQ models into Q4_1 models and ggml will effectively gain support for GPTQ this way.

I haven't had the time to look into details. But it would be very interesting to see perplexity numbers for GPTQ using ggml

ahoho commented 1 year ago

Not sure where I should be noting this, but I'm finding that q4_1 is much worse than q4_0 for a one-shot summarization task.

Here is the prompt I've been experimenting with (spoiler alerts for the 1965 novel Stoner; not using for any particular reason except that it's on my coffee table and not overly obscure):

 Stoner, by John Williams.

# Brief Summary

The

Here's the q4_0 generation. There are errors (e.g., it's not written in the first person) but it basically gets the gist. I can't run the 30B model with fp16 locally, but based on what I've seen for 13B, q4_0 doesn't perceptibly degrade things from half-precision (at least for this task).

$ ./main -m ./models/llama/llama-30b-hf/consolidated/ggml-model-q4_0.bin --color -f ./prompts/summary.txt --temp 0.1 --top_p 0.75 --top_k 40 --repeat_penalty 1.1 -b 1 -t 11

# omitting preamble for brevity

The novel is a first-person narrative of the life of William Stoner, an English professor who teaches at the same Midwestern university from which he graduated. Stoner marries the daughter of a farmer, but finds her aloof and cold. The marriage produces a daughter, Katherine, but Stoner remains distant to both his wife and child.

Stoner's career is undistinguished but not unsuccessful; he teaches the required survey courses year after year, but his colleagues recognize his intellectual dedication by making him head of the English Department toward the end of

Here's q4_1. There is just no comparison.

The book is about a boy named Johnny who lives in the city of New York. He has a dog named Lulu and he goes to school at Public School 101.

## Synopsis

Johnny's father is a policeman, and his mother is a teacher. Johnny's grandfather is a fireman, and his grandmother is a nurse. Johnny's uncle is an electrician, and his aunt is a secretary.

Again, these are for LLaMA 30B, but I'm getting the same results for 13B and Alpaca (w/ lora). Original weights were fp16.

chigkim commented 1 year ago

Does it make difference if you go from f16 vs f32 to q4_0 vs q4_1? There are 4 possible choices!

unbounded commented 1 year ago

With the current block size of 32, q4_0 is in practice a 5-bit quantization, and q4_1 a 6-bit quantization. With that in mind we could consider a q5_0 encoding similar to q4_0 that would be the same size (or smaller) than q4_1. I see some promising signs that 5-bits+magnitude with linear scaling, similar to q4_0, could outperform q4_1, but I would like to explore non-linear distributions as well.

One interesting thing about q4_1 is that it will sometimes flip the sign of the weight... Intuitively that might be an a worse error than similar magnitude errors that keep the sign. But in practice q4_1 beats q4_0 in perplexity, so maybe not?

qwopqwop200 commented 1 year ago

With the current block size of 32, q4_0 is in practice a 5-bit quantization, and q4_1 a 6-bit quantization. With that in mind we could consider a q5_0 encoding similar to q4_0 that would be the same size (or smaller) than q4_1. I see some promising signs that 5-bits+magnitude with linear scaling, similar to q4_0, could outperform q4_1, but I would like to explore non-linear distributions as well.

One interesting thing about q4_1 is that it will sometimes flip the sign of the weight... Intuitively that might be an a worse error than similar magnitude errors that keep the sign. But in practice q4_1 beats q4_0 in perplexity, so maybe not?

additionally, the current recommended setting for GPTQ is 4.16 bit quantization. You can simply count the bits with the following code.

in_dim = 8192
intermediate_dim = 22016
bit = 4
groupsize = 128 
# setting

def get_bit(indim,outdim,bit,groupsize):
    q_weight = (indim // 32 * bit) * (outdim) * (32)
    q_zeros = (indim // groupsize) * (outdim // 32 * bit) * (32)
    scales = (indim // groupsize) * (outdim) * (16)
    g_idx = (indim) * (32)
    weight = (indim) * (outdim) * (16) 
    total_bit = ((q_weight + q_zeros + scales + g_idx)/weight) * 16
    return total_bit

total_bit = 0
total_bit += get_bit(in_dim,in_dim,bit,groupsize) * 4 #q,k,v,o
total_bit += get_bit(in_dim,intermediate_dim,bit,groupsize) * 2 #gate,up
total_bit += get_bit(intermediate_dim,in_dim,bit,groupsize) * 1 #down
total_bit /= 7
print(total_bit)
unbounded commented 1 year ago

I have a silly implementation at https://github.com/unbounded/llama.cpp/blob/q4-q-harder/ggml.c#L562-L687 that I believe calculates the actual RMSE-optimal scaling factor for each block, maybe to some limit of numerical precision errors, or whatever bugs I've missed.

It does this by lining up the "interesting" scaling values where quantization changes and finding the local optimal score analytically for each of ~512 resulting configurations.

Not suggesting this for actual use as it is very slow, and we can get 99.9% of the way there by smarter approaches like #835 But consider it a contribution to the problem stated in the OP :)

Seems to be some actual papers on the subject like https://openaccess.thecvf.com/content/CVPR2021/papers/Idelbayev_Optimal_Quantization_Using_Scaled_Codebook_CVPR_2021_paper.pdf which makes some similar observations but uses a different approach, maybe better suited to bigger blocks/bit sizes.

sw commented 1 year ago

@unbounded : I have been playing around with your implementation; as it is it's probably too slow but could maybe be made faster.

I think if you ignore the sign of input values, i.e. convert them all to negative numbers, you could then halve the shape table:

const float shape[8] = {-7, -6, -5, -4, -3, -2, -1, 0};
// or
const float shape[9] = {-8, -7, -6, -5, -4, -3, -2, -1, 0};

But neither is quite as good, as you're either not using -8 or have the case where you need to clamp a +8 to a +7 when you fix the qis to match the actual sign of x[i].

Also, if you were to sort x[i] before creating the event list, you could maybe take some work out of the sorting required afterwards, but I haven't looked into that.

sw commented 1 year ago

After twiddling some more with #835 and @unbounded's code, I'm starting to doubt this is the right way: optimizing the scaling factor for RMSE, then checking perplexity to see if we haven't regressed.

We may be falling victim to the streetlight effect - optimizing for quantization RMSE just because that's fast and easy to measure.

When really, the better approach might be to search the best global scaling factor to achieve lowest perplexity. This may not be 7 as currently or 8 as in #729. It could be a larger value; clipping the maximum may be a worthy trade-off. (being clever with the sign of the scaling value as in #729 is probably a good idea anyway)

fingertap commented 1 year ago

Hey guys, what is the official name for the compression method used for Q4?

MfAl2 commented 1 year ago

Page 7 of this paper has a relatively efficient iterative procedure based on discrete calculus to find the scale factor for the minimum RMS quantisation error. They use it for a FPGA friendly two step log quantiser, but the math should work for a scaled uniform quantiser too.

https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136710657.pdf

MarcioPais commented 1 year ago

I've finally had the time to implement some of the ideas I mentioned previously, and though it may be of purely academic interest, I'd like to share some results.

I implemented 3 types of significance coding strategies for unary exponent coding:

Until a pre-determined percentage of weights are deemed significant, 4 options (delta, bitset and binary partitioning with 2 different thresholds) are tried and the best one is chosen. Afterwards, the bitset encoding strategy is always used.

Testing on the 7B model, for the 65 small layers with 4096 weights that don't really follow the same weight distribution, an optional preprocessing step that performs variable rounding was also implemented. Results below:

Bits-per-weight Method RMSE MaxAbsErr Coverage
6 q4_1 0.001782 0.127563 100%
6 q5_1 (ikawrakow) 0.000761 0.052734 100%
6 this, no preprocessing 0.000554 0.062500 95.96%
6 this, preprocessed 0.000573 0.031250 95.96%
5 q4_0 0.002218 0.142578 100%
5 q4_0 (ikawrakow) 0.001592 0.174804 100%
5 this, no preprocessing 0.001111 0.125000 91,96%
5 this, preprocessed 0.001126 0.062500 91,96%
4 q3_0 (sw) 0.003526 0.378662 100%
4 this, no preprocessing 0.002248 0.250000 83,95%
4 this, preprocessed 0.002262 0.125000 83,95%
3 q2_0 (sw) 0.007391 0.873046 100%
3 this, no preprocessing 0.004663 0.415039 68,17%
3 this, preprocessed 0.004677 0.249023 68,17%
2.5 this, no preprocessing 0.006663 0.499756 54,43%
2.5 this, preprocessed 0.006684 0.437256 54,43%
2 this, no preprocessing 0.009402 1.849609 42,87%
2 this, preprocessed 0.009431 1.748047 42,87%

As we go down in average bits-per-weight, we see that even though as expected RMSE scales almost perfectly, the maximum absolute error explodes rather quickly, influenced by those small layers. If we choose to skip quantizing them, the results are much better:

Bits-per-weight Method RMSE MaxAbsErr
2.5 this, no preprocessing 0.006648 0.246338
2 this, no preprocessing 0.009373 0.362305

Now, obviously, the most interesting aspect of this approach is not using it in CBR mode, but instead to use a VBR mode where the encoder stops whenever a certain metric is achieved. Possible useful metrics to try are RMSE, maximum and mean absolute errors, all divided by the range for each layer. Assuming a metric that translates well to perplexity degradation is used, that would allow us to get the smallest possible model size that still retains the quality we want.

The obvious elephant in the room here is that this encoding would have to be decoded and stored in memory in a custom sparse-capable format, so in practice it would probably only be useful with a very sparse and heavily quantized 65B model.

For future reference, doing a lossless encoding of the 7B model with this method requires an average of 13,76 bits-per-weight, so better than using the usual general-purpose compression algorithms, and only slightly behind a simple context-mixing compression algorithm.

MarcioPais commented 1 year ago

Now that the significance coding method (SCM) provides a good baseline for what should be achievable, I ran a few experiments to see how close to it I could get with a simple q4_1-like encoding.

The range of values per block can be encoded in 10 bits (5 bits for the exponent and 5 bits for the mantissa) and the minimum value per block can be encoded in 12 bits (1 sign bit, 5 exponent bits, 6 mantissa bits).

That leaves us 2 bits, that I'm using to index into a lookup table of quantizer step mappings, so that we can pick the one that either provides the smallest squared error or the smallest absolute error. The first entry in the LUT is the original linear mapping, and then I'm using 3 different logit-like mappings:

At the default block size (QK) of 32, this works out to 4.75 bits-per-weigth (bpw). Now for some quick results, when optimizing the mapping choice for RMSE:

Bits-per-weight Method RMSE MaxAbsErr
5 q4_0 0.002218 0.142578
5 q4_0 (ikawrakow) 0.001592 0.174804
4.75 q4_2 0.001577 0.118164
4.75 significance coding 0.001354 0.062500
4.375 q4_2 (QK=64) 0.001827 0.127831
4.375 significance coding 0.001784 0.062500
4.1875 q4_2 (QK=128) 0.002036 0.139648
4.1875 significance coding 0.002008 0.125000

At 4.75 bpw, the RMSE is ever so slightly better than the improved q4_0 method with 2 fp16 at 5 bpw, and significantly outperforms it in terms of MAE. However, it is still far from the result of SCM, until we start increasing the block size. At QK = 128, it still performs better than the original Q4_0 method, and is much closer to SCM.

Crucially, when compared to SCM, dequantization is much faster, and the LUT size can be cut in half by making use of the symmetry in the mappings.

ggerganov commented 1 year ago

@MarcioPais

Thank you for the detailed analysis. Few notes:

Testing on the 7B model, for the 65 small layers with 4096 weights that don't really follow the same weight distribution

The 1D normalization layers (i.e. layers.X.ffn_norm.weight, layers.X.attention_norm.weight, norm.weight) can remain easily in F32 format. No point in quantizing those

The first entry in the LUT is the original linear mapping, and then I'm using 3 different logit-like mappings

How often do we end up choosing either of the 4 mappings?

I will need some time to get into all the details provided here (and also from other users), but I like the various ideas that are being generated. Just an update, the short-term plan is to try to implement #995 efficiently. After that, we can try to apply some RMSE optimizing strategy to additionally bring the perplexity down.

MarcioPais commented 1 year ago

@ggerganov

How often do we end up choosing either of the 4 mappings?

QK linear map1 map2 map3
32 24,05% 22,75% 19,83% 33,37%
64 15,03% 22,09% 21,61% 41,27%
128 6,01% 18,29% 22,27% 53,43%

As the block size increases, we're getting a better approximation of the distribution, and hence the most "skewed" mapping (map3) is increasingly favored, and the "backup" linear mapping is almost never used.

However, it's important to note that very skewed mappings may not always be better, especially if MAE is a main factor for perplexity. Here's a run at QK=128 with the most "conservative" non-linear mapping (map1) replaced with an even more skewed mapping than map3:

QK linear map4 map2 map3 RMSE MAE
128 4,69% 26,37% 28,33% 40.61% 0.002017 0.158203

Here, even if we get a very small improvement to RMSE (0.002036 => 0.002017), the MAE increases by a non-negligible amount (0.139648 => 0.158203).

It should also be considered that until proper perplexity measurements taken in controlled, reproducible runs are available, comparing on RMSE and/or MAE alone might not reflect the characteristics of all the different quantization strategies proposed. For instance, SCM doesn't necessarily encode a representation for every weight, so for all of those, their contribution to RMSE is probably bigger than in the other simpler methods that use a fixed bpw. But it's well known in the literature that for very large models, sparsification to a significant degree is usually possible and at low levels can even improve the results, so the difference in perplexity for different methods at similar RMSE may be higher than expected.

MarcioPais commented 1 year ago

I ran some perplexity tests on SCM and the non-linear mapping quantization method (NLM), here are some results:

LLaMA-7B

Bits-per-weight Method Perplexity
16 FP16 (Reference) 5.9565
6+ q4_3 RMSE-optimized (#1106) + FP16 output tensor 6.0085
6 q4_1 6.0936
6 q4_3 6.0617
6 q4_3 RMSE-optimized (#1106) 6.0344
6 SCM 5.9681
SCM @ 6bpw [1]4.2472,[2]4.7509,[3]5.6108,[4]6.2010,[5]6.3296,[6]6.2916,[7]6.4857,[8]6.5790,[9]6.8988,[10]7.1445,[11]7.3559,[12]7.3710,[13]7.2867,[14]7.3399,[15]7.5808,[16]7.2083,[17]7.0967,[18]7.0441,[19]6.6939,[20]6.6843,[21]6.5953,[22]6.4216,[23]6.3900,[24]6.2996,[25]6.3004,[26]6.1424,[27]5.9721,[28]5.8749,[29]5.7887,[30]5.6339,[31]5.6051,[32]5.6264,[33]5.5719,[34]5.6020,[35]5.6253,[36]5.6611,[37]5.6657,[38]5.6747,[39]5.7071,[40]5.7572,[41]5.7669,[42]5.8043,[43]5.7666,[44]5.8232,[45]5.8260,[46]5.8001,[47]5.8202,[48]5.7954,[49]5.7969,[50]5.7581,[51]5.7545,[52]5.7447,[53]5.7894,[54]5.7738,[55]5.7517,[56]5.7805,[57]5.7997,[58]5.8189,[59]5.8355,[60]5.8766,[61]5.8701,[62]5.9277,[63]5.9586,[64]5.9716,[65]6.0135,[66]6.0206,[67]6.0374,[68]6.0511,[69]6.0743,[70]6.1041,[71]6.1249,[72]6.1559,[73]6.2136,[74]6.2178,[75]6.2313,[76]6.2431,[77]6.2540,[78]6.2395,[79]6.2666,[80]6.2598,[81]6.2708,[82]6.2754,[83]6.2259,[84]6.2086,[85]6.1956,[86]6.1747,[87]6.1102,[88]6.0848,[89]6.0654,[90]6.0511,[91]6.0730,[92]6.0675,[93]6.0673,[94]6.0647,[95]6.0922,[96]6.0920,[97]6.0869,[98]6.0810,[99]6.0674,[100]6.0663,[101]6.0895,[102]6.0843,[103]6.1045,[104]6.1118,[105]6.1116,[106]6.1277,[107]6.1267,[108]6.1401,[109]6.1353,[110]6.1319,[111]6.1540,[112]6.1741,[113]6.1762,[114]6.1727,[115]6.1784,[116]6.1699,[117]6.1750,[118]6.2028,[119]6.2240,[120]6.2582,[121]6.2727,[122]6.2969,[123]6.3327,[124]6.3497,[125]6.3408,[126]6.3793,[127]6.4147,[128]6.4441,[129]6.4292,[130]6.4376,[131]6.4341,[132]6.4267,[133]6.4133,[134]6.4230,[135]6.4191,[136]6.4090,[137]6.4018,[138]6.3845,[139]6.3743,[140]6.3706,[141]6.3415,[142]6.3379,[143]6.3086,[144]6.2888,[145]6.2794,[146]6.2676,[147]6.2710,[148]6.2711,[149]6.2658,[150]6.2617,[151]6.2637,[152]6.2540,[153]6.2378,[154]6.2297,[155]6.2363,[156]6.2317,[157]6.2484,[158]6.2528,[159]6.2572,[160]6.2596,[161]6.2711,[162]6.2433,[163]6.2318,[164]6.2089,[165]6.1788,[166]6.1523,[167]6.1161,[168]6.0859,[169]6.0724,[170]6.0619,[171]6.0360,[172]6.0196,[173]6.0035,[174]5.9743,[175]5.9532,[176]5.9418,[177]5.9224,[178]5.9003,[179]5.8839,[180]5.8748,[181]5.8540,[182]5.8367,[183]5.8233,[184]5.8224,[185]5.8152,[186]5.8159,[187]5.8220,[188]5.8182,[189]5.8352,[190]5.8360,[191]5.8564,[192]5.8723,[193]5.8886,[194]5.8995,[195]5.9201,[196]5.9352,[197]5.9557,[198]5.9703,[199]5.9734,[200]5.9782,[201]5.9731,[202]5.9916,[203]5.9989,[204]5.9977,[205]6.0079,[206]6.0146,[207]6.0108,[208]6.0189,[209]6.0228,[210]6.0280,[211]6.0383,[212]6.0452,[213]6.0556,[214]6.0580,[215]6.0603,[216]6.0741,[217]6.0921,[218]6.1053,[219]6.1050,[220]6.1015,[221]6.0970,[222]6.0947,[223]6.0852,[224]6.0781,[225]6.0744,[226]6.0946,[227]6.1025,[228]6.1076,[229]6.1136,[230]6.1104,[231]6.1266,[232]6.1152,[233]6.0991,[234]6.0848,[235]6.0653,[236]6.0590,[237]6.0497,[238]6.0527,[239]6.0385,[240]6.0285,[241]6.0306,[242]6.0340,[243]6.0325,[244]6.0216,[245]6.0188,[246]6.0081,[247]5.9967,[248]5.9897,[249]5.9874,[250]5.9918,[251]5.9849,[252]5.9814,[253]5.9719,[254]5.9667,[255]5.9558,[256]5.9384,[257]5.9265,[258]5.9187,[259]5.9167,[260]5.9088,[261]5.9047,[262]5.8994,[263]5.8939,[264]5.8717,[265]5.8712,[266]5.8699,[267]5.8634,[268]5.8724,[269]5.8704,[270]5.8714,[271]5.8789,[272]5.8822,[273]5.8823,[274]5.8848,[275]5.8930,[276]5.8988,[277]5.9143,[278]5.9238,[279]5.9330,[280]5.9359,[281]5.9455,[282]5.9513,[283]5.9655,[284]5.9734,[285]5.9819,[286]5.9950,[287]5.9946,[288]6.0004,[289]5.9923,[290]5.9772,[291]5.9627,[292]5.9482,[293]5.9351,[294]5.9372,[295]5.9362,[296]5.9409,[297]5.9397,[298]5.9427,[299]5.9401,[300]5.9297,[301]5.9299,[302]5.9222,[303]5.9137,[304]5.9056,[305]5.9022,[306]5.8899,[307]5.8920,[308]5.8951,[309]5.8798,[310]5.8746,[311]5.8682,[312]5.8704,[313]5.8650,[314]5.8634,[315]5.8483,[316]5.8433,[317]5.8275,[318]5.8076,[319]5.8193,[320]5.8312,[321]5.8358,[322]5.8319,[323]5.8254,[324]5.8227,[325]5.8327,[326]5.8328,[327]5.8347,[328]5.8384,[329]5.8440,[330]5.8464,[331]5.8584,[332]5.8558,[333]5.8625,[334]5.8571,[335]5.8511,[336]5.8547,[337]5.8524,[338]5.8518,[339]5.8468,[340]5.8426,[341]5.8504,[342]5.8530,[343]5.8577,[344]5.8579,[345]5.8585,[346]5.8561,[347]5.8599,[348]5.8631,[349]5.8655,[350]5.8623,[351]5.8631,[352]5.8632,[353]5.8577,[354]5.8577,[355]5.8627,[356]5.8656,[357]5.8621,[358]5.8711,[359]5.8736,[360]5.8701,[361]5.8698,[362]5.8767,[363]5.8876,[364]5.8936,[365]5.8986,[366]5.8999,[367]5.9081,[368]5.9057,[369]5.9067,[370]5.9082,[371]5.9029,[372]5.9077,[373]5.9121,[374]5.9105,[375]5.9106,[376]5.9171,[377]5.9128,[378]5.9154,[379]5.9212,[380]5.9134,[381]5.9101,[382]5.9051,[383]5.9044,[384]5.9039,[385]5.9030,[386]5.9026,[387]5.9025,[388]5.8990,[389]5.8941,[390]5.8873,[391]5.8799,[392]5.8759,[393]5.8741,[394]5.8768,[395]5.8756,[396]5.8685,[397]5.8756,[398]5.8793,[399]5.8870,[400]5.8872,[401]5.8887,[402]5.8897,[403]5.8917,[404]5.8981,[405]5.8889,[406]5.8857,[407]5.8853,[408]5.8869,[409]5.8982,[410]5.9089,[411]5.9198,[412]5.9353,[413]5.9458,[414]5.9532,[415]5.9587,[416]5.9662,[417]5.9778,[418]5.9814,[419]5.9881,[420]5.9969,[421]6.0083,[422]6.0122,[423]6.0191,[424]6.0296,[425]6.0380,[426]6.0443,[427]6.0487,[428]6.0567,[429]6.0617,[430]6.0697,[431]6.0833,[432]6.0871,[433]6.0864,[434]6.0824,[435]6.0832,[436]6.0857,[437]6.0951,[438]6.1025,[439]6.0994,[440]6.0985,[441]6.0936,[442]6.0922,[443]6.0935,[444]6.0939,[445]6.0921,[446]6.0944,[447]6.0973,[448]6.1012,[449]6.0988,[450]6.0996,[451]6.0957,[452]6.0821,[453]6.0738,[454]6.0682,[455]6.0693,[456]6.0739,[457]6.0758,[458]6.0736,[459]6.0741,[460]6.0826,[461]6.0799,[462]6.0785,[463]6.0824,[464]6.0812,[465]6.0787,[466]6.0710,[467]6.0712,[468]6.0710,[469]6.0730,[470]6.0734,[471]6.0687,[472]6.0730,[473]6.0679,[474]6.0689,[475]6.0629,[476]6.0646,[477]6.0573,[478]6.0562,[479]6.0619,[480]6.0664,[481]6.0683,[482]6.0639,[483]6.0599,[484]6.0619,[485]6.0600,[486]6.0543,[487]6.0540,[488]6.0517,[489]6.0472,[490]6.0448,[491]6.0420,[492]6.0364,[493]6.0338,[494]6.0321,[495]6.0317,[496]6.0279,[497]6.0225,[498]6.0207,[499]6.0165,[500]6.0075,[501]6.0010,[502]6.0012,[503]6.0006,[504]5.9922,[505]5.9943,[506]5.9950,[507]5.9892,[508]5.9853,[509]5.9847,[510]5.9880,[511]5.9925,[512]5.9960,[513]5.9980,[514]6.0042,[515]5.9989,[516]5.9980,[517]5.9990,[518]5.9988,[519]6.0016,[520]6.0040,[521]6.0054,[522]6.0081,[523]6.0087,[524]6.0145,[525]6.0177,[526]6.0185,[527]6.0204,[528]6.0155,[529]6.0160,[530]6.0110,[531]6.0099,[532]6.0145,[533]6.0168,[534]6.0152,[535]6.0173,[536]6.0120,[537]6.0099,[538]6.0147,[539]6.0157,[540]6.0194,[541]6.0197,[542]6.0208,[543]6.0223,[544]6.0234,[545]6.0215,[546]6.0222,[547]6.0182,[548]6.0135,[549]6.0136,[550]6.0108,[551]6.0075,[552]6.0053,[553]6.0017,[554]5.9998,[555]5.9968,[556]5.9965,[557]5.9987,[558]5.9949,[559]5.9944,[560]5.9943,[561]5.9945,[562]5.9923,[563]5.9920,[564]5.9960,[565]5.9981,[566]5.9980,[567]5.9959,[568]5.9964,[569]5.9951,[570]5.9979,[571]5.9983,[572]5.9993,[573]5.9993,[574]5.9959,[575]5.9954,[576]5.9954,[577]5.9940,[578]5.9921,[579]5.9927,[580]5.9864,[581]5.9828,[582]5.9817,[583]5.9826,[584]5.9829,[585]5.9755,[586]5.9690,[587]5.9695,[588]5.9743,[589]5.9795,[590]5.9825,[591]5.9846,[592]5.9834,[593]5.9802,[594]5.9813,[595]5.9791,[596]5.9823,[597]5.9803,[598]5.9775,[599]5.9797,[600]5.9792,[601]5.9778,[602]5.9787,[603]5.9816,[604]5.9824,[605]5.9858,[606]5.9879,[607]5.9862,[608]5.9831,[609]5.9839,[610]5.9873,[611]5.9856,[612]5.9881,[613]5.9845,[614]5.9797,[615]5.9727,[616]5.9754,[617]5.9695,[618]5.9648,[619]5.9596,[620]5.9463,[621]5.9397,[622]5.9381,[623]5.9397,[624]5.9402,[625]5.9403,[626]5.9392,[627]5.9415,[628]5.9416,[629]5.9412,[630]5.9442,[631]5.9498,[632]5.9554,[633]5.9540,[634]5.9574,[635]5.9580,[636]5.9547,[637]5.9513,[638]5.9538,[639]5.9506,[640]5.9516,[641]5.9518,[642]5.9583,[643]5.9604,[644]5.9617,[645]5.9599,[646]5.9638,[647]5.9599,[648]5.9607,[649]5.9609,[650]5.9646,[651]5.9698,[652]5.9709,[653]5.9748,[654]5.9686,[655]5.9681,
Bits-per-weight Method Perplexity
5 q4_0 6.2103
5 q4_2 6.1698
5 SCM 6.1347
4.75 NLM 6.1014
SCM @ 5bpw system_info: n_threads = 12 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | perplexity : calculating perplexity over 655 chunks, batch_size=512 37.45 seconds per pass - ETA 6 hours 48 minutes [1]4.4089,[2]4.9282,[3]5.8177,[4]6.4359,[5]6.4976,[6]6.4495,[7]6.6304,[8]6.7209,[9]7.0611,[10]7.2937,[11]7.5017,[12]7.5176,[13]7.4415,[14]7.5052,[15]7.7516,[16]7.3596,[17]7.2417,[18]7.2003,[19]6.8408,[20]6.8334,[21]6.7391,[22]6.5654,[23]6.5302,[24]6.4427,[25]6.4505,[26]6.2916,[27]6.1210,[28]6.0259,[29]5.9395,[30]5.7856,[31]5.7550,[32]5.7753,[33]5.7167,[34]5.7516,[35]5.7764,[36]5.8150,[37]5.8214,[38]5.8368,[39]5.8702,[40]5.9239,[41]5.9370,[42]5.9788,[43]5.9395,[44]5.9962,[45]6.0005,[46]5.9744,[47]5.9941,[48]5.9678,[49]5.9707,[50]5.9317,[51]5.9269,[52]5.9150,[53]5.9577,[54]5.9418,[55]5.9158,[56]5.9467,[57]5.9676,[58]5.9890,[59]6.0037,[60]6.0472,[61]6.0423,[62]6.1007,[63]6.1334,[64]6.1504,[65]6.1970,[66]6.2038,[67]6.2212,[68]6.2351,[69]6.2595,[70]6.2916,[71]6.3126,[72]6.3447,[73]6.4052,[74]6.4100,[75]6.4234,[76]6.4349,[77]6.4455,[78]6.4305,[79]6.4572,[80]6.4486,[81]6.4577,[82]6.4619,[83]6.4100,[84]6.3917,[85]6.3783,[86]6.3558,[87]6.2897,[88]6.2642,[89]6.2441,[90]6.2291,[91]6.2516,[92]6.2469,[93]6.2463,[94]6.2435,[95]6.2722,[96]6.2717,[97]6.2668,[98]6.2610,[99]6.2464,[100]6.2471,[101]6.2707,[102]6.2647,[103]6.2864,[104]6.2948,[105]6.2947,[106]6.3102,[107]6.3076,[108]6.3217,[109]6.3159,[110]6.3121,[111]6.3343,[112]6.3552,[113]6.3557,[114]6.3515,[115]6.3591,[116]6.3504,[117]6.3558,[118]6.3850,[119]6.4055,[120]6.4412,[121]6.4570,[122]6.4815,[123]6.5180,[124]6.5352,[125]6.5259,[126]6.5652,[127]6.6029,[128]6.6336,[129]6.6179,[130]6.6265,[131]6.6226,[132]6.6150,[133]6.6005,[134]6.6102,[135]6.6068,[136]6.5962,[137]6.5881,[138]6.5714,[139]6.5616,[140]6.5571,[141]6.5273,[142]6.5240,[143]6.4959,[144]6.4761,[145]6.4667,[146]6.4541,[147]6.4598,[148]6.4599,[149]6.4535,[150]6.4498,[151]6.4509,[152]6.4414,[153]6.4239,[154]6.4149,[155]6.4216,[156]6.4162,[157]6.4340,[158]6.4377,[159]6.4422,[160]6.4443,[161]6.4560,[162]6.4264,[163]6.4148,[164]6.3907,[165]6.3593,[166]6.3322,[167]6.2948,[168]6.2627,[169]6.2490,[170]6.2385,[171]6.2111,[172]6.1939,[173]6.1769,[174]6.1460,[175]6.1242,[176]6.1139,[177]6.0933,[178]6.0704,[179]6.0532,[180]6.0443,[181]6.0226,[182]6.0043,[183]5.9905,[184]5.9899,[185]5.9828,[186]5.9840,[187]5.9900,[188]5.9855,[189]6.0029,[190]6.0042,[191]6.0260,[192]6.0426,[193]6.0601,[194]6.0715,[195]6.0928,[196]6.1086,[197]6.1304,[198]6.1453,[199]6.1484,[200]6.1536,[201]6.1482,[202]6.1677,[203]6.1753,[204]6.1736,[205]6.1851,[206]6.1917,[207]6.1871,[208]6.1954,[209]6.1999,[210]6.2046,[211]6.2148,[212]6.2220,[213]6.2331,[214]6.2360,[215]6.2400,[216]6.2551,[217]6.2733,[218]6.2867,[219]6.2867,[220]6.2829,[221]6.2793,[222]6.2763,[223]6.2658,[224]6.2584,[225]6.2547,[226]6.2754,[227]6.2838,[228]6.2892,[229]6.2958,[230]6.2922,[231]6.3085,[232]6.2961,[233]6.2794,[234]6.2647,[235]6.2471,[236]6.2399,[237]6.2300,[238]6.2330,[239]6.2178,[240]6.2074,[241]6.2105,[242]6.2139,[243]6.2121,[244]6.2006,[245]6.1978,[246]6.1862,[247]6.1744,[248]6.1667,[249]6.1643,[250]6.1679,[251]6.1610,[252]6.1573,[253]6.1475,[254]6.1430,[255]6.1306,[256]6.1125,[257]6.1005,[258]6.0922,[259]6.0905,[260]6.0829,[261]6.0788,[262]6.0732,[263]6.0681,[264]6.0457,[265]6.0454,[266]6.0444,[267]6.0378,[268]6.0472,[269]6.0449,[270]6.0462,[271]6.0542,[272]6.0576,[273]6.0573,[274]6.0597,[275]6.0683,[276]6.0743,[277]6.0905,[278]6.1004,[279]6.1095,[280]6.1122,[281]6.1217,[282]6.1278,[283]6.1426,[284]6.1504,[285]6.1590,[286]6.1724,[287]6.1723,[288]6.1784,[289]6.1695,[290]6.1540,[291]6.1386,[292]6.1234,[293]6.1099,[294]6.1119,[295]6.1113,[296]6.1158,[297]6.1142,[298]6.1179,[299]6.1148,[300]6.1039,[301]6.1036,[302]6.0958,[303]6.0870,[304]6.0783,[305]6.0755,[306]6.0627,[307]6.0646,[308]6.0685,[309]6.0523,[310]6.0465,[311]6.0401,[312]6.0425,[313]6.0371,[314]6.0355,[315]6.0196,[316]6.0146,[317]5.9984,[318]5.9775,[319]5.9897,[320]6.0021,[321]6.0063,[322]6.0020,[323]5.9953,[324]5.9921,[325]6.0020,[326]6.0017,[327]6.0039,[328]6.0078,[329]6.0137,[330]6.0163,[331]6.0284,[332]6.0256,[333]6.0325,[334]6.0265,[335]6.0199,[336]6.0234,[337]6.0204,[338]6.0197,[339]6.0142,[340]6.0101,[341]6.0180,[342]6.0200,[343]6.0254,[344]6.0255,[345]6.0256,[346]6.0228,[347]6.0270,[348]6.0300,[349]6.0322,[350]6.0286,[351]6.0289,[352]6.0295,[353]6.0238,[354]6.0238,[355]6.0290,[356]6.0317,[357]6.0286,[358]6.0375,[359]6.0404,[360]6.0365,[361]6.0363,[362]6.0434,[363]6.0548,[364]6.0611,[365]6.0665,[366]6.0673,[367]6.0760,[368]6.0735,[369]6.0745,[370]6.0759,[371]6.0700,[372]6.0748,[373]6.0799,[374]6.0785,[375]6.0788,[376]6.0859,[377]6.0813,[378]6.0843,[379]6.0906,[380]6.0826,[381]6.0791,[382]6.0740,[383]6.0735,[384]6.0728,[385]6.0722,[386]6.0716,[387]6.0714,[388]6.0674,[389]6.0622,[390]6.0551,[391]6.0471,[392]6.0430,[393]6.0415,[394]6.0442,[395]6.0432,[396]6.0357,[397]6.0432,[398]6.0466,[399]6.0547,[400]6.0545,[401]6.0563,[402]6.0569,[403]6.0589,[404]6.0656,[405]6.0563,[406]6.0528,[407]6.0521,[408]6.0536,[409]6.0655,[410]6.0764,[411]6.0876,[412]6.1035,[413]6.1145,[414]6.1224,[415]6.1277,[416]6.1351,[417]6.1471,[418]6.1511,[419]6.1585,[420]6.1672,[421]6.1791,[422]6.1837,[423]6.1910,[424]6.2023,[425]6.2109,[426]6.2175,[427]6.2219,[428]6.2300,[429]6.2351,[430]6.2433,[431]6.2571,[432]6.2612,[433]6.2606,[434]6.2566,[435]6.2573,[436]6.2593,[437]6.2685,[438]6.2758,[439]6.2729,[440]6.2723,[441]6.2670,[442]6.2656,[443]6.2672,[444]6.2675,[445]6.2657,[446]6.2682,[447]6.2711,[448]6.2751,[449]6.2723,[450]6.2734,[451]6.2693,[452]6.2561,[453]6.2477,[454]6.2419,[455]6.2428,[456]6.2475,[457]6.2493,[458]6.2472,[459]6.2474,[460]6.2559,[461]6.2533,[462]6.2519,[463]6.2563,[464]6.2552,[465]6.2525,[466]6.2444,[467]6.2444,[468]6.2438,[469]6.2457,[470]6.2458,[471]6.2405,[472]6.2447,[473]6.2391,[474]6.2402,[475]6.2341,[476]6.2358,[477]6.2283,[478]6.2275,[479]6.2333,[480]6.2381,[481]6.2399,[482]6.2353,[483]6.2311,[484]6.2332,[485]6.2316,[486]6.2260,[487]6.2257,[488]6.2237,[489]6.2189,[490]6.2168,[491]6.2136,[492]6.2075,[493]6.2046,[494]6.2029,[495]6.2029,[496]6.1993,[497]6.1937,[498]6.1917,[499]6.1871,[500]6.1776,[501]6.1708,[502]6.1712,[503]6.1706,[504]6.1620,[505]6.1645,[506]6.1654,[507]6.1598,[508]6.1558,[509]6.1552,[510]6.1590,[511]6.1633,[512]6.1670,[513]6.1689,[514]6.1754,[515]6.1699,[516]6.1689,[517]6.1697,[518]6.1696,[519]6.1724,[520]6.1748,[521]6.1766,[522]6.1796,[523]6.1802,[524]6.1860,[525]6.1895,[526]6.1903,[527]6.1920,[528]6.1869,[529]6.1873,[530]6.1823,[531]6.1810,[532]6.1858,[533]6.1882,[534]6.1867,[535]6.1888,[536]6.1833,[537]6.1811,[538]6.1860,[539]6.1871,[540]6.1907,[541]6.1910,[542]6.1919,[543]6.1934,[544]6.1945,[545]6.1924,[546]6.1931,[547]6.1886,[548]6.1837,[549]6.1839,[550]6.1807,[551]6.1775,[552]6.1754,[553]6.1715,[554]6.1693,[555]6.1664,[556]6.1663,[557]6.1682,[558]6.1647,[559]6.1642,[560]6.1638,[561]6.1639,[562]6.1616,[563]6.1615,[564]6.1657,[565]6.1677,[566]6.1675,[567]6.1656,[568]6.1662,[569]6.1647,[570]6.1676,[571]6.1680,[572]6.1691,[573]6.1691,[574]6.1658,[575]6.1655,[576]6.1656,[577]6.1641,[578]6.1622,[579]6.1631,[580]6.1564,[581]6.1527,[582]6.1515,[583]6.1524,[584]6.1527,[585]6.1450,[586]6.1384,[587]6.1390,[588]6.1438,[589]6.1493,[590]6.1523,[591]6.1543,[592]6.1529,[593]6.1496,[594]6.1504,[595]6.1481,[596]6.1515,[597]6.1493,[598]6.1466,[599]6.1488,[600]6.1483,[601]6.1468,[602]6.1479,[603]6.1511,[604]6.1520,[605]6.1553,[606]6.1576,[607]6.1561,[608]6.1529,[609]6.1536,[610]6.1570,[611]6.1552,[612]6.1579,[613]6.1540,[614]6.1490,[615]6.1417,[616]6.1445,[617]6.1384,[618]6.1335,[619]6.1280,[620]6.1141,[621]6.1073,[622]6.1056,[623]6.1072,[624]6.1077,[625]6.1077,[626]6.1065,[627]6.1085,[628]6.1085,[629]6.1081,[630]6.1111,[631]6.1168,[632]6.1223,[633]6.1207,[634]6.1239,[635]6.1244,[636]6.1210,[637]6.1176,[638]6.1204,[639]6.1173,[640]6.1181,[641]6.1183,[642]6.1250,[643]6.1270,[644]6.1282,[645]6.1260,[646]6.1299,[647]6.1263,[648]6.1269,[649]6.1270,[650]6.1309,[651]6.1364,[652]6.1375,[653]6.1414,[654]6.1353,[655]6.1347,
NLM @ 4.75bpw [1]4.3701,[2]4.8432,[3]5.7421,[4]6.3178,[5]6.4298,[6]6.3957,[7]6.5841,[8]6.6749,[9]7.0199,[10]7.2736,[11]7.5004,[12]7.5294,[13]7.4493,[14]7.4927,[15]7.7488,[16]7.3650,[17]7.2446,[18]7.2010,[19]6.8428,[20]6.8327,[21]6.7444,[22]6.5722,[23]6.5339,[24]6.4387,[25]6.4364,[26]6.2679,[27]6.0888,[28]5.9846,[29]5.8991,[30]5.7407,[31]5.7090,[32]5.7299,[33]5.6773,[34]5.7099,[35]5.7314,[36]5.7620,[37]5.7697,[38]5.7763,[39]5.8055,[40]5.8532,[41]5.8603,[42]5.9043,[43]5.8646,[44]5.9230,[45]5.9263,[46]5.8992,[47]5.9203,[48]5.8953,[49]5.8961,[50]5.8534,[51]5.8493,[52]5.8408,[53]5.8857,[54]5.8698,[55]5.8493,[56]5.8784,[57]5.8976,[58]5.9187,[59]5.9383,[60]5.9786,[61]5.9680,[62]6.0267,[63]6.0577,[64]6.0717,[65]6.1134,[66]6.1234,[67]6.1414,[68]6.1548,[69]6.1806,[70]6.2106,[71]6.2326,[72]6.2656,[73]6.3216,[74]6.3258,[75]6.3399,[76]6.3521,[77]6.3635,[78]6.3497,[79]6.3775,[80]6.3709,[81]6.3816,[82]6.3895,[83]6.3373,[84]6.3192,[85]6.3067,[86]6.2862,[87]6.2247,[88]6.1990,[89]6.1798,[90]6.1658,[91]6.1892,[92]6.1841,[93]6.1842,[94]6.1817,[95]6.2099,[96]6.2100,[97]6.2054,[98]6.1995,[99]6.1855,[100]6.1843,[101]6.2077,[102]6.2035,[103]6.2237,[104]6.2306,[105]6.2317,[106]6.2479,[107]6.2466,[108]6.2576,[109]6.2530,[110]6.2489,[111]6.2716,[112]6.2924,[113]6.2953,[114]6.2916,[115]6.2973,[116]6.2880,[117]6.2925,[118]6.3208,[119]6.3430,[120]6.3780,[121]6.3931,[122]6.4184,[123]6.4530,[124]6.4705,[125]6.4605,[126]6.5001,[127]6.5366,[128]6.5651,[129]6.5515,[130]6.5603,[131]6.5570,[132]6.5489,[133]6.5377,[134]6.5467,[135]6.5432,[136]6.5330,[137]6.5263,[138]6.5092,[139]6.4988,[140]6.4957,[141]6.4666,[142]6.4628,[143]6.4344,[144]6.4148,[145]6.4058,[146]6.3948,[147]6.3976,[148]6.3977,[149]6.3926,[150]6.3888,[151]6.3906,[152]6.3797,[153]6.3642,[154]6.3559,[155]6.3626,[156]6.3575,[157]6.3739,[158]6.3777,[159]6.3822,[160]6.3856,[161]6.3983,[162]6.3699,[163]6.3581,[164]6.3348,[165]6.3042,[166]6.2767,[167]6.2398,[168]6.2103,[169]6.1973,[170]6.1870,[171]6.1604,[172]6.1438,[173]6.1275,[174]6.0972,[175]6.0757,[176]6.0635,[177]6.0439,[178]6.0216,[179]6.0054,[180]5.9956,[181]5.9743,[182]5.9560,[183]5.9423,[184]5.9426,[185]5.9357,[186]5.9364,[187]5.9428,[188]5.9394,[189]5.9570,[190]5.9584,[191]5.9799,[192]5.9954,[193]6.0123,[194]6.0232,[195]6.0445,[196]6.0606,[197]6.0811,[198]6.0966,[199]6.0995,[200]6.1058,[201]6.1011,[202]6.1193,[203]6.1269,[204]6.1265,[205]6.1366,[206]6.1436,[207]6.1404,[208]6.1486,[209]6.1527,[210]6.1577,[211]6.1685,[212]6.1758,[213]6.1863,[214]6.1888,[215]6.1906,[216]6.2034,[217]6.2218,[218]6.2359,[219]6.2355,[220]6.2315,[221]6.2253,[222]6.2230,[223]6.2130,[224]6.2069,[225]6.2036,[226]6.2237,[227]6.2323,[228]6.2376,[229]6.2428,[230]6.2393,[231]6.2559,[232]6.2449,[233]6.2280,[234]6.2128,[235]6.1927,[236]6.1859,[237]6.1770,[238]6.1795,[239]6.1648,[240]6.1539,[241]6.1554,[242]6.1587,[243]6.1572,[244]6.1464,[245]6.1430,[246]6.1321,[247]6.1210,[248]6.1140,[249]6.1117,[250]6.1164,[251]6.1093,[252]6.1054,[253]6.0959,[254]6.0904,[255]6.0791,[256]6.0611,[257]6.0487,[258]6.0402,[259]6.0373,[260]6.0291,[261]6.0256,[262]6.0200,[263]6.0145,[264]5.9949,[265]5.9940,[266]5.9929,[267]5.9865,[268]5.9954,[269]5.9938,[270]5.9946,[271]6.0022,[272]6.0059,[273]6.0060,[274]6.0081,[275]6.0165,[276]6.0219,[277]6.0374,[278]6.0470,[279]6.0557,[280]6.0590,[281]6.0684,[282]6.0742,[283]6.0891,[284]6.0971,[285]6.1060,[286]6.1194,[287]6.1191,[288]6.1243,[289]6.1157,[290]6.1004,[291]6.0863,[292]6.0716,[293]6.0580,[294]6.0599,[295]6.0582,[296]6.0632,[297]6.0618,[298]6.0647,[299]6.0625,[300]6.0523,[301]6.0525,[302]6.0452,[303]6.0361,[304]6.0273,[305]6.0242,[306]6.0114,[307]6.0139,[308]6.0162,[309]6.0003,[310]5.9943,[311]5.9878,[312]5.9900,[313]5.9846,[314]5.9830,[315]5.9675,[316]5.9626,[317]5.9457,[318]5.9255,[319]5.9373,[320]5.9495,[321]5.9545,[322]5.9506,[323]5.9437,[324]5.9413,[325]5.9520,[326]5.9524,[327]5.9543,[328]5.9581,[329]5.9636,[330]5.9663,[331]5.9784,[332]5.9757,[333]5.9824,[334]5.9770,[335]5.9712,[336]5.9747,[337]5.9722,[338]5.9719,[339]5.9672,[340]5.9634,[341]5.9715,[342]5.9741,[343]5.9788,[344]5.9791,[345]5.9795,[346]5.9769,[347]5.9811,[348]5.9842,[349]5.9864,[350]5.9836,[351]5.9844,[352]5.9850,[353]5.9791,[354]5.9798,[355]5.9848,[356]5.9878,[357]5.9844,[358]5.9932,[359]5.9959,[360]5.9922,[361]5.9919,[362]5.9991,[363]6.0101,[364]6.0168,[365]6.0216,[366]6.0226,[367]6.0312,[368]6.0287,[369]6.0294,[370]6.0312,[371]6.0257,[372]6.0301,[373]6.0347,[374]6.0328,[375]6.0325,[376]6.0393,[377]6.0346,[378]6.0370,[379]6.0424,[380]6.0348,[381]6.0318,[382]6.0264,[383]6.0260,[384]6.0256,[385]6.0249,[386]6.0247,[387]6.0246,[388]6.0208,[389]6.0155,[390]6.0088,[391]6.0013,[392]5.9969,[393]5.9952,[394]5.9976,[395]5.9964,[396]5.9893,[397]5.9955,[398]5.9993,[399]6.0071,[400]6.0068,[401]6.0083,[402]6.0098,[403]6.0119,[404]6.0183,[405]6.0094,[406]6.0065,[407]6.0064,[408]6.0081,[409]6.0198,[410]6.0306,[411]6.0418,[412]6.0574,[413]6.0688,[414]6.0766,[415]6.0821,[416]6.0900,[417]6.1021,[418]6.1057,[419]6.1124,[420]6.1211,[421]6.1323,[422]6.1363,[423]6.1435,[424]6.1541,[425]6.1624,[426]6.1689,[427]6.1734,[428]6.1814,[429]6.1864,[430]6.1945,[431]6.2085,[432]6.2122,[433]6.2111,[434]6.2073,[435]6.2083,[436]6.2111,[437]6.2210,[438]6.2285,[439]6.2251,[440]6.2243,[441]6.2193,[442]6.2180,[443]6.2195,[444]6.2199,[445]6.2183,[446]6.2206,[447]6.2237,[448]6.2279,[449]6.2256,[450]6.2262,[451]6.2222,[452]6.2100,[453]6.2017,[454]6.1962,[455]6.1972,[456]6.2022,[457]6.2044,[458]6.2023,[459]6.2027,[460]6.2111,[461]6.2083,[462]6.2071,[463]6.2113,[464]6.2103,[465]6.2076,[466]6.1998,[467]6.2004,[468]6.2005,[469]6.2026,[470]6.2030,[471]6.1987,[472]6.2034,[473]6.1979,[474]6.1992,[475]6.1932,[476]6.1951,[477]6.1885,[478]6.1876,[479]6.1937,[480]6.1986,[481]6.2004,[482]6.1957,[483]6.1916,[484]6.1939,[485]6.1918,[486]6.1860,[487]6.1857,[488]6.1833,[489]6.1784,[490]6.1761,[491]6.1733,[492]6.1679,[493]6.1653,[494]6.1636,[495]6.1633,[496]6.1596,[497]6.1541,[498]6.1525,[499]6.1480,[500]6.1388,[501]6.1325,[502]6.1326,[503]6.1324,[504]6.1239,[505]6.1260,[506]6.1270,[507]6.1215,[508]6.1175,[509]6.1169,[510]6.1202,[511]6.1252,[512]6.1287,[513]6.1309,[514]6.1376,[515]6.1321,[516]6.1309,[517]6.1319,[518]6.1314,[519]6.1347,[520]6.1371,[521]6.1388,[522]6.1417,[523]6.1425,[524]6.1481,[525]6.1515,[526]6.1524,[527]6.1541,[528]6.1491,[529]6.1499,[530]6.1451,[531]6.1440,[532]6.1489,[533]6.1512,[534]6.1494,[535]6.1519,[536]6.1467,[537]6.1446,[538]6.1496,[539]6.1505,[540]6.1543,[541]6.1545,[542]6.1555,[543]6.1571,[544]6.1582,[545]6.1561,[546]6.1567,[547]6.1525,[548]6.1479,[549]6.1478,[550]6.1451,[551]6.1414,[552]6.1394,[553]6.1356,[554]6.1334,[555]6.1304,[556]6.1302,[557]6.1327,[558]6.1289,[559]6.1289,[560]6.1289,[561]6.1293,[562]6.1267,[563]6.1263,[564]6.1304,[565]6.1326,[566]6.1325,[567]6.1304,[568]6.1309,[569]6.1297,[570]6.1322,[571]6.1325,[572]6.1333,[573]6.1330,[574]6.1293,[575]6.1288,[576]6.1287,[577]6.1273,[578]6.1253,[579]6.1261,[580]6.1198,[581]6.1160,[582]6.1149,[583]6.1156,[584]6.1158,[585]6.1085,[586]6.1018,[587]6.1023,[588]6.1072,[589]6.1126,[590]6.1157,[591]6.1179,[592]6.1167,[593]6.1135,[594]6.1146,[595]6.1123,[596]6.1156,[597]6.1135,[598]6.1104,[599]6.1125,[600]6.1121,[601]6.1109,[602]6.1121,[603]6.1150,[604]6.1156,[605]6.1192,[606]6.1213,[607]6.1197,[608]6.1162,[609]6.1170,[610]6.1204,[611]6.1186,[612]6.1211,[613]6.1175,[614]6.1127,[615]6.1054,[616]6.1081,[617]6.1020,[618]6.0971,[619]6.0919,[620]6.0780,[621]6.0712,[622]6.0696,[623]6.0713,[624]6.0718,[625]6.0718,[626]6.0708,[627]6.0729,[628]6.0730,[629]6.0728,[630]6.0758,[631]6.0814,[632]6.0873,[633]6.0858,[634]6.0894,[635]6.0900,[636]6.0867,[637]6.0834,[638]6.0860,[639]6.0834,[640]6.0841,[641]6.0841,[642]6.0908,[643]6.0928,[644]6.0942,[645]6.0925,[646]6.0966,[647]6.0925,[648]6.0938,[649]6.0941,[650]6.0981,[651]6.1034,[652]6.1044,[653]6.1084,[654]6.1019,[655]6.1014,
Bits-per-weight Method Perplexity Perplexity at chunk #
4 q3_0 (#1004) ? [1]4.8166, [2]5.2200, [3]6.1143
4 SCM 6.3334 [1]4.4865, [2]5.0649, [3]5.9553
SCM @ 4bpw [1]4.4865,[2]5.0649,[3]5.9553,[4]6.5977,[5]6.6739,[6]6.6580,[7]6.8753,[8]6.9782,[9]7.3275,[10]7.5823,[11]7.8070,[12]7.8091,[13]7.7428,[14]7.8138,[15]8.0488,[16]7.6375,[17]7.5148,[18]7.4648,[19]7.0842,[20]7.0715,[21]6.9824,[22]6.7951,[23]6.7563,[24]6.6686,[25]6.6695,[26]6.5019,[27]6.3196,[28]6.2198,[29]6.1306,[30]5.9768,[31]5.9504,[32]5.9709,[33]5.9172,[34]5.9501,[35]5.9790,[36]6.0208,[37]6.0232,[38]6.0346,[39]6.0646,[40]6.1271,[41]6.1429,[42]6.1844,[43]6.1433,[44]6.1992,[45]6.2006,[46]6.1722,[47]6.1918,[48]6.1613,[49]6.1626,[50]6.1177,[51]6.1108,[52]6.0969,[53]6.1424,[54]6.1246,[55]6.0978,[56]6.1274,[57]6.1486,[58]6.1747,[59]6.1911,[60]6.2333,[61]6.2236,[62]6.2859,[63]6.3197,[64]6.3336,[65]6.3800,[66]6.3886,[67]6.4047,[68]6.4183,[69]6.4427,[70]6.4764,[71]6.4994,[72]6.5315,[73]6.5948,[74]6.5980,[75]6.6105,[76]6.6247,[77]6.6375,[78]6.6222,[79]6.6487,[80]6.6393,[81]6.6492,[82]6.6578,[83]6.6023,[84]6.5835,[85]6.5726,[86]6.5491,[87]6.4848,[88]6.4559,[89]6.4371,[90]6.4232,[91]6.4475,[92]6.4427,[93]6.4444,[94]6.4396,[95]6.4721,[96]6.4705,[97]6.4670,[98]6.4608,[99]6.4449,[100]6.4465,[101]6.4708,[102]6.4650,[103]6.4874,[104]6.4948,[105]6.4960,[106]6.5094,[107]6.5056,[108]6.5188,[109]6.5147,[110]6.5108,[111]6.5338,[112]6.5558,[113]6.5567,[114]6.5527,[115]6.5603,[116]6.5519,[117]6.5591,[118]6.5893,[119]6.6121,[120]6.6472,[121]6.6638,[122]6.6879,[123]6.7271,[124]6.7447,[125]6.7346,[126]6.7749,[127]6.8140,[128]6.8449,[129]6.8264,[130]6.8337,[131]6.8300,[132]6.8214,[133]6.8049,[134]6.8161,[135]6.8122,[136]6.8005,[137]6.7928,[138]6.7776,[139]6.7681,[140]6.7639,[141]6.7347,[142]6.7303,[143]6.7020,[144]6.6827,[145]6.6753,[146]6.6635,[147]6.6711,[148]6.6727,[149]6.6658,[150]6.6624,[151]6.6642,[152]6.6546,[153]6.6366,[154]6.6264,[155]6.6329,[156]6.6280,[157]6.6460,[158]6.6501,[159]6.6555,[160]6.6571,[161]6.6693,[162]6.6388,[163]6.6283,[164]6.6030,[165]6.5695,[166]6.5411,[167]6.5028,[168]6.4699,[169]6.4558,[170]6.4442,[171]6.4149,[172]6.3963,[173]6.3787,[174]6.3465,[175]6.3240,[176]6.3120,[177]6.2899,[178]6.2652,[179]6.2476,[180]6.2379,[181]6.2162,[182]6.1963,[183]6.1824,[184]6.1813,[185]6.1737,[186]6.1750,[187]6.1810,[188]6.1763,[189]6.1954,[190]6.1957,[191]6.2179,[192]6.2345,[193]6.2537,[194]6.2664,[195]6.2863,[196]6.3023,[197]6.3251,[198]6.3400,[199]6.3420,[200]6.3466,[201]6.3417,[202]6.3628,[203]6.3698,[204]6.3706,[205]6.3825,[206]6.3900,[207]6.3856,[208]6.3932,[209]6.3972,[210]6.4021,[211]6.4116,[212]6.4189,[213]6.4300,[214]6.4333,[215]6.4391,[216]6.4542,[217]6.4729,[218]6.4869,[219]6.4874,[220]6.4833,[221]6.4799,[222]6.4768,[223]6.4660,[224]6.4588,[225]6.4549,[226]6.4761,[227]6.4848,[228]6.4900,[229]6.4969,[230]6.4930,[231]6.5098,[232]6.4969,[233]6.4798,[234]6.4643,[235]6.4483,[236]6.4412,[237]6.4313,[238]6.4342,[239]6.4189,[240]6.4082,[241]6.4117,[242]6.4153,[243]6.4133,[244]6.4018,[245]6.3985,[246]6.3859,[247]6.3737,[248]6.3653,[249]6.3635,[250]6.3669,[251]6.3596,[252]6.3565,[253]6.3455,[254]6.3408,[255]6.3287,[256]6.3095,[257]6.2969,[258]6.2874,[259]6.2854,[260]6.2765,[261]6.2724,[262]6.2671,[263]6.2621,[264]6.2421,[265]6.2416,[266]6.2398,[267]6.2325,[268]6.2414,[269]6.2398,[270]6.2404,[271]6.2479,[272]6.2513,[273]6.2512,[274]6.2533,[275]6.2623,[276]6.2689,[277]6.2850,[278]6.2953,[279]6.3048,[280]6.3078,[281]6.3173,[282]6.3227,[283]6.3370,[284]6.3457,[285]6.3542,[286]6.3688,[287]6.3682,[288]6.3751,[289]6.3660,[290]6.3502,[291]6.3347,[292]6.3190,[293]6.3051,[294]6.3072,[295]6.3070,[296]6.3115,[297]6.3108,[298]6.3145,[299]6.3118,[300]6.3008,[301]6.3005,[302]6.2931,[303]6.2848,[304]6.2756,[305]6.2725,[306]6.2592,[307]6.2610,[308]6.2652,[309]6.2488,[310]6.2425,[311]6.2361,[312]6.2387,[313]6.2332,[314]6.2312,[315]6.2145,[316]6.2104,[317]6.1937,[318]6.1723,[319]6.1852,[320]6.1978,[321]6.2019,[322]6.1974,[323]6.1912,[324]6.1881,[325]6.1987,[326]6.1982,[327]6.2005,[328]6.2050,[329]6.2111,[330]6.2140,[331]6.2262,[332]6.2238,[333]6.2312,[334]6.2253,[335]6.2190,[336]6.2227,[337]6.2197,[338]6.2190,[339]6.2133,[340]6.2084,[341]6.2162,[342]6.2181,[343]6.2239,[344]6.2240,[345]6.2242,[346]6.2212,[347]6.2255,[348]6.2295,[349]6.2320,[350]6.2281,[351]6.2288,[352]6.2294,[353]6.2231,[354]6.2236,[355]6.2290,[356]6.2320,[357]6.2285,[358]6.2377,[359]6.2408,[360]6.2366,[361]6.2362,[362]6.2436,[363]6.2547,[364]6.2605,[365]6.2664,[366]6.2672,[367]6.2766,[368]6.2743,[369]6.2753,[370]6.2764,[371]6.2706,[372]6.2758,[373]6.2812,[374]6.2797,[375]6.2795,[376]6.2873,[377]6.2827,[378]6.2850,[379]6.2916,[380]6.2830,[381]6.2790,[382]6.2740,[383]6.2731,[384]6.2725,[385]6.2722,[386]6.2722,[387]6.2718,[388]6.2669,[389]6.2612,[390]6.2532,[391]6.2447,[392]6.2405,[393]6.2386,[394]6.2409,[395]6.2391,[396]6.2314,[397]6.2392,[398]6.2424,[399]6.2500,[400]6.2496,[401]6.2517,[402]6.2525,[403]6.2542,[404]6.2616,[405]6.2521,[406]6.2487,[407]6.2483,[408]6.2502,[409]6.2621,[410]6.2730,[411]6.2844,[412]6.3009,[413]6.3125,[414]6.3214,[415]6.3272,[416]6.3350,[417]6.3478,[418]6.3518,[419]6.3588,[420]6.3683,[421]6.3818,[422]6.3865,[423]6.3937,[424]6.4058,[425]6.4150,[426]6.4217,[427]6.4262,[428]6.4345,[429]6.4389,[430]6.4481,[431]6.4618,[432]6.4659,[433]6.4649,[434]6.4604,[435]6.4608,[436]6.4631,[437]6.4736,[438]6.4814,[439]6.4782,[440]6.4773,[441]6.4720,[442]6.4707,[443]6.4726,[444]6.4729,[445]6.4709,[446]6.4730,[447]6.4759,[448]6.4804,[449]6.4774,[450]6.4774,[451]6.4731,[452]6.4614,[453]6.4529,[454]6.4467,[455]6.4477,[456]6.4528,[457]6.4545,[458]6.4522,[459]6.4526,[460]6.4610,[461]6.4584,[462]6.4560,[463]6.4606,[464]6.4595,[465]6.4570,[466]6.4487,[467]6.4493,[468]6.4491,[469]6.4507,[470]6.4510,[471]6.4459,[472]6.4505,[473]6.4448,[474]6.4462,[475]6.4403,[476]6.4424,[477]6.4345,[478]6.4337,[479]6.4401,[480]6.4453,[481]6.4468,[482]6.4425,[483]6.4380,[484]6.4402,[485]6.4389,[486]6.4329,[487]6.4324,[488]6.4300,[489]6.4249,[490]6.4223,[491]6.4191,[492]6.4128,[493]6.4097,[494]6.4081,[495]6.4078,[496]6.4046,[497]6.3986,[498]6.3968,[499]6.3920,[500]6.3821,[501]6.3748,[502]6.3747,[503]6.3745,[504]6.3658,[505]6.3684,[506]6.3694,[507]6.3645,[508]6.3605,[509]6.3597,[510]6.3635,[511]6.3679,[512]6.3715,[513]6.3734,[514]6.3800,[515]6.3745,[516]6.3732,[517]6.3748,[518]6.3746,[519]6.3774,[520]6.3798,[521]6.3814,[522]6.3839,[523]6.3841,[524]6.3900,[525]6.3935,[526]6.3941,[527]6.3962,[528]6.3912,[529]6.3917,[530]6.3868,[531]6.3856,[532]6.3909,[533]6.3937,[534]6.3918,[535]6.3948,[536]6.3891,[537]6.3868,[538]6.3919,[539]6.3930,[540]6.3964,[541]6.3970,[542]6.3981,[543]6.3991,[544]6.4000,[545]6.3979,[546]6.3990,[547]6.3940,[548]6.3884,[549]6.3884,[550]6.3853,[551]6.3815,[552]6.3795,[553]6.3754,[554]6.3727,[555]6.3695,[556]6.3690,[557]6.3710,[558]6.3671,[559]6.3668,[560]6.3665,[561]6.3662,[562]6.3641,[563]6.3641,[564]6.3684,[565]6.3704,[566]6.3700,[567]6.3677,[568]6.3681,[569]6.3663,[570]6.3693,[571]6.3694,[572]6.3705,[573]6.3703,[574]6.3664,[575]6.3663,[576]6.3664,[577]6.3650,[578]6.3628,[579]6.3637,[580]6.3568,[581]6.3529,[582]6.3520,[583]6.3526,[584]6.3527,[585]6.3449,[586]6.3383,[587]6.3389,[588]6.3436,[589]6.3492,[590]6.3520,[591]6.3541,[592]6.3527,[593]6.3486,[594]6.3496,[595]6.3471,[596]6.3511,[597]6.3488,[598]6.3458,[599]6.3479,[600]6.3476,[601]6.3459,[602]6.3479,[603]6.3511,[604]6.3520,[605]6.3555,[606]6.3574,[607]6.3560,[608]6.3527,[609]6.3532,[610]6.3566,[611]6.3548,[612]6.3575,[613]6.3534,[614]6.3482,[615]6.3406,[616]6.3432,[617]6.3368,[618]6.3315,[619]6.3259,[620]6.3113,[621]6.3041,[622]6.3022,[623]6.3034,[624]6.3042,[625]6.3041,[626]6.3031,[627]6.3053,[628]6.3058,[629]6.3053,[630]6.3084,[631]6.3146,[632]6.3201,[633]6.3186,[634]6.3220,[635]6.3224,[636]6.3191,[637]6.3159,[638]6.3190,[639]6.3158,[640]6.3167,[641]6.3167,[642]6.3236,[643]6.3258,[644]6.3272,[645]6.3249,[646]6.3293,[647]6.3258,[648]6.3266,[649]6.3265,[650]6.3302,[651]6.3358,[652]6.3367,[653]6.3407,[654]6.3341,[655]6.3334,
Bits-per-weight Method Perplexity
3 q2_0 (#1004) 12.6438
3 SCM 7.8300
SCM @ 3bpw [1]6.0942,[2]6.7662,[3]7.6038,[4]8.3524,[5]8.2915,[6]8.2419,[7]8.4662,[8]8.5407,[9]8.9734,[10]9.3232,[11]9.6352,[12]9.6208,[13]9.5982,[14]9.7618,[15]10.0687,[16]9.5089,[17]9.3348,[18]9.3130,[19]8.7935,[20]8.7400,[21]8.6163,[22]8.4302,[23]8.3799,[24]8.2907,[25]8.2851,[26]8.0742,[27]7.8200,[28]7.7135,[29]7.6222,[30]7.4274,[31]7.4010,[32]7.4245,[33]7.3594,[34]7.3935,[35]7.4241,[36]7.5015,[37]7.5204,[38]7.5385,[39]7.5784,[40]7.6567,[41]7.6877,[42]7.7363,[43]7.6716,[44]7.7293,[45]7.7264,[46]7.6833,[47]7.7103,[48]7.6628,[49]7.6592,[50]7.5926,[51]7.5754,[52]7.5577,[53]7.6009,[54]7.5788,[55]7.5386,[56]7.5680,[57]7.5956,[58]7.6293,[59]7.6458,[60]7.7060,[61]7.6870,[62]7.7644,[63]7.8026,[64]7.8172,[65]7.8757,[66]7.8862,[67]7.9111,[68]7.9323,[69]7.9644,[70]8.0057,[71]8.0350,[72]8.0740,[73]8.1527,[74]8.1452,[75]8.1574,[76]8.1727,[77]8.1926,[78]8.1813,[79]8.2132,[80]8.2059,[81]8.2202,[82]8.2356,[83]8.1614,[84]8.1447,[85]8.1409,[86]8.1103,[87]8.0484,[88]8.0102,[89]7.9842,[90]7.9620,[91]8.0019,[92]7.9939,[93]7.9899,[94]7.9864,[95]8.0278,[96]8.0277,[97]8.0191,[98]8.0111,[99]7.9845,[100]7.9780,[101]8.0098,[102]8.0020,[103]8.0304,[104]8.0333,[105]8.0344,[106]8.0579,[107]8.0568,[108]8.0709,[109]8.0685,[110]8.0648,[111]8.0911,[112]8.1192,[113]8.1227,[114]8.1192,[115]8.1330,[116]8.1195,[117]8.1280,[118]8.1694,[119]8.1971,[120]8.2447,[121]8.2704,[122]8.3007,[123]8.3528,[124]8.3769,[125]8.3618,[126]8.4117,[127]8.4544,[128]8.4892,[129]8.4616,[130]8.4697,[131]8.4624,[132]8.4480,[133]8.4311,[134]8.4462,[135]8.4397,[136]8.4282,[137]8.4204,[138]8.4024,[139]8.3956,[140]8.3884,[141]8.3676,[142]8.3623,[143]8.3396,[144]8.3209,[145]8.3169,[146]8.3003,[147]8.3116,[148]8.3126,[149]8.3067,[150]8.3081,[151]8.3101,[152]8.2988,[153]8.2741,[154]8.2627,[155]8.2697,[156]8.2616,[157]8.2823,[158]8.2877,[159]8.2935,[160]8.2932,[161]8.3068,[162]8.2680,[163]8.2548,[164]8.2194,[165]8.1754,[166]8.1387,[167]8.0855,[168]8.0441,[169]8.0257,[170]8.0104,[171]7.9727,[172]7.9486,[173]7.9269,[174]7.8846,[175]7.8547,[176]7.8342,[177]7.8085,[178]7.7779,[179]7.7565,[180]7.7442,[181]7.7159,[182]7.6891,[183]7.6711,[184]7.6641,[185]7.6563,[186]7.6562,[187]7.6613,[188]7.6530,[189]7.6788,[190]7.6767,[191]7.7025,[192]7.7204,[193]7.7429,[194]7.7624,[195]7.7856,[196]7.8058,[197]7.8328,[198]7.8502,[199]7.8540,[200]7.8571,[201]7.8510,[202]7.8813,[203]7.8903,[204]7.8973,[205]7.9138,[206]7.9242,[207]7.9183,[208]7.9300,[209]7.9340,[210]7.9371,[211]7.9514,[212]7.9609,[213]7.9722,[214]7.9795,[215]7.9847,[216]8.0041,[217]8.0265,[218]8.0424,[219]8.0414,[220]8.0366,[221]8.0298,[222]8.0248,[223]8.0090,[224]8.0012,[225]7.9969,[226]8.0198,[227]8.0342,[228]8.0421,[229]8.0493,[230]8.0485,[231]8.0668,[232]8.0503,[233]8.0279,[234]8.0076,[235]7.9944,[236]7.9857,[237]7.9726,[238]7.9767,[239]7.9548,[240]7.9407,[241]7.9492,[242]7.9537,[243]7.9534,[244]7.9381,[245]7.9340,[246]7.9183,[247]7.9039,[248]7.8908,[249]7.8898,[250]7.8929,[251]7.8840,[252]7.8798,[253]7.8658,[254]7.8590,[255]7.8460,[256]7.8218,[257]7.8062,[258]7.7943,[259]7.7919,[260]7.7803,[261]7.7746,[262]7.7672,[263]7.7592,[264]7.7408,[265]7.7394,[266]7.7401,[267]7.7302,[268]7.7402,[269]7.7371,[270]7.7373,[271]7.7471,[272]7.7536,[273]7.7517,[274]7.7548,[275]7.7663,[276]7.7733,[277]7.7928,[278]7.8073,[279]7.8191,[280]7.8232,[281]7.8359,[282]7.8408,[283]7.8568,[284]7.8677,[285]7.8786,[286]7.8955,[287]7.8936,[288]7.9036,[289]7.8939,[290]7.8754,[291]7.8536,[292]7.8324,[293]7.8133,[294]7.8144,[295]7.8116,[296]7.8156,[297]7.8145,[298]7.8204,[299]7.8160,[300]7.8020,[301]7.8014,[302]7.7913,[303]7.7797,[304]7.7671,[305]7.7634,[306]7.7465,[307]7.7468,[308]7.7509,[309]7.7298,[310]7.7215,[311]7.7144,[312]7.7159,[313]7.7076,[314]7.7068,[315]7.6859,[316]7.6841,[317]7.6650,[318]7.6366,[319]7.6549,[320]7.6693,[321]7.6754,[322]7.6688,[323]7.6630,[324]7.6598,[325]7.6713,[326]7.6703,[327]7.6741,[328]7.6798,[329]7.6892,[330]7.6935,[331]7.7098,[332]7.7047,[333]7.7171,[334]7.7100,[335]7.7006,[336]7.7031,[337]7.6972,[338]7.6971,[339]7.6904,[340]7.6860,[341]7.6953,[342]7.6974,[343]7.7051,[344]7.7057,[345]7.7049,[346]7.7016,[347]7.7060,[348]7.7119,[349]7.7144,[350]7.7094,[351]7.7087,[352]7.7093,[353]7.7014,[354]7.7039,[355]7.7111,[356]7.7149,[357]7.7117,[358]7.7232,[359]7.7273,[360]7.7192,[361]7.7180,[362]7.7272,[363]7.7396,[364]7.7480,[365]7.7544,[366]7.7544,[367]7.7650,[368]7.7614,[369]7.7634,[370]7.7642,[371]7.7557,[372]7.7626,[373]7.7681,[374]7.7655,[375]7.7658,[376]7.7760,[377]7.7690,[378]7.7718,[379]7.7807,[380]7.7695,[381]7.7639,[382]7.7575,[383]7.7548,[384]7.7526,[385]7.7528,[386]7.7535,[387]7.7529,[388]7.7460,[389]7.7382,[390]7.7299,[391]7.7202,[392]7.7156,[393]7.7154,[394]7.7176,[395]7.7144,[396]7.7044,[397]7.7126,[398]7.7159,[399]7.7248,[400]7.7240,[401]7.7266,[402]7.7291,[403]7.7306,[404]7.7405,[405]7.7308,[406]7.7271,[407]7.7284,[408]7.7301,[409]7.7433,[410]7.7576,[411]7.7708,[412]7.7912,[413]7.8041,[414]7.8138,[415]7.8216,[416]7.8309,[417]7.8459,[418]7.8497,[419]7.8587,[420]7.8696,[421]7.8874,[422]7.8922,[423]7.9027,[424]7.9171,[425]7.9283,[426]7.9378,[427]7.9431,[428]7.9530,[429]7.9582,[430]7.9689,[431]7.9871,[432]7.9905,[433]7.9884,[434]7.9804,[435]7.9792,[436]7.9808,[437]7.9941,[438]8.0029,[439]7.9992,[440]7.9960,[441]7.9884,[442]7.9850,[443]7.9869,[444]7.9875,[445]7.9849,[446]7.9866,[447]7.9894,[448]7.9931,[449]7.9890,[450]7.9878,[451]7.9826,[452]7.9743,[453]7.9646,[454]7.9582,[455]7.9588,[456]7.9654,[457]7.9689,[458]7.9673,[459]7.9670,[460]7.9759,[461]7.9721,[462]7.9691,[463]7.9758,[464]7.9752,[465]7.9720,[466]7.9633,[467]7.9652,[468]7.9675,[469]7.9699,[470]7.9708,[471]7.9652,[472]7.9707,[473]7.9631,[474]7.9663,[475]7.9617,[476]7.9658,[477]7.9567,[478]7.9553,[479]7.9661,[480]7.9735,[481]7.9758,[482]7.9713,[483]7.9656,[484]7.9694,[485]7.9689,[486]7.9614,[487]7.9608,[488]7.9583,[489]7.9503,[490]7.9475,[491]7.9441,[492]7.9361,[493]7.9317,[494]7.9284,[495]7.9285,[496]7.9240,[497]7.9165,[498]7.9147,[499]7.9074,[500]7.8953,[501]7.8868,[502]7.8868,[503]7.8858,[504]7.8748,[505]7.8775,[506]7.8780,[507]7.8736,[508]7.8685,[509]7.8679,[510]7.8724,[511]7.8778,[512]7.8813,[513]7.8836,[514]7.8924,[515]7.8852,[516]7.8835,[517]7.8862,[518]7.8858,[519]7.8890,[520]7.8910,[521]7.8928,[522]7.8958,[523]7.8961,[524]7.9030,[525]7.9072,[526]7.9088,[527]7.9119,[528]7.9057,[529]7.9064,[530]7.8991,[531]7.8963,[532]7.9030,[533]7.9074,[534]7.9037,[535]7.9083,[536]7.9025,[537]7.8983,[538]7.9056,[539]7.9061,[540]7.9109,[541]7.9136,[542]7.9145,[543]7.9169,[544]7.9185,[545]7.9175,[546]7.9181,[547]7.9125,[548]7.9039,[549]7.9034,[550]7.9008,[551]7.8952,[552]7.8933,[553]7.8881,[554]7.8836,[555]7.8795,[556]7.8788,[557]7.8825,[558]7.8783,[559]7.8778,[560]7.8761,[561]7.8754,[562]7.8718,[563]7.8715,[564]7.8764,[565]7.8795,[566]7.8792,[567]7.8757,[568]7.8764,[569]7.8734,[570]7.8766,[571]7.8762,[572]7.8776,[573]7.8760,[574]7.8720,[575]7.8716,[576]7.8719,[577]7.8690,[578]7.8667,[579]7.8665,[580]7.8576,[581]7.8526,[582]7.8514,[583]7.8522,[584]7.8514,[585]7.8437,[586]7.8359,[587]7.8363,[588]7.8421,[589]7.8488,[590]7.8522,[591]7.8542,[592]7.8524,[593]7.8479,[594]7.8488,[595]7.8457,[596]7.8522,[597]7.8495,[598]7.8459,[599]7.8480,[600]7.8468,[601]7.8443,[602]7.8485,[603]7.8524,[604]7.8534,[605]7.8571,[606]7.8580,[607]7.8575,[608]7.8525,[609]7.8527,[610]7.8566,[611]7.8549,[612]7.8580,[613]7.8534,[614]7.8477,[615]7.8371,[616]7.8404,[617]7.8315,[618]7.8244,[619]7.8167,[620]7.7970,[621]7.7866,[622]7.7837,[623]7.7845,[624]7.7866,[625]7.7862,[626]7.7850,[627]7.7888,[628]7.7887,[629]7.7871,[630]7.7914,[631]7.7982,[632]7.8050,[633]7.8036,[634]7.8078,[635]7.8089,[636]7.8061,[637]7.8033,[638]7.8082,[639]7.8039,[640]7.8056,[641]7.8059,[642]7.8140,[643]7.8159,[644]7.8171,[645]7.8146,[646]7.8204,[647]7.8180,[648]7.8202,[649]7.8205,[650]7.8255,[651]7.8332,[652]7.8352,[653]7.8398,[654]7.8316,[655]7.8300,
SCM @ 3bpw sample inferences `sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000 generate: n_ctx = 512, n_batch = 8, n_predict = 512, n_keep = 0` ``` Building a website can be done in 10 simple steps: Step #1 - Choose your domain name. Make it something short and catchy that people will remember and want to visit again. You need something easy for folks to spell correctly too! Your URL extension should match your operating system – Windows servers require “.com” extensions; Linux based servers use the “.org” extension; Macintosh servers require a ".Mac" extension. Inexpensive web hosting companies like GoDaddy offer domain names with no installation fees — a significant savings over having your registrar handle it for you! Step #2 - Set up your Web Hosting Account Online With Your Registrar (This is the place where YOU PAY TO DO IT!) FREE web hosting services are NOT recommended unless your goal is to build your personal website or social networking profile. The free space they offer is minimal and will cost you in the long run of being able to create a professional web presence with search engine optimization (SEO) done right. When it comes time to upgrade or move web hosts, you'll be looking at hundreds of dollars per year — money that could better serve your business plan! You don’t want your website name and URL extension to be owned by someone else either — they could be holding you hostage! Step #3 - Install Your Web Server on a Hard Drive & Test It First Before Making it Live Online! This is important because you don't want to get to the point of making it live online when suddenly your web server dies. You don’t want your website visitors to see an error message either -- this will be VERY frustrating to them if they can't load your site or your entire web server goes down! Step #4 - Register Your Domain Name Online With Your Registrar (This is the place where YOU PAY TO DO IT!) This is very important because you need to secure your domain name — even though it’s free and openly available on most registrars websites. Secure your domain name registration with a private registrant that only reveals information about who registered your domain once you pay them money! Step #5 - Install Your Domain Name With Your Registrar's DNS Server (DNS) (This is the place where YOU PAY TO DO IT!) Make sure to use "www." as your domain name extension since that's what most people type when they want to reach your website address or URL extension. It' ``` ``` Building a website can be done in 10 simple steps: Step One: Determine whether you want a static site or one that is built dynamically. Step Two: Choose the platform on which you will build your site. Step Three: Create content plans based on personas and visitor profiles, as well as goals for your site visitors Step Four: Map out high level sitemap Step Five: Determine user flows Step Six: Build wireframes Step Seven: Write copy Step Eight: Design Step Nine: Launch Step Ten: Measure success and adjustments The best way to learn is to read about the subject matter you are interested in. There are hundreds of books on the subject matter of web development. It is important that you get a few different ones so you can build upon what you have learnt over time rather than be stuck into repetition without fresh inspiration or stimulation. If this means ordering books from abroad then do it! Go to your local bookshop and ask them to recommend titles based on the information you provide about what you are trying to achieve with building web sites etc. There is a huge amount of material out there so its probably best to narrow things down rather than being overwhelmed by choice which is quite common when searching for books. Have fun researching! The following list provides some resources that may be useful in your research about web development: http://www.alistairduncan.co.uk/ Tutorials on web design topics (by Alastair Duncan) http://tympan.com/~ A web standards reference by Dean Allen and others http://xmlhammer.net/~ The XML Hammer http://xjane.com/ Web Standards Compliance by Jane Madden http://webstandards.lud.ch/~ W3C Web Site http://www.w3c-org.org~ W3C (Mission statement) Organisation with charter to develop web standards http://www.w3.org/~ W3C Working Groups (Mission Statement), e.g. accessibility working group http://www.w3pub.com/ Web Publishing in Public Spaces - Community Based Publishing http://xmlhammer.net/ http://xmlhammer.net/2001/04/28/ XML Hammer's weblog (Duncan's comments on the XML 1.0 standard) ```