Closed IrvingShu closed 5 years ago
@IrvingShu Any chance, that you can share the model/ model structure. I did not handle all cases in this script, only those I had in my models.
This is my model ,can you help me to test it? Thank you very much. https://pan.baidu.com/s/1QccXwB5v-wtsgif3etvWVg
@IrvingShu Yes I will try my best, but can you upload the model on any other filehoster. Downloading stuff from Baidu quite complicated without speaking Chinese.
Yes, Thank you very much, I just solved it. version:tf1.8.0 ` # Attributes without data_format, NWHC is default
atts = {key:n_org.node_def.attr[key] for key in list(n_org.node_def.attr.keys()) if key != 'data_format'}
# first add below code
if n_org.type in['Conv2D']:
st = atts['strides'].list.i
stl = [st[0], st[2], st[3], st[1]]
atts['strides'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))
` and modifed
op = sess.graph.create_op(op_type=n_org.type, inputs=op_inputs,name=n_org.name+'_new', dtypes=[tf.float32], attrs=atts)
Glad to hear that. Anything else not working as expected or can we close the issue?
OK, thanks you very much. pytorch -> onnx ->tf cpu succesful done!
Yes, Thank you very much, I just solved it. version:tf1.8.0 ` # Attributes without data_format, NWHC is default
atts = {key:n_org.node_def.attr[key] for key in list(n_org.node_def.attr.keys()) if key != 'data_format'} # first add below code if n_org.type in['Conv2D']: st = atts['strides'].list.i stl = [st[0], st[2], st[3], st[1]] atts['strides'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))
` and modifed
op = sess.graph.create_op(op_type=n_org.type, inputs=op_inputs,name=n_org.name+'_new', dtypes=[tf.float32], attrs=atts)
hey Im getting the same error but I cant resolve it by making the changes you mentioned. can you please help me solve it?
(1, 64, 112, 224)
(1, 64, 112, 112)
Traceback (most recent call last):
File "model-to-NHWC.py", line 58, in
net input 1 3 224 224 error: (1, 64, 224, 112) (1, 64, 112, 112) Traceback (most recent call last): File "convert-model-to-NWHC.py", line 47, in
assert out_trans.shape == sess.graph.get_tensor_by_name(n_org.name+':0').shape
AssertionError