onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.29k stars 296 forks source link

Train onnx model from another framework #939

Open letruongthanh3698 opened 3 years ago

letruongthanh3698 commented 3 years ago

Hi everyone,

I am now testing the example/train_onnx_model.py on Google Colab with the onnx model generated from MATLAB Deep Learning Tool Box and it shows the error: `==> Train the model..

ValueError Traceback (most recent call last)

in () 1 if __name__ == "__main__": ----> 2 train_onnx_model() 3 run_onnx_model(trained_onnx_model) 2 frames in train_onnx_model() 42 feed_dict[training_flag_placeholder] = True 43 loss, accuracy, _ = sess.run([loss_op, eval_op, opt_op], ---> 44 feed_dict=feed_dict) 45 if (step % 100) == 0: 46 print('Epoch {}, train step {}, loss:{}, accuracy:{}'.format( /usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 966 try: 967 result = self._run(None, fetches, feed_dict, options_ptr, --> 968 run_metadata_ptr) 969 if run_metadata: 970 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) /usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1165 'Cannot feed value of shape %r for Tensor %r, ' 1166 'which has shape %r' % -> 1167 (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 1168 if not self.graph.is_feedable(subfeed_t): 1169 raise ValueError('Tensor %s may not be fed.' % subfeed_t) ValueError: Cannot feed value of shape (32, 28, 28, 1) for Tensor 'imageinput_Mean/Read/ReadVariableOp:0', which has shape '(1, 1, 1, 1)'` Do I have to train by tensorflow and retrain by tensorflow or I can train with different type of framework? Can anyone help me? Thank you.
chinhuang007 commented 3 years ago

I don't think the model has to be initially trained in Tensorflow. @chudegao maybe you can take a look. Thanks.

chudegao commented 3 years ago

The onnx model can be exported from other frameworks. I tried using onnx model from both pytorch and tensorflow. Just make sure the model's input and feed_dict is consist. From the error message, I guess the onnx model's input shape should be [1,1,1,1] and you are trying to feed data with shape[32,28,28,1].