pudae / tensorflow-densenet

Tensorflow-DenseNet with ImageNet Pretrained Models
Apache License 2.0
169 stars 59 forks source link

Model for transfer-learning #10

Closed ilkarman closed 6 years ago

ilkarman commented 6 years ago

Hi Pudae, thanks a lot for this super-neat implementation. I am trying to load the model (without the pre-processing functions, etc) so that I can chop off the last fc layer and stick my own (to re-train on my data-set). However, I was having a bit of an issue loading the model without any of the functions attached:

def create_symbol(chkpt_dir=CHKPT_DIR):
    with tf.Graph().as_default():
        slim.get_or_create_global_step()
        # Not possible to have channels first?
        X = tf.placeholder(tf.float32, 
                           shape=(None, 224, 224, 3))
        logits, endpoints = densenet.densenet121(
            X,
            num_classes=1000, 
            is_training=True,
            reuse=None)

        print(logits)
        print(endpoints)

        variables_to_restore = slim.get_variables_to_restore()

        sess = tf.Session()
        saver = tf.train.Saver(variables_to_restore)

        init_op = tf.group(
          tf.global_variables_initializer(),
          tf.local_variables_initializer())

        checkpoint_path = os.path.join(chkpt_dir, 'tf-densenet121.ckpt')
        sess.run(init_op)
        saver.restore(sess, checkpoint_path)
        return endpoints, sess

sym, sess = create_symbol()

Seems to complain about finding biases:

--------------------------------------------------------------------------
NotFoundError                             Traceback (most recent call last)
/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1322     try:
-> 1323       return fn(*args)
   1324     except errors.OpError as e:

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1301                                    feed_dict, fetch_list, target_list,
-> 1302                                    status, run_metadata)
   1303 

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    472             compat.as_text(c_api.TF_Message(self.status.status)),
--> 473             c_api.TF_GetCode(self.status.status))
    474     # Delete the underlying status object from memory otherwise it stays alive

NotFoundError: Key densenet121/dense_block2/conv_block7/x1/Conv/biases not found in checkpoint
    [[Node: save/RestoreV2_158 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_158/tensor_names, save/RestoreV2_158/shape_and_slices)]]
    [[Node: save/RestoreV2_16/_1181 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_2392_save/RestoreV2_16", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

During handling of the above exception, another exception occurred:

NotFoundError                             Traceback (most recent call last)
<ipython-input-5-9d5aca7f5637> in <module>()
----> 1 sym, sess = create_symbol()

<ipython-input-4-e283472f7a13> in create_symbol(chkpt_dir)
     25         checkpoint_path = os.path.join(chkpt_dir, 'tf-densenet121.ckpt')
     26         sess.run(init_op)
---> 27         saver.restore(sess, checkpoint_path)
     28         return network_fn, sess

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py in restore(self, sess, save_path)
   1664     if context.in_graph_mode():
   1665       sess.run(self.saver_def.restore_op_name,
-> 1666                {self.saver_def.filename_tensor_name: save_path})
   1667     else:
   1668       self._build_eager(save_path, build_save=False, build_restore=True)

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1119       results = self._do_run(handle, final_targets, final_fetches,
-> 1120                              feed_dict_tensor, options, run_metadata)
   1121     else:
   1122       results = []

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1315     if handle is None:
   1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317                            options, run_metadata)
   1318     else:
   1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1334         except KeyError:
   1335           pass
-> 1336       raise type(e)(node_def, op, message)
   1337 
   1338   def _extend_graph(self):

NotFoundError: Key densenet121/dense_block2/conv_block7/x1/Conv/biases not found in checkpoint
    [[Node: save/RestoreV2_158 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_158/tensor_names, save/RestoreV2_158/shape_and_slices)]]
    [[Node: save/RestoreV2_16/_1181 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_2392_save/RestoreV2_16", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

Caused by op 'save/RestoreV2_158', defined at:
  File "/anaconda/envs/py35/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/anaconda/envs/py35/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/anaconda/envs/py35/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tornado/platform/asyncio.py", line 112, in start
    self.asyncio_loop.run_forever()
  File "/anaconda/envs/py35/lib/python3.5/asyncio/base_events.py", line 345, in run_forever
    self._run_once()
  File "/anaconda/envs/py35/lib/python3.5/asyncio/base_events.py", line 1312, in _run_once
    handle._run()
  File "/anaconda/envs/py35/lib/python3.5/asyncio/events.py", line 125, in _run
    self._callback(*self._args)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tornado/ioloop.py", line 760, in _run_callback
    ret = callback()
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 536, in <lambda>
    self.io_loop.add_callback(lambda : self._handle_events(self.socket, 0))
  File "/anaconda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/anaconda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/anaconda/envs/py35/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-9d5aca7f5637>", line 1, in <module>
    sym, sess = create_symbol()
  File "<ipython-input-4-e283472f7a13>", line 19, in create_symbol
    saver = tf.train.Saver(variables_to_restore)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1218, in __init__
    self.build()
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1227, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1263, in _build
    build_save=build_save, build_restore=build_restore)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 751, in _build_internal
    restore_sequentially, reshape)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 427, in _AddRestoreOps
    tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 267, in restore_op
    [spec.tensor.dtype])[0])
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1021, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/anaconda/envs/py35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Key densenet121/dense_block2/conv_block7/x1/Conv/biases not found in checkpoint
    [[Node: save/RestoreV2_158 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_158/tensor_names, save/RestoreV2_158/shape_and_slices)]]
    [[Node: save/RestoreV2_16/_1181 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_2392_save/RestoreV2_16", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

However, the end-points seem ok:

('densenet121/dense_block4', <tf.Tensor 'densenet121/dense_block4/conv_block16/concat:0' shape=(?, 7, 7, 1024) dtype=float32>), ('densenet121/logits', <tf.Tensor 'densenet121/logits/Relu:0' shape=(?, 1, 1, 1000) dtype=float32>), ('predictions', <tf.Tensor 'densenet121/predictions/Reshape_1:0' shape=(?, 1, 1, 1000) dtype=float32>)])

