larq / compute-engine

Highly optimized inference engine for Binarized Neural Networks
https://docs.larq.dev/compute-engine
Apache License 2.0
243 stars 35 forks source link

Question: model with bitpacked output #445

Closed xross closed 4 years ago

xross commented 4 years ago

For testing our kernels I would like to produce a single operator model with bitpacked output, the toy model looks like the following:

def quant(x):
    return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)

def toy_model_int8(**kwargs):
    img = tf.keras.layers.Input(shape=(224, 224, 3))
    x = quant(img)

    x = lq.layers.QuantConv2D(
        12, 3, input_quantizer="ste_sign", kernel_quantizer="ste_sign", activation=quant
    )(x)

    return tf.keras.Model(img, x)

model = toy_model_int8()

I have been successful in producing bitpacked intermediary's between successive Conv2D's but not at the model output - - - granted this is not a normal use case!

Setting activation to quant produces an int8 output, activation=None produces a float32 output, as expected. My question is, is it possibly to produce a int32 bitpacked output from this model?

AdamHillier commented 4 years ago

That is a good point, and something I'd thought of before. At the moment, the converter will set the bconv op output tensor type to be int32 (i.e. bitpacked output) only when the subsequent (and only) use of the tensor is as input into another bconv or binary maxpool op. So the answer is that currently this is not possible.

I think it's fair to say that we will want to support this if it's easy to do, as it seems like a reasonable use-case. Perhaps not as the default behaviour, maybe behind a feature flag, but we can cross that bridge when we come to it.

I don't know how easy this would be to implement, however. @lgeiger do you have any thoughts about this, is there a nice way to pattern match a 'model output' in MLIR? I worry that since it's not currently possible in TF2.3 to convert a QAT model to have proper Int8 input/output tensors, there might be challenges here, but hopefully not.

Tombana commented 4 years ago

In our TFLite code (not micro), we have unittests that do test models like that: https://github.com/larq/compute-engine/blob/master/larq_compute_engine/tflite/tests/bconv2d_test.cc These tests build a single-op model on-the-fly and run the interpreter on that, using some tflite testing utilities. I'm not sure if it actually creates a valid flatbuffer file during the process that you might be able to extract.

lgeiger commented 4 years ago

Perhaps not as the default behaviour, maybe behind a feature flag, but we can cross that bridge when we come to it.

I agree that this shouldn't be the default behaviour as exposing internal bitpacked representation to users seems to be quite error prone (e.g. if the output of the model is post-processed or passed into another TFLite model writing bitpacked output could potentially lead to unexpected behaviour since only our BConv and BMaxPool ops correctly handle bitpacked int32 input). Though, I am not a big fan of adding a flag to the converter either as it increases the API surface for a feature that seems to only be used internally for testing. However, if there are more uses cases for this I'm not too concerned about adding a new flag to the converter.

I don't know how easy this would be to implement, however. @lgeiger do you have any thoughts about this, is there a nice way to pattern match a 'model output' in MLIR?

Unfortunately I am not aware of an easy way to handle this in our current setup on top of my head. I guess we could post-process the flatbuffer in Python, but that seems to be a bit brittle or we could add a new pass that operates on the entire model at once and therefor knows about whether a tensor is the model output.

I agree with Tom, that for testing individual ops I'd prefer doing it in C++ if possible and building the model on the fly which gives full control over all details. This would come with the additional benefit of not having to invoke the converter for every parameter combination which can become quite expensive when trying to test many combinations.

xross commented 4 years ago

Thanks for the responses. To clarify my query, we do test individual ops at the C level as you might expect, however we do also run a bunch of single op models through our integration test suite in order to test end-to-end operation (i.e. model->LCE converter->xformer). We'll likely generate a dual op model and generate a pass to snip off the 2nd op in our test system - if/ when such feature is made available we can move to it.

Again, thanks for the replies :)

lgeiger commented 4 years ago

@xross I think after #457 is merged (probably in a few weeks time), this will become a bit easier. While we will still not be able to have bitpacked model output, it will be possible to generate a single layer model, e.g:

model = tf.keras.Sequential(
    [
        lq.layers.QuantConv2D(
            16,
            3,
            input_quantizer="ste_sign",
            kernel_quantizer="ste_sign",
            input_shape=(224, 224, 3),
        ),
        lq.quantizers.SteSign(),
    ]
)

This will be converted to something like

Screenshot 2020-08-06 at 17 05 06

with quantize/dequantize ops that handle bitpacking as described in #447.

I think this should make testing easier, and if needed the dequantize op that converts bitpacked output back to floating point or int8 can be removed from the flatbuffer similar to how we handle true int8 output in Python.

Let me know if that would work for you.

xross commented 4 years ago

Great, thanks for the update - this should do nicely I believe. Our xformer already strips the builtin QUANTIZE/DEQUANTIZE so can extend to strip the LceQuantize/Dequantize

lgeiger commented 4 years ago

Great, thanks for the update - this should do nicely I believe. Our xformer already strips the builtin QUANTIZE/DEQUANTIZE so can extend to strip the LceQuantize/Dequantize

This has been implemented in #457 and will be part of the 0.4 release. I'm closing this issue for now, feel fre to reopen if you run into any issues.