Open Ivan-E-Johnson opened 3 months ago
Something like this
class CustomLoss(nn.Module):
def __init__(self, weights):
super(CustomLoss, self).__init__()
self.weights = weights
# Initialize other necessary components here
def forward(self, outputs, targets):
# Calculate each metric
dice_loss = DiceLoss(...)
hausdorff_distance = HausdorffDistanceMetric(...)
# Combine them with weights
total_loss = self.weights['dice'] * dice_loss + self.weights['hausdorff'] * hausdorff_distance
return total_loss
Implement a custom loss function that is a weighted combination of metrics. something like loss = .8 diceCE + .1 hassdorfs + .1 other
implement:
https://lightning.ai/docs/torchmetrics/stable/classification/label_ranking_average_precision.html#label-ranking-average-precision
https://lightning.ai/docs/torchmetrics/stable/pages/overview.html
https://github.com/Project-MONAI/tutorials/blob/main/modules/transform_visualization.ipynb