paulbauriegel / tensorflow-tools

Python Scripts for working with Tensorflow
MIT License
48 stars 7 forks source link

error #1

Closed IrvingShu closed 5 years ago

IrvingShu commented 5 years ago

net input 1 3 224 224 error: (1, 64, 224, 112) (1, 64, 112, 112) Traceback (most recent call last): File "convert-model-to-NWHC.py", line 47, in assert out_trans.shape == sess.graph.get_tensor_by_name(n_org.name+':0').shape AssertionError

paulbauriegel commented 5 years ago

@IrvingShu Any chance, that you can share the model/ model structure. I did not handle all cases in this script, only those I had in my models.

IrvingShu commented 5 years ago

This is my model ,can you help me to test it? Thank you very much. https://pan.baidu.com/s/1QccXwB5v-wtsgif3etvWVg

paulbauriegel commented 5 years ago

@IrvingShu Yes I will try my best, but can you upload the model on any other filehoster. Downloading stuff from Baidu quite complicated without speaking Chinese.

IrvingShu commented 5 years ago

Yes, Thank you very much, I just solved it. version:tf1.8.0 ` # Attributes without data_format, NWHC is default

    atts = {key:n_org.node_def.attr[key] for key in list(n_org.node_def.attr.keys()) if key != 'data_format'}
   # first add below code
    if n_org.type in['Conv2D']:
        st = atts['strides'].list.i
        stl = [st[0], st[2], st[3], st[1]]
        atts['strides'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))

` and modifed

op = sess.graph.create_op(op_type=n_org.type, inputs=op_inputs,name=n_org.name+'_new', dtypes=[tf.float32], attrs=atts)

paulbauriegel commented 5 years ago

Glad to hear that. Anything else not working as expected or can we close the issue?

IrvingShu commented 5 years ago

OK, thanks you very much. pytorch -> onnx ->tf cpu succesful done!

codinghamster12 commented 3 years ago

Yes, Thank you very much, I just solved it. version:tf1.8.0 ` # Attributes without data_format, NWHC is default

    atts = {key:n_org.node_def.attr[key] for key in list(n_org.node_def.attr.keys()) if key != 'data_format'}
   # first add below code
    if n_org.type in['Conv2D']:
        st = atts['strides'].list.i
        stl = [st[0], st[2], st[3], st[1]]
        atts['strides'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))

` and modifed

op = sess.graph.create_op(op_type=n_org.type, inputs=op_inputs,name=n_org.name+'_new', dtypes=[tf.float32], attrs=atts)

hey Im getting the same error but I cant resolve it by making the changes you mentioned. can you please help me solve it? (1, 64, 112, 224) (1, 64, 112, 112) Traceback (most recent call last): File "model-to-NHWC.py", line 58, in assert out_trans.shape == sess.graph.get_tensor_by_name(n_org.name+':0').shape AssertionError