So that I can just extract: 'densenet121/dense_block4/conv_block16/concat:0' and add my own: ('densenet121/logits', <tf.Tensor 'densenet121/logits/Relu:0' shape=(?, 1, 1, 16) dtype=float32>), ('predictions', <tf.Tensor 'densenet121/predictions/Reshape_1:0' shape=(?, 1, 1, 16) dtype=float32>)])

Also was curious if there was a reason that shape is channels-last, since I thought channels-first is faster for cuDNN training?

Thanks

pudae commented 6 years ago

Hi ilkarman~

Seems to complain about finding biases:

I turned off bias with _slim.argscope. Try like this, plz.

with slim.arg_scope(densenet.densenet_arg_scope):
    logits, endpoints = densenet.densenet121(
       X,
       num_classes=1000, 
       [is_training=True,
            reuse=None)

Also was curious if there was a reason that shape is channels-last, since I thought channels-first is faster for cuDNN training?

You're right. But I just used slim's default data format, NHWC. If you want to use NCHW, you can do it simply adding data_format argument to slim.arg_scope.

Thanks~

ilkarman commented 6 years ago

Thanks that has seemed to do the trick!

I noticed that I wasn't able to pass num_classes=None, however. I want to add a fully-connected layer which I later train with sigmoid_cross_entropy() loss and I think in the code this defaults to softmax activation.

Wanted to ask if this approach looks ok:

# Extract checkpoint
CHKPT_DIR = 'tfdensenet/'
if not os.path.isdir(CHKPT_DIR):
    with tarfile.open("tf-densenet121.tar.gz") as t:
        t.extractall(CHKPT_DIR)

# Load variables into model (without this nothing is restored)
tf.train.get_or_create_global_step()

# Place-holders
X = tf.placeholder(tf.float32, shape=[None, WIDTH, HEIGHT, CHANNELS])
y = tf.placeholder(tf.float32, shape=[None, CLASSES])

# Import symbol
dense_args = densenet.densenet_arg_scope()
print(dense_args)  # Add NCHW later

with slim.arg_scope(dense_args):
    logits, _ = densenet.densenet121(X, num_classes=CLASSES, is_training=True, reuse=None)

# Collect variables to restore from checkpoint
variables_to_restore = slim.get_variables_to_restore(exclude=['densenet121/logits', 'predictions'])
#print(variables_to_restore)

model_path = os.path.join(CHKPT_DIR, "tf-densenet121.ckpt")
print(model_path)

init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)  

# Reshape logits to (None, CLASSES) since my label is (None, CLASSES)
sym = tf.reshape(logits, shape=[-1, CLASSES])

# Loss
loss_fn = tf.nn.sigmoid_cross_entropy_with_logits(logits=sym, labels=y)
loss = tf.reduce_mean(loss_fn)

optimizer = tf.train.AdamOptimizer(LR, beta1=0.9, beta2=0.999)
training_op = optimizer.minimize(loss)

print("Loading pre-trained weights")
sess = tf.Session()
init_fn(sess)  # Load from checkpoint

# Initialise uninitialised vars (FC layer & Adam)
init_uninitialized(sess)

Previously with resnet50 I was able to do it like so:

def get_symbol(model_name, in_tensor, chkpoint=CHKPOINT, out_features=CLASSES):
    if model_name == 'resnet50':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            base_model, _ = resnet_v1.resnet_v1_50(X, None, is_training=True)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore()
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)   
        # Attach extra layers
        fc = tf.layers.dense(base_model, out_features, name='output')
        # Activation function will be included in loss
        sym = tf.reshape(fc, shape=[-1, out_features])

    elif model_name == 'densenet121':
        raise ValueError("Densenet is not yet implemented")
        # TODO: https://github.com/pudae/tensorflow-densenet/issues/10
    else:
        raise ValueError("Unknown model-name")

    return sym, init_fn

def init_symbol(sym, out_tensor, lr=LR):
    loss_fn = tf.nn.sigmoid_cross_entropy_with_logits(logits=sym, labels=y)
    loss = tf.reduce_mean(loss_fn)
    optimizer = tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999)
    training_op = optimizer.minimize(loss)
    return training_op, loss

def init_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
    if len(not_initialized_vars):
        #print("Initialising: ", not_initialized_vars)
        sess.run(tf.variables_initializer(not_initialized_vars))

# Place-holders
X = tf.placeholder(tf.float32, shape=[None, WIDTH, HEIGHT, CHANNELS])
y = tf.placeholder(tf.float32, shape=[None, CLASSES])

# Create symbol
sym, init_fn = get_symbol(model_name='resnet50', in_tensor=X)

# Create training operation
model, loss = init_symbol(sym=sym, out_tensor=y)

# Launch session and load model from checkpoint
sess = tf.Session()

# Temp
if PRETRAINED_WEIGHTS:
    print("Loading pre-trained weights")
    init_fn(sess)  # Load from checkpoint

# Initialise uninitialised vars (FC layer & Adam)
init_uninitialized(sess)

Edit: When I later want to use this model for scoring (having trained it on my data-set), would it be possible to pass in a training-flag (for dropout and batch-norm to act correctly) to feed-dict or would the process be a bit more complicated?

Thanks every much! Ilia