iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.88k stars 629 forks source link

Error importing 3d Unet TFLite model #12120

Open mariecwhite opened 1 year ago

mariecwhite commented 1 year ago

When importing 3dunet_kits19_1x1x128x128x128.tflite to mlir using import-iree-tflite, I get the error:

/tmp/3dunet_kits19_128x128x128.tflite:0:0: error: The following illegal operations still remain: 
    tfl.broadcast_to (count: 49)
    tfl.no_value (count: 1)
    tfl.conv_3d_transpose (count: 5)

This model is part of the MLPerf Inference suite and taking from: https://github.com/mlcommons/inference/tree/master/vision/medical_imaging/3d-unet-kits19

jpienaar commented 1 year ago

Wrt broadcast_to, they all seem to be feeding into broadcasting functions (I think these can all be fused in to the these mul/add ops):

     1      %24 = "tfl.mul"(%8, %23) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
     2      %26 = "tfl.add"(%24, %25) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
     3      %45 = "tfl.mul"(%32, %44) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
     4      %47 = "tfl.add"(%45, %46) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
     5      %67 = "tfl.mul"(%52, %66) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
     6      %69 = "tfl.add"(%67, %68) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
     7      %87 = "tfl.mul"(%74, %86) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
     8      %89 = "tfl.add"(%87, %88) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
     9      %109 = "tfl.mul"(%94, %108) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    10      %111 = "tfl.add"(%109, %110) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    11      %129 = "tfl.mul"(%116, %128) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    12      %131 = "tfl.add"(%129, %130) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    13      %151 = "tfl.mul"(%136, %150) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    14      %153 = "tfl.add"(%151, %152) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    15      %171 = "tfl.mul"(%158, %170) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    16      %173 = "tfl.add"(%171, %172) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    17      %193 = "tfl.mul"(%178, %192) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    18      %195 = "tfl.add"(%193, %194) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    19      %213 = "tfl.mul"(%200, %212) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    20      %215 = "tfl.add"(%213, %214) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    21      %234 = "tfl.mul"(%220, %233) {fused_activation_function = "NONE"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
    22      %236 = "tfl.add"(%234, %235) {fused_activation_function = "RELU"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
    23      %254 = "tfl.mul"(%241, %253) {fused_activation_function = "NONE"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
    24      %256 = "tfl.add"(%254, %255) {fused_activation_function = "RELU"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
    25      %282 = "tfl.mul"(%269, %281) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    26      %284 = "tfl.add"(%282, %283) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    27      %302 = "tfl.mul"(%289, %301) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    28      %304 = "tfl.add"(%302, %303) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
    29      %312 = "tfl.add"(%308, %311) {fused_activation_function = "NONE"} : (tensor<1x16x16x16x256xf32>, tensor<1x16x16x16x256xf32>) -> tensor<1x16x16x16x256xf32>
    30      %332 = "tfl.mul"(%319, %331) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    31      %334 = "tfl.add"(%332, %333) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    32      %352 = "tfl.mul"(%339, %351) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    33      %354 = "tfl.add"(%352, %353) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
    34      %362 = "tfl.add"(%358, %361) {fused_activation_function = "NONE"} : (tensor<1x32x32x32x128xf32>, tensor<1x32x32x32x128xf32>) -> tensor<1x32x32x32x128xf32>
    35      %382 = "tfl.mul"(%369, %381) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    36      %384 = "tfl.add"(%382, %383) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    37      %402 = "tfl.mul"(%389, %401) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    38      %404 = "tfl.add"(%402, %403) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
    39      %411 = "tfl.add"(%408, %410) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
    40      %431 = "tfl.mul"(%418, %430) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
    41      %433 = "tfl.add"(%431, %432) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
    42      %451 = "tfl.mul"(%438, %450) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
    43      %453 = "tfl.add"(%451, %452) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
    44      %461 = "tfl.add"(%457, %460) {fused_activation_function = "NONE"} : (tensor<1x128x128x128x32xf32>, tensor<1x128x128x128x32xf32>) -> tensor<1x128x128x128x32xf32>
    45      %481 = "tfl.mul"(%468, %480) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
    46      %483 = "tfl.add"(%481, %482) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
    47      %501 = "tfl.mul"(%488, %500) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
    48      %503 = "tfl.add"(%501, %502) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
    49      %510 = "tfl.add"(%506, %509) {fused_activation_function = "NONE"} : (tensor<1x128x128x128x3xf32>, tensor<1x128x128x128x3xf32>) -> tensor<1x128x128x128x3xf32>

Wrt no_value

9:    %6 = "tfl.conv_3d"(%3, %4, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x1xf32>, tensor<3x3x3x1x32xf32>, none) -> tensor<1x128x128x128x32xf32>
34:    %31 = "tfl.conv_3d"(%29, %30, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x32xf32>, tensor<3x3x3x32x32xf32>, none) -> tensor<1x128x128x128x32xf32>
54:    %51 = "tfl.conv_3d"(%49, %50, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x130x130x130x32xf32>, tensor<3x3x3x32x64xf32>, none) -> tensor<1x64x64x64x64xf32>
76:    %73 = "tfl.conv_3d"(%71, %72, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x66x64xf32>, tensor<3x3x3x64x64xf32>, none) -> tensor<1x64x64x64x64xf32>
96:    %93 = "tfl.conv_3d"(%91, %92, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x66x66x66x64xf32>, tensor<3x3x3x64x128xf32>, none) -> tensor<1x32x32x32x128xf32>
118:    %115 = "tfl.conv_3d"(%113, %114, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x34x34x34x128xf32>, tensor<3x3x3x128x128xf32>, none) -> tensor<1x32x32x32x128xf32>
138:    %135 = "tfl.conv_3d"(%133, %134, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x34x34x34x128xf32>, tensor<3x3x3x128x256xf32>, none) -> tensor<1x16x16x16x256xf32>
160:    %157 = "tfl.conv_3d"(%155, %156, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x18x18x18x256xf32>, tensor<3x3x3x256x256xf32>, none) -> tensor<1x16x16x16x256xf32>
180:    %177 = "tfl.conv_3d"(%175, %176, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x18x18x18x256xf32>, tensor<3x3x3x256x320xf32>, none) -> tensor<1x8x8x8x320xf32>
202:    %199 = "tfl.conv_3d"(%197, %198, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x10x10x10x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x8x8x8x320xf32>
222:    %219 = "tfl.conv_3d"(%217, %218, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x10x10x10x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x4x4x4x320xf32>
243:    %240 = "tfl.conv_3d"(%238, %239, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x6x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x4x4x4x320xf32>
263:    %260 = "tfl.conv_3d_transpose"(%258, %259, %257, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x320x320xf32>, tensor<1x4x4x4x320xf32>, none) -> tensor<1x8x8x8x320xf32>
271:    %268 = "tfl.conv_3d"(%266, %267, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x10x10x10x640xf32>, tensor<3x3x3x640x320xf32>, none) -> tensor<1x8x8x8x320xf32>
291:    %288 = "tfl.conv_3d"(%286, %287, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x10x10x10x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x8x8x8x320xf32>
311:    %308 = "tfl.conv_3d_transpose"(%306, %307, %305, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x256x320xf32>, tensor<1x8x8x8x320xf32>, none) -> tensor<1x16x16x16x256xf32>
321:    %318 = "tfl.conv_3d"(%316, %317, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x18x18x18x512xf32>, tensor<3x3x3x512x256xf32>, none) -> tensor<1x16x16x16x256xf32>
341:    %338 = "tfl.conv_3d"(%336, %337, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x18x18x18x256xf32>, tensor<3x3x3x256x256xf32>, none) -> tensor<1x16x16x16x256xf32>
361:    %358 = "tfl.conv_3d_transpose"(%356, %357, %355, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x128x256xf32>, tensor<1x16x16x16x256xf32>, none) -> tensor<1x32x32x32x128xf32>
371:    %368 = "tfl.conv_3d"(%366, %367, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x34x34x34x256xf32>, tensor<3x3x3x256x128xf32>, none) -> tensor<1x32x32x32x128xf32>
391:    %388 = "tfl.conv_3d"(%386, %387, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x34x34x34x128xf32>, tensor<3x3x3x128x128xf32>, none) -> tensor<1x32x32x32x128xf32>
411:    %408 = "tfl.conv_3d_transpose"(%406, %407, %405, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x64x128xf32>, tensor<1x32x32x32x128xf32>, none) -> tensor<1x64x64x64x64xf32>
420:    %417 = "tfl.conv_3d"(%415, %416, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x66x128xf32>, tensor<3x3x3x128x64xf32>, none) -> tensor<1x64x64x64x64xf32>
440:    %437 = "tfl.conv_3d"(%435, %436, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x66x64xf32>, tensor<3x3x3x64x64xf32>, none) -> tensor<1x64x64x64x64xf32>
460:    %457 = "tfl.conv_3d_transpose"(%455, %456, %454, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x32x64xf32>, tensor<1x64x64x64x64xf32>, none) -> tensor<1x128x128x128x32xf32>
470:    %467 = "tfl.conv_3d"(%465, %466, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x64xf32>, tensor<3x3x3x64x32xf32>, none) -> tensor<1x128x128x128x32xf32>
490:    %487 = "tfl.conv_3d"(%485, %486, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x32xf32>, tensor<3x3x3x32x32xf32>, none) -> tensor<1x128x128x128x32xf32>
509:    %506 = "tfl.conv_3d"(%504, %505, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x128x128x128x32xf32>, tensor<1x1x1x32x3xf32>, none) -> tensor<1x128x128x128x3xf32>

So in this case if tfl. conv_3d_transpose was supported, no_value would be too.