google / model_search

Apache License 2.0
3.27k stars 465 forks source link

Training on Multiclass Image Dataset #57

Open Monkeydion opened 3 years ago

Monkeydion commented 3 years ago

Is there already a way to use the code for multiclass image datasets? Documentation shows only for binary image datasets. Tried changing the "label_mode" variable in image_data.py to "categorical" and change the return value of the "number_of_classes" function to the number of classes. Still an error.

Monkeydion commented 3 years ago

The first error was found at the metrics_fns.py file because of the incompatibility of shapes between prediction (None, ) and labels (None, {number of classes}). what I did was to add labels=tf.argmax(labels, 1) [to change one hot encodings to a vector containing the class of the highest propbability] under the def _metric_fn(labels, predictions, weights=None) function of def make_accuracy_metric_fn(label_vocabulary=None):

Now the error says that there is incompatibility in sizes for the logits and labels under the loss_fns

tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: logits and labels must be broadcastable: logits_size=[16,4] labels_size=[64,4] [[node Phoenix/Trainer/softmax_cross_entropy_loss/xentropy (defined at D:\Programming\TuKoy\model_search-master\model_search\loss_fns.py:95) ]] [[Phoenix/Trainer/Mean/_161]] (1) Invalid argument: logits and labels must be broadcastable: logits_size=[16,4] labels_size=[64,4] [[node Phoenix/Trainer/softmax_cross_entropy_loss/xentropy (defined at D:\Programming\TuKoy\model_search-master\model_search\loss_fns.py:95) ]]