larq / compute-engine

Highly optimized inference engine for Binarized Neural Networks
https://docs.larq.dev/compute-engine
Apache License 2.0
242 stars 34 forks source link

QuantConv2D binarized activations with bitpacked output as tf.int32 #599

Closed simonmaurer closed 3 years ago

simonmaurer commented 3 years ago

@lgeiger @AdamHillier given our latest mail conversation I'm trying to state one of the questions here so others can profit too. this is somehow related to the #445

given the hints from @lgeiger:

If you use experimental_enable_bitpacked_activations=True during model conversion and have one-padding enabled in Larq, the converter will automatically fuse a convolution followed by a sign function into a single op where the binary convolution returns bitpacked binary feature maps directly.

I modelled a multitask network with binary layers and a sign operation (larq.math.sign respectively larq.quantizers.SteSign()) afterwards: output = larq.layers.QuantConv2D(256, kernel_size = (3,3), strides=(1,1), padding="same", pad_values=1, input_quantizer="ste_sign", kernel_quantizer="ste_sign", kernel_constraint="weight_clip", use_bias=False)(output) output = larq.quantizers.SteSign()(output)

However there is a LceDequantize op and the output is a tensor with tf.float32 (identity).

bitpacked_output

Can you give me some hint how I could strip that last dequantization layer so I get the actual tf.int32 tensor ?

