dhlab-epfl / dhSegment

Generic framework for historical document processing
https://dhlab-epfl.github.com/dhSegment
GNU General Public License v3.0
370 stars 116 forks source link

Multilabel trainning problem #29

Closed vndee closed 5 years ago

vndee commented 5 years ago

Hi, I have used this code to train a Document Layout Analysis model. I set: prediction_type = utils.PredictionType.MULTILABEL And my classes.txt (9 classes) file:

0 0 0 1 0 0 0 0 0 0 0 0 25 255 255 0 1 0 0 0 0 0 0 0 142 130 255 0 0 1 0 0 0 0 0 0 191 130 74 0 0 0 1 0 0 0 0 0 191 14 74 0 0 0 0 1 0 0 0 0 191 181 74 0 0 0 0 0 1 0 0 0 36 13 249 0 0 0 0 0 0 1 0 0 110 49 7 0 0 0 0 0 0 0 1 0 250 246 7 0 0 0 0 0 0 0 0 1

But I've got an error:

Caused by op 'Label2Img/GatherNd', defined at: File "train.py", line 47, in @ex.automain File "/home/it/.local/lib/python3.6/site-packages/sacred/experiment.py", line 137, in automain self.run_commandline() File "/home/it/.local/lib/python3.6/site-packages/sacred/experiment.py", line 260, in run_commandline return self.run(cmd_name, config_updates, named_configs, {}, args) File "/home/it/.local/lib/python3.6/site-packages/sacred/experiment.py", line 209, in run run() File "/home/it/.local/lib/python3.6/site-packages/sacred/run.py", line 221, in call self.result = self.main_function(args) File "/home/it/.local/lib/python3.6/site-packages/sacred/config/captured_function.py", line 46, in captured_function result = wrapped(args, kwargs) File "train.py", line 111, in run num_threads=32)) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 356, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1181, in _train_model return self._train_model_default(input_fn, hooks, saving_listeners) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1208, in _train_model_default input_fn, model_fn_lib.ModeKeys.TRAIN)) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1049, in _get_features_and_labels_from_input_fn self._call_input_fn(input_fn, mode)) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1136, in _call_input_fn return input_fn(kwargs) File "/home/it/Projects/DLA/dhSegment/dh_segment/io/input.py", line 224, in fn label_export = utils.multiclass_to_label_image(label_export, classes_file) File "/home/it/Projects/DLA/dhSegment/dh_segment/utils/labels.py", line 67, in multiclass_to_label_image return tf.gather_nd(c, tf.cast(class_label_tensor, tf.int32)) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3140, in gather_nd "GatherNd", params=params, indices=indices, name=name) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func return func(*args, **kwargs) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3272, in create_op op_def=op_def) File "/home/it/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1768, in init self._traceback = tf_stack.extract_stack() InvalidArgumentError (see above for traceback): Only indices.shape[-1] values between 1 and 7 are currently supported. Requested rank: 9 [[{{node Label2Img/GatherNd}} = GatherNd[Tindices=DT_INT32, Tparams=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Label2Img/GatherNd/params, Cast_1)]]

I guess that we can not train a multilabel classification model with more than 7 classes. Can anyone help me to fix this problems? Thanks.

solivr commented 5 years ago

Hi @vndee,

It seems that this error is due to multiclass_to_label_image which is only used for visualization purposes (tf.summary), and not during training. I'll need to have a closer look to see if there is another/better way for visualizing many labels but for now you can comment the lines related to image summaries (the ones using multiclass_to_label_image). Also regarding your classes.txt file, your classes seems to be mutually exclusive, i.e only one label is given to each color, so you should be using PredictionType.CLASSIFICATION (with only RGB codes in the classes.txt).

vndee commented 5 years ago

I've just fixed my problems by reducing the number of classes to 7. Thanks for you reply, it help me to clearly understand work flow of the model.