turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.56k stars 273 forks source link

Different groupsizes for different bitwidths? #181

Closed dalistarh closed 10 months ago

dalistarh commented 10 months ago

Hi,

I am one of the authors of the original GPTQ paper; we've been working on a version which specifically targets the EXL2 format, and we would find it useful to be able to use different group sizes for different bit-width "tiles" in the format. (For instance, we could use group size 64 for 2bit, but 1024 for 4bit.) Is this something that could be easily supported?

Best regards, Dan

turboderp commented 10 months ago

I don't know about easily, but it should be possible. The only requirement on group size is that it's a multiple of 32 to support the 3 and 5 bit formats.

To avoid storing any extra metadata with the model, the group size is derived a little clumsily like so:

    groupsize = 1;
    while (groupsize * groups < height) groupsize *= 2;

The matmul kernel also assumes the group size is constant per matrix in a couple of places, so there would have to be a few changes to make it variable. But it can definitely be done.

As for the format itself, the q_groups tensor in each layer is a list of pairs of unsigned shorts. First element of each pair is the number of bits per weight in that group, second is the first row index in the q_weight tensor for the group. Here is the code that packs the quantized weights, for reference:

        output[key + ".q_groups"] = self.qgroups

        ...

        i = 0  # <--- group index
        row = 0
        out_row = 0
        while i < self.qscale.shape[0]:

            bits = self.qgroups[i * 2].item()  # <--- bitrate written earlier
            self.qgroups[i * 2 + 1] = out_row  # <--- current row in q_weight
            i += 1

            rows = min(self.group_size, rem_rows)
            wpqr = 32 / bits
            qrows = rows / wpqr
            assert i == self.qgroups.shape[-1] or qrows == int(qrows)
            qrows = math.ceil(qrows)

            # rows is the number of rows in the quantized and offset but unpacked matrix, qwt
            # qrows is the number of uint32 rows after packing (rounded up) 

            g_qwt = qwt[row:row+rows, :].contiguous()
            g_qwt_packed = torch.zeros((qrows, columns + padding), dtype = torch.int32, device = self.device)

            if padding > 0: g_qwt[:, -padding:] = 2 ** (bits - 1)  # <-- zero with offset

            ext_c.pack_columns(g_qwt, g_qwt_packed, bits)  # <-- calls CUDA kernel to pack one group
            qwt_packed.append(g_qwt_packed)

            row += rows
            out_row += qrows
            rem_rows -= rows

        qwt_packed = torch.cat(qwt_packed, dim = 0)
        output[key + ".q_weight"] = qwt_packed

So the format does allow for what you're after. It's just the QMatrix class in the extension that needs to be updated to derive variable group sizes from the q_groups tensor instead of what it's doing now. And then a few corresponding modifications to the kernel.

I'll try to fit that in sometime tomorrow.

dalistarh commented 10 months ago

Thanks a lot for the quick (and comprehensive) response!

Indeed, you can reasonably assume that the group size will be a power of 2 that's at least 32, and that we would only have a small number of different group sizes. (Right now we are only looking at a fixed group-size per bit-width, e.g. 2bit patterns are always group size 32, so we aren't planning on going too crazy.)

Generally, we are working on a "V2" version of GPTQ which includes variants of features already discussed on this forum (such as an "optimal" EXL2 split w.r.t. L2 error, or choosing different compression targets across layers, based on output sensitivity, better clipping). In case you think it might be useful to sync on this directly, please feel free to contact me on my academic email. Would be happy to chat.

Thanks again!

turboderp commented 10 months ago

In case you missed it, the implementation should support different group sizes within a single quantized tensor. The option I added to the quantizer script corresponds to what you suggested, one group size per bitrate since that seemed like the simplest way to test it, but the kernels should respect any combination of group sizes. The size of each individual group is derived from the .q_groups tensor.

dalistarh commented 10 months ago

Thanks a lot, we will give this a try and report back if we see any improvements.