Open mariecwhite opened 1 year ago
Wrt broadcast_to, they all seem to be feeding into broadcasting functions (I think these can all be fused in to the these mul/add ops):
1 %24 = "tfl.mul"(%8, %23) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
2 %26 = "tfl.add"(%24, %25) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
3 %45 = "tfl.mul"(%32, %44) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
4 %47 = "tfl.add"(%45, %46) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
5 %67 = "tfl.mul"(%52, %66) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
6 %69 = "tfl.add"(%67, %68) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
7 %87 = "tfl.mul"(%74, %86) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
8 %89 = "tfl.add"(%87, %88) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
9 %109 = "tfl.mul"(%94, %108) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
10 %111 = "tfl.add"(%109, %110) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
11 %129 = "tfl.mul"(%116, %128) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
12 %131 = "tfl.add"(%129, %130) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
13 %151 = "tfl.mul"(%136, %150) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
14 %153 = "tfl.add"(%151, %152) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
15 %171 = "tfl.mul"(%158, %170) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
16 %173 = "tfl.add"(%171, %172) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
17 %193 = "tfl.mul"(%178, %192) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
18 %195 = "tfl.add"(%193, %194) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
19 %213 = "tfl.mul"(%200, %212) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
20 %215 = "tfl.add"(%213, %214) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
21 %234 = "tfl.mul"(%220, %233) {fused_activation_function = "NONE"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
22 %236 = "tfl.add"(%234, %235) {fused_activation_function = "RELU"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
23 %254 = "tfl.mul"(%241, %253) {fused_activation_function = "NONE"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
24 %256 = "tfl.add"(%254, %255) {fused_activation_function = "RELU"} : (tensor<1x320x4x4x4xf32>, tensor<1x320x4x4x4xf32>) -> tensor<1x320x4x4x4xf32>
25 %282 = "tfl.mul"(%269, %281) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
26 %284 = "tfl.add"(%282, %283) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
27 %302 = "tfl.mul"(%289, %301) {fused_activation_function = "NONE"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
28 %304 = "tfl.add"(%302, %303) {fused_activation_function = "RELU"} : (tensor<1x320x8x8x8xf32>, tensor<1x320x8x8x8xf32>) -> tensor<1x320x8x8x8xf32>
29 %312 = "tfl.add"(%308, %311) {fused_activation_function = "NONE"} : (tensor<1x16x16x16x256xf32>, tensor<1x16x16x16x256xf32>) -> tensor<1x16x16x16x256xf32>
30 %332 = "tfl.mul"(%319, %331) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
31 %334 = "tfl.add"(%332, %333) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
32 %352 = "tfl.mul"(%339, %351) {fused_activation_function = "NONE"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
33 %354 = "tfl.add"(%352, %353) {fused_activation_function = "RELU"} : (tensor<1x256x16x16x16xf32>, tensor<1x256x16x16x16xf32>) -> tensor<1x256x16x16x16xf32>
34 %362 = "tfl.add"(%358, %361) {fused_activation_function = "NONE"} : (tensor<1x32x32x32x128xf32>, tensor<1x32x32x32x128xf32>) -> tensor<1x32x32x32x128xf32>
35 %382 = "tfl.mul"(%369, %381) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
36 %384 = "tfl.add"(%382, %383) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
37 %402 = "tfl.mul"(%389, %401) {fused_activation_function = "NONE"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
38 %404 = "tfl.add"(%402, %403) {fused_activation_function = "RELU"} : (tensor<1x128x32x32x32xf32>, tensor<1x128x32x32x32xf32>) -> tensor<1x128x32x32x32xf32>
39 %411 = "tfl.add"(%408, %410) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
40 %431 = "tfl.mul"(%418, %430) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
41 %433 = "tfl.add"(%431, %432) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
42 %451 = "tfl.mul"(%438, %450) {fused_activation_function = "NONE"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
43 %453 = "tfl.add"(%451, %452) {fused_activation_function = "RELU"} : (tensor<1x64x64x64x64xf32>, tensor<1x64x64x64x64xf32>) -> tensor<1x64x64x64x64xf32>
44 %461 = "tfl.add"(%457, %460) {fused_activation_function = "NONE"} : (tensor<1x128x128x128x32xf32>, tensor<1x128x128x128x32xf32>) -> tensor<1x128x128x128x32xf32>
45 %481 = "tfl.mul"(%468, %480) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
46 %483 = "tfl.add"(%481, %482) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
47 %501 = "tfl.mul"(%488, %500) {fused_activation_function = "NONE"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
48 %503 = "tfl.add"(%501, %502) {fused_activation_function = "RELU"} : (tensor<1x32x128x128x128xf32>, tensor<1x32x128x128x128xf32>) -> tensor<1x32x128x128x128xf32>
49 %510 = "tfl.add"(%506, %509) {fused_activation_function = "NONE"} : (tensor<1x128x128x128x3xf32>, tensor<1x128x128x128x3xf32>) -> tensor<1x128x128x128x3xf32>
Wrt no_value
9: %6 = "tfl.conv_3d"(%3, %4, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x1xf32>, tensor<3x3x3x1x32xf32>, none) -> tensor<1x128x128x128x32xf32>
34: %31 = "tfl.conv_3d"(%29, %30, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x32xf32>, tensor<3x3x3x32x32xf32>, none) -> tensor<1x128x128x128x32xf32>
54: %51 = "tfl.conv_3d"(%49, %50, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x130x130x130x32xf32>, tensor<3x3x3x32x64xf32>, none) -> tensor<1x64x64x64x64xf32>
76: %73 = "tfl.conv_3d"(%71, %72, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x66x64xf32>, tensor<3x3x3x64x64xf32>, none) -> tensor<1x64x64x64x64xf32>
96: %93 = "tfl.conv_3d"(%91, %92, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x66x66x66x64xf32>, tensor<3x3x3x64x128xf32>, none) -> tensor<1x32x32x32x128xf32>
118: %115 = "tfl.conv_3d"(%113, %114, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x34x34x34x128xf32>, tensor<3x3x3x128x128xf32>, none) -> tensor<1x32x32x32x128xf32>
138: %135 = "tfl.conv_3d"(%133, %134, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x34x34x34x128xf32>, tensor<3x3x3x128x256xf32>, none) -> tensor<1x16x16x16x256xf32>
160: %157 = "tfl.conv_3d"(%155, %156, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x18x18x18x256xf32>, tensor<3x3x3x256x256xf32>, none) -> tensor<1x16x16x16x256xf32>
180: %177 = "tfl.conv_3d"(%175, %176, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x18x18x18x256xf32>, tensor<3x3x3x256x320xf32>, none) -> tensor<1x8x8x8x320xf32>
202: %199 = "tfl.conv_3d"(%197, %198, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x10x10x10x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x8x8x8x320xf32>
222: %219 = "tfl.conv_3d"(%217, %218, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x10x10x10x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x4x4x4x320xf32>
243: %240 = "tfl.conv_3d"(%238, %239, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x6x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x4x4x4x320xf32>
263: %260 = "tfl.conv_3d_transpose"(%258, %259, %257, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x320x320xf32>, tensor<1x4x4x4x320xf32>, none) -> tensor<1x8x8x8x320xf32>
271: %268 = "tfl.conv_3d"(%266, %267, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x10x10x10x640xf32>, tensor<3x3x3x640x320xf32>, none) -> tensor<1x8x8x8x320xf32>
291: %288 = "tfl.conv_3d"(%286, %287, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x10x10x10x320xf32>, tensor<3x3x3x320x320xf32>, none) -> tensor<1x8x8x8x320xf32>
311: %308 = "tfl.conv_3d_transpose"(%306, %307, %305, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x256x320xf32>, tensor<1x8x8x8x320xf32>, none) -> tensor<1x16x16x16x256xf32>
321: %318 = "tfl.conv_3d"(%316, %317, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x18x18x18x512xf32>, tensor<3x3x3x512x256xf32>, none) -> tensor<1x16x16x16x256xf32>
341: %338 = "tfl.conv_3d"(%336, %337, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x18x18x18x256xf32>, tensor<3x3x3x256x256xf32>, none) -> tensor<1x16x16x16x256xf32>
361: %358 = "tfl.conv_3d_transpose"(%356, %357, %355, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x128x256xf32>, tensor<1x16x16x16x256xf32>, none) -> tensor<1x32x32x32x128xf32>
371: %368 = "tfl.conv_3d"(%366, %367, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x34x34x34x256xf32>, tensor<3x3x3x256x128xf32>, none) -> tensor<1x32x32x32x128xf32>
391: %388 = "tfl.conv_3d"(%386, %387, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x34x34x34x128xf32>, tensor<3x3x3x128x128xf32>, none) -> tensor<1x32x32x32x128xf32>
411: %408 = "tfl.conv_3d_transpose"(%406, %407, %405, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x64x128xf32>, tensor<1x32x32x32x128xf32>, none) -> tensor<1x64x64x64x64xf32>
420: %417 = "tfl.conv_3d"(%415, %416, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x66x128xf32>, tensor<3x3x3x128x64xf32>, none) -> tensor<1x64x64x64x64xf32>
440: %437 = "tfl.conv_3d"(%435, %436, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x66x64xf32>, tensor<3x3x3x64x64xf32>, none) -> tensor<1x64x64x64x64xf32>
460: %457 = "tfl.conv_3d_transpose"(%455, %456, %454, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<5xi32>, tensor<2x2x2x32x64xf32>, tensor<1x64x64x64x64xf32>, none) -> tensor<1x128x128x128x32xf32>
470: %467 = "tfl.conv_3d"(%465, %466, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x64xf32>, tensor<3x3x3x64x32xf32>, none) -> tensor<1x128x128x128x32xf32>
490: %487 = "tfl.conv_3d"(%485, %486, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x130x130x130x32xf32>, tensor<3x3x3x32x32xf32>, none) -> tensor<1x128x128x128x32xf32>
509: %506 = "tfl.conv_3d"(%504, %505, %5) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x128x128x128x32xf32>, tensor<1x1x1x32x3xf32>, none) -> tensor<1x128x128x128x3xf32>
So in this case if tfl. conv_3d_transpose was supported, no_value would be too.
When importing 3dunet_kits19_1x1x128x128x128.tflite to mlir using
import-iree-tflite
, I get the error:This model is part of the MLPerf Inference suite and taking from: https://github.com/mlcommons/inference/tree/master/vision/medical_imaging/3d-unet-kits19