I get it that you marked this as experimental but from a research perspective accessing the output layers (bitpacked output) directly is interesting (especially since you're thinking about making this default behavior as discussed here: https://github.com/larq/compute-engine/pull/597

I suppose the answer must lie in modify_integer_quantized_model_io_type as part of https://github.com/larq/compute-engine/blob/master/larq_compute_engine/mlir/python/util.py

Tombana commented 3 years ago

Hi, thanks for your interest in LCE! The converter will indeed leave that LceDequantize in the graph, as this bitpacked datatype is kind of an LCE internal type that is only ever read by other LCE ops, so it won't be exposed to the outside.

I think that your best bet is to write a separate python script that modifies the generated flatbuffer, inspired by that function modify_integer_quantized_model_io_type that you linked to.

You could take that util.py file as example and then slightly modify it as follows.

https://github.com/larq/compute-engine/blob/579c8fb02f69d6f08e1581206c7b9b00b5394028/larq_compute_engine/mlir/python/util.py#L147-L163

With these changes, you should be able to strip out the "LceDequantize" ops.

simonmaurer commented 3 years ago

@Tombana thanks a lot for the hint. very appreciated. implemented a script given your input that does exactly that. when executing it after the regular calling tflite_buffer = convert_keras_model(model) the following error was thrown:

ImportError: cannot import name 'tflite_schema' from 'larq_compute_engine.mlir.python'

looks like I have to implement the code right in the repo and compile larq_compute_engine manually. would you agree ?

simonmaurer commented 3 years ago

given the last error, I fixed it by importing it as: from larq_compute_engine.mlir.python import schema_py_generated as tflite_schema

the code line you mentioned above I had to adapt a bit: opcode.customCode == b'LceDequantize'

as you said I had also had to remove the extra checks

simonmaurer commented 3 years ago

am I right to assume that this will only remove the last LceDequantize layer(s) in a model (right before the (multi)output) or could it affect intermediate layers within the model ? cause I still don't quite get the conversion: looking at the sample picture above, shouldn't there be a LceDequantize layer right before a regular Conv2D on the two paths on the right side (as any output tensor from QuantConv2D/LceBconv2d is tf.int32 per se) ?

Tombana commented 3 years ago

am I right to assume that this will only remove the last LceDequantize layer(s) in a model (right before the (multi)output) or could it affect intermediate layers within the model ?

You are right. It will not affect any intermediate layers.

cause I still don't quite get the conversion: looking at the sample picture above, shouldn't there be a LceDequantize layer right before a regular Conv2D on the two paths on the right side

That is a good question. The following assumption is wrong:

as any output tensor from QuantConv2D/LceBconv2d is tf.int32 per se

The binary convolutions have binary input and weights. The output, however, is generally a float (or int8) number. This is especially important in networks with shortcuts, where you want the full-precision float data to go into the shortcut. Now in case there are two LceBconv2d ops in sequence, then normally you'd have an LceQuantize op inbetween to convert the full-precision output to binary. But if the intermediate tensor inbetween is not consumed by any other op, then the MLIR converter detects this and changes the first LceBconv2d op to produce binary output directly, so that the LceQuantize op can be removed. I hope that answers your question, feel free to ask more if it is not clear!

simonmaurer commented 3 years ago

@Tombana it's a bit clearer now. makes sense that you discuss making experimental_enable_bitpacked_activations=True default, as otherwise there's gonna be "unnecessary" LceQuantize ops in between consecutive QuantConv2D layers.

The binary convolutions have binary input and weights. The output, however, is generally a float (or int8) number.

Oh, gotcha. My mistake, I assumed the output to be tf.int32 given the API documentation: Note that the output of a binarized layer is not binary. Instead the output is an integer, due to the summation that appears in most neural network layers.

Two questions remain:

  1. since you said binary convolutions generally output a float (or int8) number. When setting: tfl_mdl = convert_keras_model(mdl, inference_input_type=tf.int8, inference_output_type=tf.int8) the tensors all remain in tf.float32, why is that or can we control it elsewhere ?
  2. referring to _find_int8_quantized_inputs_outputs in util.py: are we talking about TFLite QAT quantization/dequantization layers here that are adapted ? in other words tf.int8 as input/output of the model or for each layer (like the int8 numbers you mentioned above) ?
Tombana commented 3 years ago

Oh, gotcha. My mistake, I assumed the output to be tf.int32 given the API documentation: Note that the output of a binarized layer is not binary. Instead the output is an integer, due to the summation that appears in most neural network layers.

I understand the confusion now. That part of the documentation refers to the fact that the sum of the product of the -1 and +1 values of a binary convolution will always be an integer. In LCE, the binary convolution op also has a fused batchnorm, so then that integer can become a float.

In LCE, the int32 type in tflite files always refers to bitpacked data, so it refers to 32 individual bits and not an integer. That's why the raw int32 tensor is usually not useful for the end user. You mentioned you are interested in the int32 data, so if you meant the sum of the -1,1 product, then this is not what you will get by removing the LceDequantize op.

  1. since you said binary convolutions generally output a float (or int8) number. When setting: tfl_mdl = convert_keras_model(mdl, inference_input_type=tf.int8, inference_output_type=tf.int8) the tensors all remain in tf.float32, why is that or can we control it elsewhere ?

Those parameters inference_input_type refer to the input type of the entire model, not the individual layers. To make the individual layers use int8 tensors, you have to add fake quantization nodes (such as this one) to the graph between every layer. (But not inbetween things that are going to get fused, so it can be a bit tricky to get this right.)

  1. referring to _find_int8_quantized_inputs_outputs in util.py: are we talking about TFLite QAT quantization/dequantization layers here that are adapted ? in other words tf.int8 as input/output of the model or for each layer (like the int8 numbers you mentioned above) ?

That is correct.

simonmaurer commented 3 years ago

I understand the confusion now. That part of the documentation refers to the fact that the sum of the product of the -1 and +1 values of a binary convolution will always be an integer. In LCE, the binary convolution op also has a fused batchnorm, so then that integer can become a float.

gotcha. should be clear now.

In LCE, the int32 type in tflite files always refers to bitpacked data, so it refers to 32 individual bits and not an integer. That's why the raw int32 tensor is usually not useful for the end user. You mentioned you are interested in the int32 data, so if you meant the sum of the -1,1 product, then this is not what you will get by removing the LceDequantize op.

for the LCE converter the aim is indeed to access the bitpacked data containing binary values as output so I can postprocess them (without dequantizing them back to tf.float32). removing LceDequantize should be the right way to go. sry if I mixed that up

Those parameters inference_input_type refer to the input type of the entire model, not the individual layers. To make the individual layers use int8 tensors, you have to add fake quantization nodes (such as this one) to the graph between every layer. (But not inbetween things that are going to get fused, so it can be a bit tricky to get this right.)

so to get this one right, I'd do a manual placement of the quantize layers from tensorflow package, ie. fake_quant_with_min_max_vars or fake_quant_with_min_max_args now by doing this would I need to define the range too using experimental_default_int8_range ? as both of the above already exhibit arguments min=-6 and max=6

would the default QAT converter tensorflow_model_optimization.quantization.keras.quantize_model be an alternative (as it does such a conversion automatically and adds fake quantize nodes) or is that not compatible with the Larq workflow?

Tombana commented 3 years ago

so to get this one right, I'd do a manual placement of the quantize layers from tensorflow package, ie. fake_quant_with_min_max_vars or fake_quant_with_min_max_args now by doing this would I need to define the range too using experimental_default_int8_range ? as both of the above already exhibit arguments min=-6 and max=6

When you add those fake_quant ops, you do not need to use experimental_default_int8_range. The flag experimental_default_int8_range is meant for benchmarking int8 models without training them properly, and it will effectively add a fake_quant node to any tensor that does not have one yet.

would the default QAT converter tensorflow_model_optimization.quantization.keras.quantize_model be an alternative (as it does such a conversion automatically and adds fake quantize nodes) or is that not compatible with the Larq workflow?

That should work in theory, although we've had some issues with that in the past when we tried it ourselves, so we ended up placing these fake_quant nodes manually.

simonmaurer commented 3 years ago

conversion was successful in the sense that the bitpacked output is now available as a tf.int32 tensor

bitpacked_output_tfint32

When you add those fake_quant ops, you do not need to use experimental_default_int8_range. The flag experimental_default_int8_range is meant for benchmarking int8 models without training them properly, and it will effectively add a fake_quant node to any tensor that does not have one yet.

thanks for the clarification. definitely get the differences with respect to the parameters now. maybe one could also add that to the API documentation

is there an openness from your side to participate in such kinds of updates (and even if it's just documentation changes) from outside your team so that others can profit too ? usually that is done via pull requests, how do you see this ?

the idea would be to have an additional parameter enable_bitpacked_int32_output = False in convert_keras_model taking this discussion up:

Originally posted by @lgeiger in https://github.com/larq/compute-engine/issues/445#issuecomment-667088689

Perhaps not as the default behaviour, maybe behind a feature flag, but we can cross that bridge when we come to it.

I agree that this shouldn't be the default behaviour as exposing internal bitpacked representation to users seems to be quite error prone (e.g. if the output of the model is post-processed or passed into another TFLite model writing bitpacked output could potentially lead to unexpected behaviour since only our BConv and BMaxPool ops correctly handle bitpacked int32 input). Though, I am not a big fan of adding a flag to the converter either as it increases the API surface for a feature that seems to only be used internally for testing. However, if there are more uses cases for this I'm not too concerned about adding a new flag to the converter.

I don't know how easy this would be to implement, however. @lgeiger do you have any thoughts about this, is there a nice way to pattern match a 'model output' in MLIR?

Unfortunately I am not aware of an easy way to handle this in our current setup on top of my head. I guess we could post-process the flatbuffer in Python, but that seems to be a bit brittle or we could add a new pass that operates on the entire model at once and therefor knows about whether a tensor is the model output.

I agree with Tom, that for testing individual ops I'd prefer doing it in C++ if possible and building the model on the fly which gives full control over all details. This would come with the additional benefit of not having to invoke the converter for every parameter combination which can become quite expensive when trying to test many combinations.

Tombana commented 3 years ago

thanks for the clarification. definitely get the differences with respect to the parameters now. maybe one could also add that to the API documentation

is there an openness from your side to participate in such kinds of updates (and even if it's just documentation changes) from outside your team so that others can profit too ? usually that is done via pull requests, how do you see this ?

We very much welcome such contributions; everyone is welcome to make pull requests. The docs are generated from here: https://github.com/larq/docs/tree/master/docs/compute-engine , and some part of the API docs are autogenerated from the docs in the code, like here: https://github.com/larq/compute-engine/blob/14da1c84c9cc8269fe12ba2c210ab1c70f0f4a3d/larq_compute_engine/mlir/python/converter.py#L78

the idea would be to have an additional parameter enable_bitpacked_int32_output = False in convert_keras_model taking this discussion up:

Originally posted by @lgeiger in #445 (comment)

Since you already have a working python script to remove the LceDequantize ops, it would be simple to add this. I still think it would not be useful for the average user; but for debugging and research it can be helpful.

What do you think about adding it under a debugging or research subfolder, like:

from larq_compute_engine.research import strip_lcedequantize_ops

tflite_model = strip_lcedequantize_ops(tflite_model)

@larq/compute-engine thoughts?

simonmaurer commented 3 years ago

@Tombana @lgeiger @AdamHillier thoughts about this ? as suggested by @Tombana the idea would be to have a separate research module, so the function strip_lcedequantize_ops could access the backend methods in util

am I assuming correctly, that the bitpacked tf.int32 outputs have values 0,1 representing 1,-1 of the quantization (QuantConv2D -> SteSign)?

Tombana commented 3 years ago

Hi @simonmaurer,

Apologies for the delay. We discussed it internally and we think it makes most sense to add this functionality directly to util.py but not expose it as a public API. People that need the functionality can then import it directly with

from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops

Since you already have a working script, could you create a pull request where you add the relevant changes to util.py ?

am I assuming correctly, that the bitpacked tf.int32 outputs have values 0,1 representing 1,-1 of the quantization (QuantConv2D -> SteSign)?

If x is your int32 number, then (x >> i) & 1 will give you the i-th bit of x (with i between 0 and 31). The bit value 0 represents +1.0 and the bit value 1 represents -1.0.

simonmaurer commented 3 years ago

@Tombana no problem. very nice. I'll have to adapt the code and put it into util.py. this way we can directly use the builtin functions _convert_model_from_bytearray_to_object, _remove_tensors_from_model and _convert_model_from_object_to_bytearray. once this is done I can issue a pull request.

simonmaurer commented 3 years ago

closing this issue as it has been resolved by PR #611