Visual-modelling / Admin

Repo to store high-level issues
0 stars 0 forks source link

Basic CNN implementation #5

Open ghomasHudson opened 4 years ago

zhemingzuo commented 4 years ago

For efficient 3-D convolution, I suggest to use C3D.

zhemingzuo commented 4 years ago

Loss functions: focal loss and smooth L1 loss

import torch import torch.nn as nn import torch.nn.functional as F

class FocalLoss(nn.Module): 'Focal Loss - https://arxiv.org/abs/1708.02002'

def __init__(self, alpha=0.25, gamma=2):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma

def forward(self, pred_logits, target):
    pred = pred_logits.sigmoid()
    ce = F.binary_cross_entropy_with_logits(pred_logits, target, reduction='none')
    alpha = target * self.alpha + (1. - target) * (1. - self.alpha)
    pt = torch.where(target == 1,  pred, 1 - pred)
    return alpha * (1. - pt) ** self.gamma * ce

class SmoothL1Loss(nn.Module): 'Smooth L1 Loss'

def __init__(self, beta=0.11):
    super().__init__()
    self.beta = beta

def forward(self, pred, target):
    x = (pred - target).abs()
    l1 = x - 0.5 * self.beta
    l2 = 0.5 * x ** 2 / self.beta
    return torch.where(x >= self.beta, l1, l2)