Closed ebgoldstein closed 1 year ago
psuedocode for new dice metric:
for each class argmax (one-hot to label encoding) make y_true mask - find all y_true instances for the class and set to 1. set all other classes to 0 make y_pred mask - find all y_pred instances for the class and set to 1. set all other classes to 0 reshape, reduce, calculate metric for class write to someplace avg the metrics for each class to return mean dice
for x in range(0,tf.shape(y_true)[3]):
y_true_am = tf.argmax(y_true, -1)
....
this is also an implementation: https://stackoverflow.com/questions/72195156/correct-implementation-of-dice-loss-in-tensorflow-keras
def dice_coef(y_true, y_pred, smooth):
y_true_f = tf.reshape(tf.dtypes.cast(y_true, tf.float32), [-1])
y_pred_f = tf.reshape(tf.dtypes.cast(y_pred, tf.float32), [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
return dice
def dice_multi_metric(nclasses, smooth):
def dice_coef_multilabel(y_true, y_pred):
dice = 0
y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
for index in range(nclasses):
dice += dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index], smooth)
return dice/nclasses
return dice_coef_multilabel
def dice_coef_loss(nclasses, smooth):
def dice_MC_coef_loss(y_true, y_pred):
dice = 0
y_pred = tf.one_hot(tf.argmax(y_pred, -1), 4)
for index in range(nclasses):
dice += dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index], smooth)
return 1 - (dice/nclasses)
return dice_MC_coef_loss
As a metric, it is called with:
dice_multi_metric(4, 10e-6)
(for 4 classes and epsilon of 10e6)
for loss, called as:
loss = dice_coef_loss(4, 10e-6)
(for 4 classes and epsilon of 10e6)
this might also be a way to incorporate loss weighting per class..
whoops, the argmax needs to be removed for a loss... forgot about dfifferentiability.. I have verified that the loss and the metrics work in Gym.. waiting to see what a model looks like
Fantastic. Looking forward to trying this out
experimental formulation can be seen in this branch of gym
:
https://github.com/Doodleverse/segmentation_gym/tree/NewDice
can be removed from gym code, and added to utils
instead when we confirm it works..
curves look good, results look good, metrics really match what i am seeing... I think this is working well..
now just for fun i will try to do some loss weighting on the rare class...
Note that everything looks worse numerically - losses are larger and metrics are lower.. but the actual segmentation results look identical to what i had before.. i think the metrics are now just in-line with the reality of taking a mean across classes (vs a mean of the entire scene).. so poor performance on rare classes degrades model metrics, and great performance on background classes does not artificially inflate the metrics.
loss weighting seems to be the key for rare class
here are the curves... note loss is dice loss w/ weighting (not raw dice) - every class has weight of 1, but dev has weight of 2
there are two things i want to say about the dice metric right now (The code block is included below for easy reference):
1) predictions (
y_pred
) come out of the model w/o being argmax'ed, so i think we might need a line to actually convert them to true one-hot encoding... for example, this line wouldargmax
each prediction, and then reconvert them to one-hot encoding.. this would convert all values in the matrix to 0 and 1... somethink like this:y_pred = tf.one_hot(tf.argmax(y_pred, -1), num_classes)
w/o doing this, the
tf.reduce_sum
step might actually be operating on values other than 0 and 1.. and I think values between 0 and 1 are impacting the dice score right now...2) This metric does not seem to be a 'mean dice' as it stands currently.. The formulation of dice as it stands compares two reshaped one-hot tensors. So it is dice for the entire image, and not mean dice. Dice instead needs to be computed for each class, then averaged...
As a thought experiment ... imagine a multiclass problem with strong class imbalance. all classes are predicted well except one. And there might be only a handful of pixels for this rare class that is predicted poorly. But misprediction of few pixels will not strongly impact the metric as it is written (its just a few mispredicted pixels). But if the Dice for each class was computed, the dice for this rare class would be 0, and when averaged with the other classes, could strongly impact the mean Dice score.
The solution to this issue is to rewrite the dice metric to calculate stats for each class, and then average these stats.
https://github.com/Doodleverse/doodleverse_utils/blob/691abe7f3c3d354df50c2958cf727c4037950222/doodleverse_utils/model_imports.py#L983-L1013