Doodleverse / doodleverse_utils

A set of common Doodleverse tools and utilities
MIT License
4 stars 3 forks source link

Dice #7

Closed ebgoldstein closed 1 year ago

ebgoldstein commented 1 year ago

there are two things i want to say about the dice metric right now (The code block is included below for easy reference):

1) predictions (y_pred) come out of the model w/o being argmax'ed, so i think we might need a line to actually convert them to true one-hot encoding... for example, this line would argmax each prediction, and then reconvert them to one-hot encoding.. this would convert all values in the matrix to 0 and 1... somethink like this: y_pred = tf.one_hot(tf.argmax(y_pred, -1), num_classes)

w/o doing this, the tf.reduce_sum step might actually be operating on values other than 0 and 1.. and I think values between 0 and 1 are impacting the dice score right now...

2) This metric does not seem to be a 'mean dice' as it stands currently.. The formulation of dice as it stands compares two reshaped one-hot tensors. So it is dice for the entire image, and not mean dice. Dice instead needs to be computed for each class, then averaged...

As a thought experiment ... imagine a multiclass problem with strong class imbalance. all classes are predicted well except one. And there might be only a handful of pixels for this rare class that is predicted poorly. But misprediction of few pixels will not strongly impact the metric as it is written (its just a few mispredicted pixels). But if the Dice for each class was computed, the dice for this rare class would be 0, and when averaged with the other classes, could strongly impact the mean Dice score.

The solution to this issue is to rewrite the dice metric to calculate stats for each class, and then average these stats.

https://github.com/Doodleverse/doodleverse_utils/blob/691abe7f3c3d354df50c2958cf727c4037950222/doodleverse_utils/model_imports.py#L983-L1013

ebgoldstein commented 1 year ago

psuedocode for new dice metric:

for each class
    argmax (one-hot to label encoding)
    make y_true mask - find all y_true instances for the class and set to 1. set all other classes to 0
    make y_pred mask - find all y_pred instances for the class and set to 1. set all other classes to 0
    reshape, reduce, calculate metric for class
    write to someplace

avg the metrics for each class to return mean dice

for x in range(0,tf.shape(y_true)[3]):
    y_true_am = tf.argmax(y_true, -1)
    ....
ebgoldstein commented 1 year ago

this is also an implementation: https://stackoverflow.com/questions/72195156/correct-implementation-of-dice-loss-in-tensorflow-keras

ebgoldstein commented 1 year ago
def dice_coef(y_true, y_pred, smooth):
    y_true_f = tf.reshape(tf.dtypes.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.dtypes.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return dice

def dice_multi_metric(nclasses, smooth):

    def dice_coef_multilabel(y_true, y_pred):
        dice = 0
        y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
        for index in range(nclasses):
            dice += dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index], smooth)
        return dice/nclasses

    return dice_coef_multilabel

def dice_coef_loss(nclasses, smooth):

    def dice_MC_coef_loss(y_true, y_pred):
        dice = 0
        y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
        for index in range(nclasses):
            dice += dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index], smooth)
        return 1 - (dice/nclasses)

    return dice_MC_coef_loss

As a metric, it is called with:

dice_multi_metric(4, 10e-6) (for 4 classes and epsilon of 10e6)

for loss, called as:

loss = dice_coef_loss(4, 10e-6) (for 4 classes and epsilon of 10e6)

ebgoldstein commented 1 year ago

this might also be a way to incorporate loss weighting per class..

ebgoldstein commented 1 year ago

whoops, the argmax needs to be removed for a loss... forgot about dfifferentiability.. I have verified that the loss and the metrics work in Gym.. waiting to see what a model looks like

dbuscombe-usgs commented 1 year ago

Fantastic. Looking forward to trying this out

ebgoldstein commented 1 year ago

experimental formulation can be seen in this branch of gym:

https://github.com/Doodleverse/segmentation_gym/tree/NewDice

can be removed from gym code, and added to utils instead when we confirm it works..

ebgoldstein commented 1 year ago

NOAA_NewLoss_trainhist_8

ebgoldstein commented 1 year ago

curves look good, results look good, metrics really match what i am seeing... I think this is working well..

ebgoldstein commented 1 year ago

now just for fun i will try to do some loss weighting on the rare class...

ebgoldstein commented 1 year ago

Note that everything looks worse numerically - losses are larger and metrics are lower.. but the actual segmentation results look identical to what i had before.. i think the metrics are now just in-line with the reality of taking a mean across classes (vs a mean of the entire scene).. so poor performance on rare classes degrades model metrics, and great performance on background classes does not artificially inflate the metrics.

ebgoldstein commented 1 year ago

loss weighting seems to be the key for rare class image

here are the curves... note loss is dice loss w/ weighting (not raw dice) - every class has weight of 1, but dev has weight of 2 image