Open tanmay4269 opened 3 months ago
At the moment this is not supported out-fo-the-box, but you can do it by computing loss function per class and averaging.
Here is a pseudo-code:
n_classes = 7
class_weights = torch.tensor([1, 1, 1, 1, 1, 1, 1])
predicted_tensor = torch.ones([5, n_classes, 512, 512])
gt_tensor = torch.ones([5, n_classes, 512, 512])
criterion = smp.losses.JaccardLoss("binary")
loss = 0
for i in range(n_classes):
class_prediction = predicted_tensor[:, i].unsqueeze(1) # B, 1, H, W
class_gt = gt_tensor[:, i].unsqueeze(1) # B, 1, H, W
class_weight = class_weights[i]
class_loss = criterion(class_prediction, class_gt)
loss += (class_loss * class_weight)
How to set class weights for different loss functions? Like Jaccard Loss