This causes broken Reshape and Concatenation ops: the MLIR passes will correctly infer the Reshape output shape to something like ?x?x2, and then the flatbuffer exporter simply overrides both with 1. Other artifacts
left-over StridedSlice ops in the graph (which should've been folded out)
invalid Reshape shapes:
invalid concatenation shapes:
I forgot the exact details but something like depthwise layers with dilation also got shape mismatches.
The graphdef converter code has a batch_size=1 fix in Python, but the saved model converter did not have something like this yet.
What do these changes do?
This adds a pass to the converter that sets the dynamic batchsize to 1 on the input layer.
When
batch_size
is not explicitly set to 1, it will remain a?
(wildcard) throughout all MLIR passes. Only at the final MLIR to Flatbuffer conversion stage, are all the?
simply converted to1
: https://github.com/tensorflow/tensorflow/blob/v2.7.0/tensorflow/compiler/mlir/lite/flatbuffer_export.cc#L844This causes broken Reshape and Concatenation ops: the MLIR passes will correctly infer the Reshape output shape to something like
?x?x2
, and then the flatbuffer exporter simply overrides both with1
. Other artifactsStridedSlice
ops in the graph (which should've been folded out)The graphdef converter code has a
batch_size=1
fix in Python, but the saved model converter did not have something like this yet.How Has This Been Tested?
MLIR filecheck tests have been added.