microsoft / MMdnn

MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.
MIT License
5.8k stars 965 forks source link

Converted Resent50 from mxnet is predicting label incorrectly #184

Closed LiYingwei closed 6 years ago

LiYingwei commented 6 years ago

Ubuntu 14.04

Python version: 2.7

Tensorflow 1.4.0 with GPU

Pre-trained model path: download using mmdownload

Running scripts:

mkdir checkpoint
mmdownload -f mxnet -n imagenet1k-resnet-50 -o ./
mmtoir -f mxnet -n resnet-50-symbol.json -w resnet-50-0000.params -d resnet50 --inputShape 3 299 299
mmtocode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath mx_resnet50.py
python -m mmdnn.conversion.examples.tensorflow.imagenet_test -n mx_resnet50.py -w resnet50.npy --dump checkpoint/mx_resnet50.ckpt

I successfully got mx_resnet50.py

import tensorflow as tf

__weights_dict = dict()

is_train = False

def load_weights(weight_file):
    import numpy as np

    if weight_file == None:
        return

    try:
        weights_dict = np.load(weight_file).item()
    except:
        weights_dict = np.load(weight_file, encoding='bytes').item()

    return weights_dict

def KitModel(weight_file = None):
    global __weights_dict
    __weights_dict = load_weights(weight_file)

    data            = tf.placeholder(tf.float32, shape = (None, 299, 299, 3), name = 'data')
    bn_data         = batch_normalization(data, variance_epsilon=1.99999994948e-05, name='bn_data')
    conv0_pad       = tf.pad(bn_data, paddings = [[0L, 0L], [3L, 3L], [3L, 3L], [0L, 0L]])
    conv0           = convolution(conv0_pad, group=1, strides=[2, 2], padding='VALID', name='conv0')
    bn0             = batch_normalization(conv0, variance_epsilon=1.99999994948e-05, name='bn0')
    relu0           = tf.nn.relu(bn0, name = 'relu0')
    pooling0_pad    = tf.pad(relu0, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]], constant_values=float('-Inf'))
    pooling0        = tf.nn.max_pool(pooling0_pad, [1, 3, 3, 1], [1, 2, 2, 1], padding='VALID', name='pooling0')
    stage1_unit1_bn1 = batch_normalization(pooling0, variance_epsilon=1.99999994948e-05, name='stage1_unit1_bn1')
    stage1_unit1_relu1 = tf.nn.relu(stage1_unit1_bn1, name = 'stage1_unit1_relu1')
    stage1_unit1_conv1 = convolution(stage1_unit1_relu1, group=1, strides=[1, 1], padding='VALID', name='stage1_unit1_conv1')
    stage1_unit1_sc = convolution(stage1_unit1_relu1, group=1, strides=[1, 1], padding='VALID', name='stage1_unit1_sc')
    stage1_unit1_bn2 = batch_normalization(stage1_unit1_conv1, variance_epsilon=1.99999994948e-05, name='stage1_unit1_bn2')
    stage1_unit1_relu2 = tf.nn.relu(stage1_unit1_bn2, name = 'stage1_unit1_relu2')
    stage1_unit1_conv2_pad = tf.pad(stage1_unit1_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage1_unit1_conv2 = convolution(stage1_unit1_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage1_unit1_conv2')
    stage1_unit1_bn3 = batch_normalization(stage1_unit1_conv2, variance_epsilon=1.99999994948e-05, name='stage1_unit1_bn3')
    stage1_unit1_relu3 = tf.nn.relu(stage1_unit1_bn3, name = 'stage1_unit1_relu3')
    stage1_unit1_conv3 = convolution(stage1_unit1_relu3, group=1, strides=[1, 1], padding='VALID', name='stage1_unit1_conv3')
    plus0           = stage1_unit1_conv3 + stage1_unit1_sc
    stage1_unit2_bn1 = batch_normalization(plus0, variance_epsilon=1.99999994948e-05, name='stage1_unit2_bn1')
    stage1_unit2_relu1 = tf.nn.relu(stage1_unit2_bn1, name = 'stage1_unit2_relu1')
    stage1_unit2_conv1 = convolution(stage1_unit2_relu1, group=1, strides=[1, 1], padding='VALID', name='stage1_unit2_conv1')
    stage1_unit2_bn2 = batch_normalization(stage1_unit2_conv1, variance_epsilon=1.99999994948e-05, name='stage1_unit2_bn2')
    stage1_unit2_relu2 = tf.nn.relu(stage1_unit2_bn2, name = 'stage1_unit2_relu2')
    stage1_unit2_conv2_pad = tf.pad(stage1_unit2_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage1_unit2_conv2 = convolution(stage1_unit2_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage1_unit2_conv2')
    stage1_unit2_bn3 = batch_normalization(stage1_unit2_conv2, variance_epsilon=1.99999994948e-05, name='stage1_unit2_bn3')
    stage1_unit2_relu3 = tf.nn.relu(stage1_unit2_bn3, name = 'stage1_unit2_relu3')
    stage1_unit2_conv3 = convolution(stage1_unit2_relu3, group=1, strides=[1, 1], padding='VALID', name='stage1_unit2_conv3')
    plus1           = stage1_unit2_conv3 + plus0
    stage1_unit3_bn1 = batch_normalization(plus1, variance_epsilon=1.99999994948e-05, name='stage1_unit3_bn1')
    stage1_unit3_relu1 = tf.nn.relu(stage1_unit3_bn1, name = 'stage1_unit3_relu1')
    stage1_unit3_conv1 = convolution(stage1_unit3_relu1, group=1, strides=[1, 1], padding='VALID', name='stage1_unit3_conv1')
    stage1_unit3_bn2 = batch_normalization(stage1_unit3_conv1, variance_epsilon=1.99999994948e-05, name='stage1_unit3_bn2')
    stage1_unit3_relu2 = tf.nn.relu(stage1_unit3_bn2, name = 'stage1_unit3_relu2')
    stage1_unit3_conv2_pad = tf.pad(stage1_unit3_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage1_unit3_conv2 = convolution(stage1_unit3_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage1_unit3_conv2')
    stage1_unit3_bn3 = batch_normalization(stage1_unit3_conv2, variance_epsilon=1.99999994948e-05, name='stage1_unit3_bn3')
    stage1_unit3_relu3 = tf.nn.relu(stage1_unit3_bn3, name = 'stage1_unit3_relu3')
    stage1_unit3_conv3 = convolution(stage1_unit3_relu3, group=1, strides=[1, 1], padding='VALID', name='stage1_unit3_conv3')
    plus2           = stage1_unit3_conv3 + plus1
    stage2_unit1_bn1 = batch_normalization(plus2, variance_epsilon=1.99999994948e-05, name='stage2_unit1_bn1')
    stage2_unit1_relu1 = tf.nn.relu(stage2_unit1_bn1, name = 'stage2_unit1_relu1')
    stage2_unit1_conv1 = convolution(stage2_unit1_relu1, group=1, strides=[1, 1], padding='VALID', name='stage2_unit1_conv1')
    stage2_unit1_sc = convolution(stage2_unit1_relu1, group=1, strides=[2, 2], padding='VALID', name='stage2_unit1_sc')
    stage2_unit1_bn2 = batch_normalization(stage2_unit1_conv1, variance_epsilon=1.99999994948e-05, name='stage2_unit1_bn2')
    stage2_unit1_relu2 = tf.nn.relu(stage2_unit1_bn2, name = 'stage2_unit1_relu2')
    stage2_unit1_conv2_pad = tf.pad(stage2_unit1_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage2_unit1_conv2 = convolution(stage2_unit1_conv2_pad, group=1, strides=[2, 2], padding='VALID', name='stage2_unit1_conv2')
    stage2_unit1_bn3 = batch_normalization(stage2_unit1_conv2, variance_epsilon=1.99999994948e-05, name='stage2_unit1_bn3')
    stage2_unit1_relu3 = tf.nn.relu(stage2_unit1_bn3, name = 'stage2_unit1_relu3')
    stage2_unit1_conv3 = convolution(stage2_unit1_relu3, group=1, strides=[1, 1], padding='VALID', name='stage2_unit1_conv3')
    plus3           = stage2_unit1_conv3 + stage2_unit1_sc
    stage2_unit2_bn1 = batch_normalization(plus3, variance_epsilon=1.99999994948e-05, name='stage2_unit2_bn1')
    stage2_unit2_relu1 = tf.nn.relu(stage2_unit2_bn1, name = 'stage2_unit2_relu1')
    stage2_unit2_conv1 = convolution(stage2_unit2_relu1, group=1, strides=[1, 1], padding='VALID', name='stage2_unit2_conv1')
    stage2_unit2_bn2 = batch_normalization(stage2_unit2_conv1, variance_epsilon=1.99999994948e-05, name='stage2_unit2_bn2')
    stage2_unit2_relu2 = tf.nn.relu(stage2_unit2_bn2, name = 'stage2_unit2_relu2')
    stage2_unit2_conv2_pad = tf.pad(stage2_unit2_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage2_unit2_conv2 = convolution(stage2_unit2_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage2_unit2_conv2')
    stage2_unit2_bn3 = batch_normalization(stage2_unit2_conv2, variance_epsilon=1.99999994948e-05, name='stage2_unit2_bn3')
    stage2_unit2_relu3 = tf.nn.relu(stage2_unit2_bn3, name = 'stage2_unit2_relu3')
    stage2_unit2_conv3 = convolution(stage2_unit2_relu3, group=1, strides=[1, 1], padding='VALID', name='stage2_unit2_conv3')
    plus4           = stage2_unit2_conv3 + plus3
    stage2_unit3_bn1 = batch_normalization(plus4, variance_epsilon=1.99999994948e-05, name='stage2_unit3_bn1')
    stage2_unit3_relu1 = tf.nn.relu(stage2_unit3_bn1, name = 'stage2_unit3_relu1')
    stage2_unit3_conv1 = convolution(stage2_unit3_relu1, group=1, strides=[1, 1], padding='VALID', name='stage2_unit3_conv1')
    stage2_unit3_bn2 = batch_normalization(stage2_unit3_conv1, variance_epsilon=1.99999994948e-05, name='stage2_unit3_bn2')
    stage2_unit3_relu2 = tf.nn.relu(stage2_unit3_bn2, name = 'stage2_unit3_relu2')
    stage2_unit3_conv2_pad = tf.pad(stage2_unit3_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage2_unit3_conv2 = convolution(stage2_unit3_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage2_unit3_conv2')
    stage2_unit3_bn3 = batch_normalization(stage2_unit3_conv2, variance_epsilon=1.99999994948e-05, name='stage2_unit3_bn3')
    stage2_unit3_relu3 = tf.nn.relu(stage2_unit3_bn3, name = 'stage2_unit3_relu3')
    stage2_unit3_conv3 = convolution(stage2_unit3_relu3, group=1, strides=[1, 1], padding='VALID', name='stage2_unit3_conv3')
    plus5           = stage2_unit3_conv3 + plus4
    stage2_unit4_bn1 = batch_normalization(plus5, variance_epsilon=1.99999994948e-05, name='stage2_unit4_bn1')
    stage2_unit4_relu1 = tf.nn.relu(stage2_unit4_bn1, name = 'stage2_unit4_relu1')
    stage2_unit4_conv1 = convolution(stage2_unit4_relu1, group=1, strides=[1, 1], padding='VALID', name='stage2_unit4_conv1')
    stage2_unit4_bn2 = batch_normalization(stage2_unit4_conv1, variance_epsilon=1.99999994948e-05, name='stage2_unit4_bn2')
    stage2_unit4_relu2 = tf.nn.relu(stage2_unit4_bn2, name = 'stage2_unit4_relu2')
    stage2_unit4_conv2_pad = tf.pad(stage2_unit4_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage2_unit4_conv2 = convolution(stage2_unit4_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage2_unit4_conv2')
    stage2_unit4_bn3 = batch_normalization(stage2_unit4_conv2, variance_epsilon=1.99999994948e-05, name='stage2_unit4_bn3')
    stage2_unit4_relu3 = tf.nn.relu(stage2_unit4_bn3, name = 'stage2_unit4_relu3')
    stage2_unit4_conv3 = convolution(stage2_unit4_relu3, group=1, strides=[1, 1], padding='VALID', name='stage2_unit4_conv3')
    plus6           = stage2_unit4_conv3 + plus5
    stage3_unit1_bn1 = batch_normalization(plus6, variance_epsilon=1.99999994948e-05, name='stage3_unit1_bn1')
    stage3_unit1_relu1 = tf.nn.relu(stage3_unit1_bn1, name = 'stage3_unit1_relu1')
    stage3_unit1_conv1 = convolution(stage3_unit1_relu1, group=1, strides=[1, 1], padding='VALID', name='stage3_unit1_conv1')
    stage3_unit1_sc = convolution(stage3_unit1_relu1, group=1, strides=[2, 2], padding='VALID', name='stage3_unit1_sc')
    stage3_unit1_bn2 = batch_normalization(stage3_unit1_conv1, variance_epsilon=1.99999994948e-05, name='stage3_unit1_bn2')
    stage3_unit1_relu2 = tf.nn.relu(stage3_unit1_bn2, name = 'stage3_unit1_relu2')
    stage3_unit1_conv2_pad = tf.pad(stage3_unit1_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage3_unit1_conv2 = convolution(stage3_unit1_conv2_pad, group=1, strides=[2, 2], padding='VALID', name='stage3_unit1_conv2')
    stage3_unit1_bn3 = batch_normalization(stage3_unit1_conv2, variance_epsilon=1.99999994948e-05, name='stage3_unit1_bn3')
    stage3_unit1_relu3 = tf.nn.relu(stage3_unit1_bn3, name = 'stage3_unit1_relu3')
    stage3_unit1_conv3 = convolution(stage3_unit1_relu3, group=1, strides=[1, 1], padding='VALID', name='stage3_unit1_conv3')
    plus7           = stage3_unit1_conv3 + stage3_unit1_sc
    stage3_unit2_bn1 = batch_normalization(plus7, variance_epsilon=1.99999994948e-05, name='stage3_unit2_bn1')
    stage3_unit2_relu1 = tf.nn.relu(stage3_unit2_bn1, name = 'stage3_unit2_relu1')
    stage3_unit2_conv1 = convolution(stage3_unit2_relu1, group=1, strides=[1, 1], padding='VALID', name='stage3_unit2_conv1')
    stage3_unit2_bn2 = batch_normalization(stage3_unit2_conv1, variance_epsilon=1.99999994948e-05, name='stage3_unit2_bn2')
    stage3_unit2_relu2 = tf.nn.relu(stage3_unit2_bn2, name = 'stage3_unit2_relu2')
    stage3_unit2_conv2_pad = tf.pad(stage3_unit2_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage3_unit2_conv2 = convolution(stage3_unit2_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage3_unit2_conv2')
    stage3_unit2_bn3 = batch_normalization(stage3_unit2_conv2, variance_epsilon=1.99999994948e-05, name='stage3_unit2_bn3')
    stage3_unit2_relu3 = tf.nn.relu(stage3_unit2_bn3, name = 'stage3_unit2_relu3')
    stage3_unit2_conv3 = convolution(stage3_unit2_relu3, group=1, strides=[1, 1], padding='VALID', name='stage3_unit2_conv3')
    plus8           = stage3_unit2_conv3 + plus7
    stage3_unit3_bn1 = batch_normalization(plus8, variance_epsilon=1.99999994948e-05, name='stage3_unit3_bn1')
    stage3_unit3_relu1 = tf.nn.relu(stage3_unit3_bn1, name = 'stage3_unit3_relu1')
    stage3_unit3_conv1 = convolution(stage3_unit3_relu1, group=1, strides=[1, 1], padding='VALID', name='stage3_unit3_conv1')
    stage3_unit3_bn2 = batch_normalization(stage3_unit3_conv1, variance_epsilon=1.99999994948e-05, name='stage3_unit3_bn2')
    stage3_unit3_relu2 = tf.nn.relu(stage3_unit3_bn2, name = 'stage3_unit3_relu2')
    stage3_unit3_conv2_pad = tf.pad(stage3_unit3_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage3_unit3_conv2 = convolution(stage3_unit3_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage3_unit3_conv2')
    stage3_unit3_bn3 = batch_normalization(stage3_unit3_conv2, variance_epsilon=1.99999994948e-05, name='stage3_unit3_bn3')
    stage3_unit3_relu3 = tf.nn.relu(stage3_unit3_bn3, name = 'stage3_unit3_relu3')
    stage3_unit3_conv3 = convolution(stage3_unit3_relu3, group=1, strides=[1, 1], padding='VALID', name='stage3_unit3_conv3')
    plus9           = stage3_unit3_conv3 + plus8
    stage3_unit4_bn1 = batch_normalization(plus9, variance_epsilon=1.99999994948e-05, name='stage3_unit4_bn1')
    stage3_unit4_relu1 = tf.nn.relu(stage3_unit4_bn1, name = 'stage3_unit4_relu1')
    stage3_unit4_conv1 = convolution(stage3_unit4_relu1, group=1, strides=[1, 1], padding='VALID', name='stage3_unit4_conv1')
    stage3_unit4_bn2 = batch_normalization(stage3_unit4_conv1, variance_epsilon=1.99999994948e-05, name='stage3_unit4_bn2')
    stage3_unit4_relu2 = tf.nn.relu(stage3_unit4_bn2, name = 'stage3_unit4_relu2')
    stage3_unit4_conv2_pad = tf.pad(stage3_unit4_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage3_unit4_conv2 = convolution(stage3_unit4_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage3_unit4_conv2')
    stage3_unit4_bn3 = batch_normalization(stage3_unit4_conv2, variance_epsilon=1.99999994948e-05, name='stage3_unit4_bn3')
    stage3_unit4_relu3 = tf.nn.relu(stage3_unit4_bn3, name = 'stage3_unit4_relu3')
    stage3_unit4_conv3 = convolution(stage3_unit4_relu3, group=1, strides=[1, 1], padding='VALID', name='stage3_unit4_conv3')
    plus10          = stage3_unit4_conv3 + plus9
    stage3_unit5_bn1 = batch_normalization(plus10, variance_epsilon=1.99999994948e-05, name='stage3_unit5_bn1')
    stage3_unit5_relu1 = tf.nn.relu(stage3_unit5_bn1, name = 'stage3_unit5_relu1')
    stage3_unit5_conv1 = convolution(stage3_unit5_relu1, group=1, strides=[1, 1], padding='VALID', name='stage3_unit5_conv1')
    stage3_unit5_bn2 = batch_normalization(stage3_unit5_conv1, variance_epsilon=1.99999994948e-05, name='stage3_unit5_bn2')
    stage3_unit5_relu2 = tf.nn.relu(stage3_unit5_bn2, name = 'stage3_unit5_relu2')
    stage3_unit5_conv2_pad = tf.pad(stage3_unit5_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage3_unit5_conv2 = convolution(stage3_unit5_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage3_unit5_conv2')
    stage3_unit5_bn3 = batch_normalization(stage3_unit5_conv2, variance_epsilon=1.99999994948e-05, name='stage3_unit5_bn3')
    stage3_unit5_relu3 = tf.nn.relu(stage3_unit5_bn3, name = 'stage3_unit5_relu3')
    stage3_unit5_conv3 = convolution(stage3_unit5_relu3, group=1, strides=[1, 1], padding='VALID', name='stage3_unit5_conv3')
    plus11          = stage3_unit5_conv3 + plus10
    stage3_unit6_bn1 = batch_normalization(plus11, variance_epsilon=1.99999994948e-05, name='stage3_unit6_bn1')
    stage3_unit6_relu1 = tf.nn.relu(stage3_unit6_bn1, name = 'stage3_unit6_relu1')
    stage3_unit6_conv1 = convolution(stage3_unit6_relu1, group=1, strides=[1, 1], padding='VALID', name='stage3_unit6_conv1')
    stage3_unit6_bn2 = batch_normalization(stage3_unit6_conv1, variance_epsilon=1.99999994948e-05, name='stage3_unit6_bn2')
    stage3_unit6_relu2 = tf.nn.relu(stage3_unit6_bn2, name = 'stage3_unit6_relu2')
    stage3_unit6_conv2_pad = tf.pad(stage3_unit6_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage3_unit6_conv2 = convolution(stage3_unit6_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage3_unit6_conv2')
    stage3_unit6_bn3 = batch_normalization(stage3_unit6_conv2, variance_epsilon=1.99999994948e-05, name='stage3_unit6_bn3')
    stage3_unit6_relu3 = tf.nn.relu(stage3_unit6_bn3, name = 'stage3_unit6_relu3')
    stage3_unit6_conv3 = convolution(stage3_unit6_relu3, group=1, strides=[1, 1], padding='VALID', name='stage3_unit6_conv3')
    plus12          = stage3_unit6_conv3 + plus11
    stage4_unit1_bn1 = batch_normalization(plus12, variance_epsilon=1.99999994948e-05, name='stage4_unit1_bn1')
    stage4_unit1_relu1 = tf.nn.relu(stage4_unit1_bn1, name = 'stage4_unit1_relu1')
    stage4_unit1_conv1 = convolution(stage4_unit1_relu1, group=1, strides=[1, 1], padding='VALID', name='stage4_unit1_conv1')
    stage4_unit1_sc = convolution(stage4_unit1_relu1, group=1, strides=[2, 2], padding='VALID', name='stage4_unit1_sc')
    stage4_unit1_bn2 = batch_normalization(stage4_unit1_conv1, variance_epsilon=1.99999994948e-05, name='stage4_unit1_bn2')
    stage4_unit1_relu2 = tf.nn.relu(stage4_unit1_bn2, name = 'stage4_unit1_relu2')
    stage4_unit1_conv2_pad = tf.pad(stage4_unit1_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage4_unit1_conv2 = convolution(stage4_unit1_conv2_pad, group=1, strides=[2, 2], padding='VALID', name='stage4_unit1_conv2')
    stage4_unit1_bn3 = batch_normalization(stage4_unit1_conv2, variance_epsilon=1.99999994948e-05, name='stage4_unit1_bn3')
    stage4_unit1_relu3 = tf.nn.relu(stage4_unit1_bn3, name = 'stage4_unit1_relu3')
    stage4_unit1_conv3 = convolution(stage4_unit1_relu3, group=1, strides=[1, 1], padding='VALID', name='stage4_unit1_conv3')
    plus13          = stage4_unit1_conv3 + stage4_unit1_sc
    stage4_unit2_bn1 = batch_normalization(plus13, variance_epsilon=1.99999994948e-05, name='stage4_unit2_bn1')
    stage4_unit2_relu1 = tf.nn.relu(stage4_unit2_bn1, name = 'stage4_unit2_relu1')
    stage4_unit2_conv1 = convolution(stage4_unit2_relu1, group=1, strides=[1, 1], padding='VALID', name='stage4_unit2_conv1')
    stage4_unit2_bn2 = batch_normalization(stage4_unit2_conv1, variance_epsilon=1.99999994948e-05, name='stage4_unit2_bn2')
    stage4_unit2_relu2 = tf.nn.relu(stage4_unit2_bn2, name = 'stage4_unit2_relu2')
    stage4_unit2_conv2_pad = tf.pad(stage4_unit2_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage4_unit2_conv2 = convolution(stage4_unit2_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage4_unit2_conv2')
    stage4_unit2_bn3 = batch_normalization(stage4_unit2_conv2, variance_epsilon=1.99999994948e-05, name='stage4_unit2_bn3')
    stage4_unit2_relu3 = tf.nn.relu(stage4_unit2_bn3, name = 'stage4_unit2_relu3')
    stage4_unit2_conv3 = convolution(stage4_unit2_relu3, group=1, strides=[1, 1], padding='VALID', name='stage4_unit2_conv3')
    plus14          = stage4_unit2_conv3 + plus13
    stage4_unit3_bn1 = batch_normalization(plus14, variance_epsilon=1.99999994948e-05, name='stage4_unit3_bn1')
    stage4_unit3_relu1 = tf.nn.relu(stage4_unit3_bn1, name = 'stage4_unit3_relu1')
    stage4_unit3_conv1 = convolution(stage4_unit3_relu1, group=1, strides=[1, 1], padding='VALID', name='stage4_unit3_conv1')
    stage4_unit3_bn2 = batch_normalization(stage4_unit3_conv1, variance_epsilon=1.99999994948e-05, name='stage4_unit3_bn2')
    stage4_unit3_relu2 = tf.nn.relu(stage4_unit3_bn2, name = 'stage4_unit3_relu2')
    stage4_unit3_conv2_pad = tf.pad(stage4_unit3_relu2, paddings = [[0L, 0L], [1L, 1L], [1L, 1L], [0L, 0L]])
    stage4_unit3_conv2 = convolution(stage4_unit3_conv2_pad, group=1, strides=[1, 1], padding='VALID', name='stage4_unit3_conv2')
    stage4_unit3_bn3 = batch_normalization(stage4_unit3_conv2, variance_epsilon=1.99999994948e-05, name='stage4_unit3_bn3')
    stage4_unit3_relu3 = tf.nn.relu(stage4_unit3_bn3, name = 'stage4_unit3_relu3')
    stage4_unit3_conv3 = convolution(stage4_unit3_relu3, group=1, strides=[1, 1], padding='VALID', name='stage4_unit3_conv3')
    plus15          = stage4_unit3_conv3 + plus14
    bn1             = batch_normalization(plus15, variance_epsilon=1.99999994948e-05, name='bn1')
    relu1           = tf.nn.relu(bn1, name = 'relu1')
    pool1           = tf.nn.avg_pool(relu1, [1] + relu1.get_shape().as_list()[1:-1] + [1], strides = [1] * 4, padding = 'VALID', name = 'pool1')
    flatten0        = tf.contrib.layers.flatten(pool1)
    fc1             = tf.layers.dense(flatten0, 1000, kernel_initializer = tf.constant_initializer(__weights_dict['fc1']['weights']), bias_initializer = tf.constant_initializer(__weights_dict['fc1']['bias']), use_bias = True)
    softmax         = tf.nn.softmax(fc1, name = 'softmax')
    return data, softmax

def batch_normalization(input, name, **kwargs):
    mean = tf.Variable(__weights_dict[name]['mean'], name = name + "_mean", trainable = is_train)
    variance = tf.Variable(__weights_dict[name]['var'], name = name + "_var", trainable = is_train)
    offset = tf.Variable(__weights_dict[name]['bias'], name = name + "_bias", trainable = is_train) if 'bias' in __weights_dict[name] else None
    scale = tf.Variable(__weights_dict[name]['scale'], name = name + "_scale", trainable = is_train) if 'scale' in __weights_dict[name] else None
    return tf.nn.batch_normalization(input, mean, variance, offset, scale, name = name, **kwargs)

def convolution(input, name, group, **kwargs):
    w = tf.Variable(__weights_dict[name]['weights'], trainable=is_train, name=name + "_weight")
    if group == 1:
        layer = tf.nn.convolution(input, w, **kwargs)
    else:
        weight_groups = tf.split(w, num_or_size_splits=group, axis=-1)
        xs = tf.split(input, num_or_size_splits=group, axis=-1)
        convolved = [tf.nn.convolution(x, weight, **kwargs) for
                    (x, weight) in zip(xs, weight_groups)]
        layer = tf.concat(convolved, axis=-1)

    if 'bias' in __weights_dict[name]:
        b = tf.Variable(__weights_dict[name]['bias'], trainable=is_train, name=name + "_bias")
        layer = layer + b
    return layer

But when I load the weight and feed images, the output are always equal to 818. Please help.

kitstar commented 6 years ago

Hi @LiYingwei,

For MXNet resnet50,

  1. The input shape is 3224224, not 299299, you need to modify the mmtoir* command.

  2. The preprocess function is transposing the image from RGB format to BGR format.

Did you apply the preprocess function? Thanks.

LiYingwei commented 6 years ago

@kitstar , Thanks for helping me. I indeed didn't consider these things. I will try to fix them and see if it can work.

Also, before you reply, I tried to debug by myself. Firstly I want to make sure tensorflow part could work properly, so I followed this page to test if I can correctly convert tf model to IR and then convert back to tf model.

So I run these scripts:

  1. mmdownload -f tensorflow -n resnet_v2_152
  2. mmtoir -f tensorflow -d resnet152 -n imagenet_resnet_v2_152.ckpt.meta -w imagenet_resnet_v2_152.ckpt --dstNodeName MMdnn_Output
  3. mmtocode -f tensorflow --IRModelPath resnet152.pb --IRWeightPath resnet152.npy --dstModelPath tf_resnet152.py
  4. python -m mmdnn.conversion.examples.tensorflow.imagenet_test -s tf -p resnet -n tf_resnet152 -w resnet152.npy

When I run line 4, I got an error:

imagenet_test.py: error: argument -s: invalid choice: u'tf' (choose from 'mxnet', 'keras', 'cntk', 'pytorch', 'caffe', 'tensorflow')

So I change tf to tensorflow and run it again. This time the program tell me

IOError: [Errno 2] No such file or directory: u'mmdnn/conversion/examples/data/seagull.jpg'

So I make these dirs and download the jpg file from github. Finally, it could run. However, the model I generate by commend 1-3 cannot pass this test. I edit the line 309 (in the released version the line# is 241) of imagenet_test.py and print the prediction and then I found 3 of 5 is not correct.

image

I am wondering if there is anything wrong in this part.

kitstar commented 6 years ago

Hi @LiYingwei , If you want to try the inference result of the model from mxnet, you can try:

python -m mmdnn.conversion.examples.tensorflow.imagenet_test -s mxnet -p imagenet1k-resnet-152 -n mx_resnet50.py -w resnet50.npy -i    your_image_path

For seagull.jpg, we can got [(21, 0.49690375), (144, 0.40759432), (23, 0.057171866), (146, 0.03169549), (22, 0.0018843169)]

In imagenet 1k labels, 21, 22, 23, 144 are common prediction results. And the mobel will give you all 1000 classes probability, you can show top-n as you want.