hellochick / ICNet-tensorflow

TensorFlow-based implementation of "ICNet for Real-Time Semantic Segmentation on High-Resolution Images".
405 stars 153 forks source link

model predicting all zeros after training with own dataset #65

Open FisherShi opened 6 years ago

FisherShi commented 6 years ago

I'm training with my own dataset trying to classify 3 classes. training images are (600,800,3) RGB and labels are (600,800,1) with 0,1,2 representing (none, road and cars). I would not be surprised if the model could not predict cars well as the majority of the training images are road and background, but for some reason loss decreased to ~0.6 and stopped improving, and the model tend to predict everything as 0. Any suggestions about how to debug is highly appreciated.

alexw92 commented 6 years ago

If the majority of your groundtruth is class 0 this can happen. I had a similiar problem with my own dataset and solved it by applying class weights to the particular classes. That is, a relative big weight is applied to the class which is predicted too often and smaller weight numbers to the rest of them. This should "punish" the network more for wrong predictions of the often used class (in your case 0).

cy-goh commented 5 years ago

@alexw92, do you add a small snippets of your recommendations here?

I thought bigger weight should be added to the class which is predicted less often instead?

alexw92 commented 5 years ago

@chiyuan-goh Yes you are right. Bigger weights should be applied to rare features of course. For weight calculation I calculated so called median-frequency weights introduced by Eigen et al. in their experiments on Pascal Voc data (Predicting deoth, surface normals and semantic labels with a common multi-scale convolutional architecture). In the train file I added the weights as shown in the following code:

def create_loss(output, label, num_classes, ignore_label):
    raw_pred = tf.reshape(output, [-1, num_classes])
    label = prepare_label(label, tf.stack(output.get_shape()[1:3]), num_classes=num_classes, one_hot=False)
    label = tf.reshape(label, [-1,])

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)

    # added class weights  un, bui, wo, wa, ro, res
    #class_weights = tf.constant([0.153, 0.144, 0.245, 0.022, 0.11, 0.325])

    #  class weight calculation used in segnet
    global dataset_class_weights
    if dataset_class_weights is None:
        dataset_class_weights = tf.constant([1 for i in range(num_classes)])
    class_weights = dataset_class_weights#tf.constant([0.975644, 1.025603, 0.601745, 6.600600, 1.328684, 0.454776])
    weights = tf.gather(class_weights, gt)

    loss = tf.losses.sparse_softmax_cross_entropy(logits=pred, labels=gt, weights=weights)
    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss