Open BenCrulis opened 5 months ago
Could you try again with a QuantConv2D
layer instead of a QuantDense
layer?
The Larq Compute Engine does not support fully-connected binary layers at the moment.
To check if the conversion was successful, I recommend opening the .tflite
file in netron. There you can check if the layer was recognized as a binary layer: it should have Lce
in the name.
@Tombana It does not seem to recognize QuantConv2D either:
Assuming I manage to get QuantConv2D working, can I then use it to replace the dense layers and benefit from the weight compression with a kernel size of 1x1? I am asking because I am not sure if a 1x1 kernel will allow efficient bit packing, I don't know how it is implemented in LCE.
Also, if you think it should not be too hard to implement QuantDense in LCE, I would perhaps be interested in implementing it since I would like to do ternary LLM inference on Android, but at the moment I can't even save a TFLite file of the model because it is bigger than 2GB if it is unquantized.
can I then use it to replace the dense layers and benefit from the weight compression with a kernel size of 1x1?
Yes. It will benefit from the weight compression. However, depending on the shape parameters it might not benefit from the speed improvement because there are currently no optimized kernels implemented for that.
if you think it should not be too hard to implement QuantDense in LCE, I would perhaps be interested in implementing it
We welcome pull requests for this. Implementing this would consist of two main parts:
Add support for this in the converter. This requires some knowledge of MLIR and TableGen, but the binary-conv2d code can be used as a starting point. Example of this code: https://github.com/larq/compute-engine/blob/3f57a37a919f86f242d6c5f2ad6299fb01717759/larq_compute_engine/mlir/transforms/prepare_patterns_common.td#L97
Add an optimized kernel. Since you are interested in Android, I assume it will be for 64-bit ARM CPUs, which means it requires some knowledge of NEON assembly and the kernel code would look something like this: https://github.com/larq/compute-engine/blob/3f57a37a919f86f242d6c5f2ad6299fb01717759/larq_compute_engine/core/bgemm/kernels_aarch64.h#L185
@Tombana I don't know if this changes anything, but I forgot to say that the inputs to the binary and ternary layers are not quantized, only the kernel weights are. I was assuming that there would be some kind of general mechanism to store quantized weights efficiently, independently of the layer, and then unquantize them on-device to get back to the world of classical f32 operations that I assume are all already supported.
But of course implementing a ternary matmul algorithm would probably be best, since even for f32 inputs, the binary and ternary matmul can be implemented with only additions and subtraction operations. It should be more efficient than just unquantizing the kernel weights and fall back to the normal f32 matmul algorithm.
I will look into the files you linked, but I'm afraid I know nothing about MLIR, TableGen or NEON assembly.
The LCE binary conv2d layer only supports the case where both inputs and weights are binary.
unquantize them on-device to get back to the world of classical f32 operations that I assume are all already supported.
This is a good idea. In theory this is supported because there would be a set of binary weights, followed by a LceDequantize
operator that converts them to int8 or float32. The output of that LceDequantize
could then go into the weights-input of the convolution.
However in practice I think that does not work because the MLIR converter will do 'constant folding' : when it detects this LceDequantize
operator with constant input, then it will already run the operator and store the output instead, because this is an "optimization" (in this case it isn't but the converter does not know that).
I'm not sure if this can be easily disabled.
As a workaround, maybe there is a way to prevent constant folding by adding an Input variable that will be just zero at inference time and added to the weights before LceDequantize
.
Is there any way to store bit-packed weights without relying on QuantConv2D
or other completely binary layers at the moment? I can try this workaround if there is.
Something like that might work. Perhaps you have to use the input variable even before the ste_sign
.
Note that this LceDequantize
trick would work with any layer, including dense layers, so in this case there is no need to rely on QuantConv2D
. The ste_sign
is converted to LceQuantize+LceDequantize
and that works independently of the other layers around it.
@Tombana I tried using the ste_sign
function to quantize the kernel parameters myself in a custom layer, it seemed to do constant folding as you said, so the parameters were stored unquantized.
I tried my idea of adding a value computed from the input to prevent constant folding, but it also didn't want to store the quantized weights, I think this is because the addition is placed before LceQuantize
when it should probably be after it and before LceDequantize
(ignore the splitting of the fully connected operation at the end):
I also tried having a fully binary dense module inside of my custom layer and just use its kernel instead of calling the layer, but it has the same behavior with and without the trick, the quantization is lost.
Do you have any other idea how to preserve quantization on disk with the current API?
In any case, if it is possible I suggest creating additional separate functions for quantization and dequantization in the python API to have a fine grained control over the process and storage of parameters when needed. If such functions can be created in user space, I am very interested as I need a working solution urgently. In the future it would be nice to implement the partially binarized and ternarized layers variants as we said (float inputs, binary/ternary weights), ideally with optimized addition/subtraction based matmul.
I don't know if this will work, but this is the only other idea I have left:
Disable constant-folding of the LceDequantize
op by setting this to zero:
https://github.com/larq/compute-engine/blob/3f57a37a919f86f242d6c5f2ad6299fb01717759/larq_compute_engine/mlir/ir/lce_ops.td#L121
Then rebuild the converter (just a heads up, this can take several hours) and see if it works.
After lots of tries, I found that tf.raw_ops.DebugIdentityV2
seems to be preserved in the constant folding pass, so I will stick with it for now, but I fear the behavior may change in the future.
edit: this method doesn't seem to allow the execution of the resulting tflite interpreter, the operation is unrecognized.
It is also possible to add tf.cast(tf.random.uniform((1,), minval=0.0, maxval=0.0), dtype=tf.uint8)
to the variable instead, but this creates a broadcast and an addition.
I haven't tried to change hasFolder
yet, I would be grateful if somebody else tried though, as I am not comfortable with either tensorflow or larq codebase. Bonus point if someone adds a no-op operation that is specifically made to prevent constant folding.
Should I make a feature request for this ?
Should I make a feature request for this ?
I think it might be a nice feature to be able to run regular layers but with bitpacked weights that get converted to float/int8 at runtime. I'm not sure it has many use cases: binarizing weights often requires a new model architecture that is specifically tuned for this, but perhaps it works for your neural network.
The Larq team currently does not have time to pick this up, but perhaps someone else in the community can work on it.
It is useful in the cases where having a small memory footprint is most important. The TFLite converter even fails if the resulting TFLite file weighs more than 2GB, which prevents using most LLMs on a smartphone through TFLite or Larq.
Quantization of activations is apparently more destructive than quantization of weights, so it makes sense to only quantize weights in a post-training quantization step. Currently, only models trained from scratch with both quantized weights and activations can benefit from the Larq primitives, if I understand correctly, but supporting mixed quantization schemes would also bring fast inference and reduced model size to post-training quantized models.
I am currently writing my own packing/unpacking algorithm with tensorflow primitives, I fear that the runtime cost might be too large though, I see it is creating lots of nodes in the computational graph. But to even allow this in the first place, we need an operator specifically made to prevent the constant folding step anyways, so there is still this (hopefully) small addition to either the tensorflow or Larq codebase that would be very useful. I also opened this issue tensorflow/tensorflow#69598 about the problem of constant-folding in the tensorflow github.
I am trying to save a quantized ternary model to a
.tflite
file, but larq doesn't seem to save the weights using datatypes with a reduced precision and thus compress the file size. However, after converting and writing to disk, the size of the file is about the same as the one predicted bylarq.models.summary
in float32 parameters.Even if I try to do the same thing with a simple
QuantDense
layer, the weights are saved in float32.I am using this kind of code:
Am I doing something wrong?