machrisaa / tensorflow-vgg

VGG19 and VGG16 on Tensorflow
2.21k stars 1.08k forks source link

Regarding training the vgg model #7

Closed wenouyang closed 8 years ago

wenouyang commented 8 years ago

Hi Chris,

If we need to re-train the vgg model, how to modify your current code to the training process? Or are there any tensorflow framework that can incorporate your code into a training process?

Thanks,

wenouyang

machrisaa commented 8 years ago

I have actually created another version to enable the training in another project. In that project, I have done a modification in the last layer and continue to train the network. And the trained result can be saved as a NPY file again.

It may not a perfect solution. So I post it here:

class CustomVgg19:
    def __init__(self, vgg19_npy_path=None):
        if vgg19_npy_path is None:
            path = inspect.getfile(CustomVgg19)
            path = os.path.abspath(os.path.join(path, os.pardir))
            path = os.path.join(path, "vgg.npy")
            vgg19_npy_path = path
            print vgg19_npy_path

        self.data_dict = np.load(vgg19_npy_path).item()
        self.var_dict = {}
        print "npy file loaded"

    def build(self, rgb, train=False, full=True):
        rgb_scaled = rgb * 255.0

        # Convert RGB to BGR
        red, green, blue = tf.split(3, 3, rgb_scaled)
        # assert red.get_shape().as_list()[1:] == [224, 224, 1]
        # assert green.get_shape().as_list()[1:] == [224, 224, 1]
        # assert blue.get_shape().as_list()[1:] == [224, 224, 1]
        bgr = tf.concat(3, [
            blue - VGG_MEAN[0],
            green - VGG_MEAN[1],
            red - VGG_MEAN[2],
        ])

        self.conv1_1 = self.conv_layer(bgr, "conv1_1", train)
        self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2", train)
        self.pool1 = self.avg_pool(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1", train)
        self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2", train)
        self.pool2 = self.avg_pool(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1", train)
        self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2", train)
        self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3", train)
        self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4", train)
        self.pool3 = self.avg_pool(self.conv3_4, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1", train)
        self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2", train)
        self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3", train)
        self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4", train)
        self.pool4 = self.avg_pool(self.conv4_4, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1", train)
        self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2", train)
        self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3", train)
        self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4", train)
        self.pool5 = self.avg_pool(self.conv5_4, 'pool5')

        if full:
            if train:
                with tf.device("/cpu:0"):
                    self.fc6 = self.fc_layer(self.pool5, "fc6", train)
            else:
                self.fc6 = self.fc_layer(self.pool5, "fc6", train)

            assert self.fc6.get_shape().as_list()[1:] == [4096]
            self.relu6 = tf.nn.relu(self.fc6)
            if train:
                self.relu6 = tf.nn.dropout(self.relu6, 0.5)

            self.fc7 = self.fc_layer(self.relu6, "fc7", train)
            self.relu7 = tf.nn.relu(self.fc7)
            if train:
                self.relu7 = tf.nn.dropout(self.relu7, 0.5)

            # replace this one with our own layer of result
            # self.fc8 = self.fc_layer(self.relu7, "fc8", train)
            # self.prob = tf.nn.softmax(self.fc8, name="prob")

            self.fc_custom = self.fc_layer(self.relu7, "fc_custom", train, w_init_shape=[4096, 36],
                                           b_init_shape=[36])
            self.prob = tf.nn.softmax(self.fc_custom, name="prob")

        self.data_dict = None

        if full:
            return self.prob

    def avg_pool(self, bottom, name):
        return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
                              padding='SAME', name=name)

    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
                              padding='SAME', name=name)

    def conv_layer(self, bottom, name, train=False):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name, train)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name, train)
            bias = tf.nn.bias_add(conv, conv_biases)

            relu = tf.nn.relu(bias)
            return relu

    def fc_layer(self, bottom, name, train=False, w_init_shape=None, b_init_shape=None):
        with tf.variable_scope(name):
            shape = bottom.get_shape().as_list()
            dim = 1
            for d in shape[1:]:
                dim *= d
            x = tf.reshape(bottom, [-1, dim])

            weights = self.get_fc_weight(name, train, w_init_shape)
            biases = self.get_bias(name, train, b_init_shape)

            fc = tf.nn.bias_add(tf.matmul(x, weights), biases)

            return fc

    def get_conv_filter(self, layer_name, train=False, init_shape=None):
        return self.get_var(layer_name, 0, "filter", train, init_shape)

    def get_bias(self, layer_name, train=False, init_shape=None):
        return self.get_var(layer_name, 1, "biases", train, init_shape)

    def get_fc_weight(self, layer_name, train=False, init_shape=None):
        return self.get_var(layer_name, 0, "weights", train, init_shape)

    def get_var(self, layer_name, idx, var_name, train, init_shape):
        if train:
            if layer_name in self.data_dict:
                data = self.data_dict[layer_name][idx]
                var = tf.Variable(data, name=var_name)
                print layer_name, var_name, "loaded from npy"
            else:
                var = tf.Variable(tf.truncated_normal(init_shape, 0.01, 0.1), name=var_name)
                print layer_name, var_name, "not found, random var created"
        else:
            data = self.data_dict[layer_name][idx]
            var = tf.constant(data, name=var_name)
            print layer_name, var_name, "loaded from npy"

        if layer_name not in self.var_dict:
            self.var_dict[layer_name] = [None] * 2
        self.var_dict[layer_name][idx] = var

        return var

    def save_npy(self, sess, file_path="./vgg.npy"):
        assert isinstance(sess, tf.Session)

        data_dict = {}

        var, name, idx = self.get_all_var()
        var_out = sess.run(var)

        for i in xrange(len(name)):
            if name[i] not in data_dict:
                data_dict[name[i]] = [None] * 2
            data_dict[name[i]][idx[i]] = var_out[i]

        np.save(file_path, data_dict)
        print("file saved", file_path)
        return file_path
wenouyang commented 8 years ago

Hi Chris, thank you so much. By the way, I am having some questions on understanding the part of converting RGB to BGR, generally, why we need that. It seems to me like normalization in the statistics. Besides, how do you get the VGG_MEAN value. I posted the similar question on stackoverflow http://stackoverflow.com/questions/38711525/regarding-the-image-scaling-operations-for-running-vgg-model

ebigelow commented 8 years ago

@machrisaa

CustomVgg19.save_npy calls CustomVgg19.get_all_var() but this method is not provided... could you please post this as well?

Edit:

def get_all_var(self):
    D = self.var_dict
    var_list = [(var, name, D[name].index(var)) for name in D for var in D[name]]
    return zip(*var_list)

Is this about right?

machrisaa commented 8 years ago

Sorry for late reply. I have uploaded a complete standalone trainable version for the VGG19. It will be much better then the code I provided above. Please check it if you are interested.

wenouyang commented 8 years ago

Thanks.