How to load the pretrained model? #3

Closed UpCoder closed 6 years ago

UpCoder commented 6 years ago

Hi, I am interested about how to load the pretrained model parameters of encoder part. Can you give me some suggestion? Thanks!

HasnainRaz commented 6 years ago

Hey, You can nest the encoder into a new variable scope, and then use tf.get_collection to get the names of all variables in the encoder as a list, and pass this list to the tf.train.Saver.restore as the argument for var_list, so essentially something like this:

def model(self, x, training):
        Defines the complete graph model for the Tiramisu based on the provided
            x: Tensor, input image to segment.
            training: Bool Tesnor, indicating whether training or not.

            x: Tensor, raw unscaled logits of predicted segmentation.
        concats = []
        with tf.variable_scope('encoder'):
            x = tf.layers.conv2d(x,
                                kernel_size=[3, 3],
                                strides=[1, 1],
                                dilation_rate=[1, 1],
            print("First Convolution Out: ", x.get_shape())
            for block_nb in range(0, self.nb_blocks):
                dense = self.dense_block(x, training, block_nb, 'down_dense_block_' + str(block_nb))

                if block_nb != self.nb_blocks - 1:
                    x = tf.concat([x, dense], axis=3, name='down_concat_' + str(block_nb))
                    x = self.transition_down(x, training, x.get_shape()[-1], 'trans_down_' + str(block_nb))
                    print("Downsample Out:", x.get_shape())

            x = dense
            print("Bottleneck Block: ", dense.get_shape())

           .....Decoder continues downwards

Then when you want to restore the encoder only:

encoder_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder')
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, ckpt_name, var_list=encoder_vars)
UpCoder commented 6 years ago

Thank you very much. I also want to know how can we download the specifical pretrained model? Can you provide the download link of specifical pretrained DenseNet parameters? Thanks again.

HasnainRaz commented 6 years ago

Currently I do not have a pretrained model file, you'll probably have to do the training first yourself.

UpCoder commented 6 years ago

OK, Thank you!