STMicroelectronics / stm32ai-modelzoo

AI Model Zoo for STM32 devices
Other
236 stars 64 forks source link

Issue with layers.Input for a UNet model #8

Closed kirilllzaitsev closed 11 months ago

kirilllzaitsev commented 1 year ago

Hi, do you have any examples of how to fit architectures such as UNet, Autoencoder, etc. onto an STM32 device? Trying to do it with a UNet I define below, I receive the error: NOT IMPLEMENTED: Order of dimensions of input cannot be interpreted

The issue must be in the way I define inputs: layers.Input(shape=(*img_size, in_channels), name="input"), but I see lots of similar cases that work. Can it be that the skip-connection architecture impacts tflite conversion, causing the issue?

My model is:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               

 conv2d (Conv2D)                (None, 14, 14, 16)   32          ['input_1[0][0]']                

 batch_normalization (BatchNorm  (None, 14, 14, 16)  64          ['conv2d[0][0]']                 
 alization)                                                                                       

 activation (Activation)        (None, 14, 14, 16)   0           ['batch_normalization[0][0]']    

 activation_1 (Activation)      (None, 14, 14, 16)   0           ['activation[0][0]']             

 separable_conv2d (SeparableCon  (None, 14, 14, 32)  688         ['activation_1[0][0]']           
 v2D)                                                                                             

 batch_normalization_1 (BatchNo  (None, 14, 14, 32)  128         ['separable_conv2d[0][0]']       
 rmalization)                                                                                     

 activation_2 (Activation)      (None, 14, 14, 32)   0           ['batch_normalization_1[0][0]']  

 separable_conv2d_1 (SeparableC  (None, 14, 14, 32)  1344        ['activation_2[0][0]']           
 onv2D)                                                                                           

 batch_normalization_2 (BatchNo  (None, 14, 14, 32)  128         ['separable_conv2d_1[0][0]']     
 rmalization)                                                                                     

 max_pooling2d (MaxPooling2D)   (None, 7, 7, 32)     0           ['batch_normalization_2[0][0]']  

 conv2d_1 (Conv2D)              (None, 7, 7, 32)     544         ['activation[0][0]']             

 add (Add)                      (None, 7, 7, 32)     0           ['max_pooling2d[0][0]',          
                                                                  'conv2d_1[0][0]']               

 activation_3 (Activation)      (None, 7, 7, 32)     0           ['add[0][0]']                    

 conv2d_transpose (Conv2DTransp  (None, 7, 7, 32)    9248        ['activation_3[0][0]']           
 ose)                                                                                             

 batch_normalization_3 (BatchNo  (None, 7, 7, 32)    128         ['conv2d_transpose[0][0]']       
 rmalization)                                                                                     

 activation_4 (Activation)      (None, 7, 7, 32)     0           ['batch_normalization_3[0][0]']  

 conv2d_transpose_1 (Conv2DTran  (None, 7, 7, 32)    9248        ['activation_4[0][0]']           
 spose)                                                                                           

 batch_normalization_4 (BatchNo  (None, 7, 7, 32)    128         ['conv2d_transpose_1[0][0]']     
 rmalization)                                                                                     

 up_sampling2d_1 (UpSampling2D)  (None, 14, 14, 32)  0           ['add[0][0]']                    

 up_sampling2d (UpSampling2D)   (None, 14, 14, 32)   0           ['batch_normalization_4[0][0]']  

 conv2d_2 (Conv2D)              (None, 14, 14, 32)   1056        ['up_sampling2d_1[0][0]']        

 add_1 (Add)                    (None, 14, 14, 32)   0           ['up_sampling2d[0][0]',          
                                                                  'conv2d_2[0][0]']               

 activation_5 (Activation)      (None, 14, 14, 32)   0           ['add_1[0][0]']                  

 conv2d_transpose_2 (Conv2DTran  (None, 14, 14, 16)  4624        ['activation_5[0][0]']           
 spose)                                                                                           

 batch_normalization_5 (BatchNo  (None, 14, 14, 16)  64          ['conv2d_transpose_2[0][0]']     
 rmalization)                                                                                     

 activation_6 (Activation)      (None, 14, 14, 16)   0           ['batch_normalization_5[0][0]']  

 conv2d_transpose_3 (Conv2DTran  (None, 14, 14, 16)  2320        ['activation_6[0][0]']           
 spose)                                                                                           

 batch_normalization_6 (BatchNo  (None, 14, 14, 16)  64          ['conv2d_transpose_3[0][0]']     
 rmalization)                                                                                     

 up_sampling2d_3 (UpSampling2D)  (None, 28, 28, 32)  0           ['add_1[0][0]']                  

 up_sampling2d_2 (UpSampling2D)  (None, 28, 28, 16)  0           ['batch_normalization_6[0][0]']  

 conv2d_3 (Conv2D)              (None, 28, 28, 16)   528         ['up_sampling2d_3[0][0]']        

 add_2 (Add)                    (None, 28, 28, 16)   0           ['up_sampling2d_2[0][0]',        
                                                                  'conv2d_3[0][0]']               

 conv2d_4 (Conv2D)              (None, 28, 28, 1)    17          ['add_2[0][0]']                  

==================================================================================================
Total params: 30,353
Trainable params: 30,001
Non-trainable params: 352
__________________________________________________________________________________________________
LFOSTM commented 1 year ago

Hello, Actually this should not happen... Could you share the model with us (.h5, .tflite or .onnx), please? Thanks

Shahnawax commented 11 months ago

Closing the issue, as we did not manage to reproduce the issue on our side and did not get any additional information from the reporter.