ry / tensorflow-resnet

ResNet model in TensorFlow
MIT License
1.66k stars 625 forks source link

Retrain resnet model on new data. #12

Open SilviaLauraPintea opened 8 years ago

SilviaLauraPintea commented 8 years ago

Hi,

I am trying to add a new fc layer on top of the avgpool resnet layer with a different number of outputs to suit my problem. I do not want to only retrain the new fc but also the previous layers. So I need the gradients of the previous layers as well. Unfortunately this does not seem to work. I have tried on a dummy net that I have created to save it (without the gradients -- so similar to the provided resnet meta and ckpt) and then load it and add a new fc layer and this worked without problems.

Here is a snapshot of my retraining code:

# Start the session:
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

# Gets data batches.
trainimages, trainlabels = dataAsTensors(is_training=True, batch_size=FLAGS.batch_size)

# In the default graph:
graph = tf.get_default_graph()
with graph.as_default():    

    # Data saver loading the graph meta only.
    dataSaver = tf.train.import_meta_graph('ResNet-L50.meta')

    for op in graph.get_operations():
        print op.name

    # Get both the 'avg_pool' and the 'images' operations.
    images = graph.get_tensor_by_name("images:0") 
    avgpool = graph.get_tensor_by_name('avg_pool:0')  

    # Define a new fc layer on top of the avg_pool layer 
    logits, _ = fc_num_outs(avgpool, FLAGS.num_classes, FLAGS.avgpool_size)    

    # Define the loss on top of the new fc and a placeholder for the labels 
    labelsVar = tf.placeholder(tf.int64, shape=(FLAGS.batch_size), name='labelsVar')
    loss_ = loss(logits, labelsVar)

    # Define the gradients and get the operation.
    global_step = tf.Variable(0, name='global_step', trainable=False)    
    ops = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = ops.minimize(loss_, global_step=global_step)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord=coord)
    with sess.as_default():

        # Initialize all variables.
        sess.run(tf.initialize_all_variables())

        # Restore the RESNET checkpoint after initialization.
        dataSaver.restore(sess, "ResNet-L50.ckpt")

        for i in range(0, FLAGS.max_steps):
            # Feed the batch images and the labels.
            npImages = trainimages.eval()
            npLabels = trainlabels.eval()

            # Run 1 step of the gradient optimization.
            sess.run(train_op, {images: npImages, labelsVar: npLabels})
            print "Done running grad step.. ", i

            if (i % 100 == 0): # Save the checkpoint
                dataSaver.save(sess, 'resnet_retrained' + str(i) + '.ckpt')

    coord.request_stop()
    coord.join(threads)
    sess.close()

I am not sure why for the resnet model I get this error:

File "retrain.py", line 278, in main retrain() File "retrain.py", line 244, in retrain trainop = ops.minimize(loss, global_step=global_step) File "tensorflow/python/training/optimizer.py", line 193, in minimize grad_loss=grad_loss) File "tensorflow/python/training/optimizer.py", line 250, in compute_gradients colocate_gradients_with_ops=colocate_gradients_with_ops) File "tensorflow/python/ops/gradients.py", line 467, in gradients out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i) File "tensorflow/python/ops/control_flow_ops.py", line 1047, in ZerosLikeOutsideLoop pred = op_ctxt.pred AttributeError: 'NoneType' object has no attribute 'pred'

while for my own toy model the same code seems to work.

Thanks a lot. Cheers, Silvia

KalraA commented 7 years ago

I'm getting the same error! @SilviaLauraPintea Have you found a solution?

leiup commented 7 years ago

How to solve this problem? I also meet the same error! @SilviaLauraPintea @KalraA @ry Thank you very much~

nikste commented 7 years ago

Anybody solved this?

nu1ptr commented 7 years ago

I'm having this issue as well. Has anybody been able to figure it out yet?

rener1199 commented 7 years ago

I also meet the same error! How to solve that? @SilviaLauraPintea @KalraA @ry @nu1ptr Thanks