mthorrell / gboost_module

Gradient Boosting Modules for pytorch
MIT License
9 stars 0 forks source link

Add ordinal loss function #8

Open mthorrell opened 3 months ago

mthorrell commented 3 months ago

One key use case for gboost_module is enabling people to use elaborate loss functions with XGBoost. Ordinal Classification loss is one such function as it is fairly complicated and requires fitting parameters within the loss function. A basic version of this is below. Let's make a directory for examples like this and start filling things in:

class ORDLogistic(torch.nn.Module):
    def __init__(self, n_ord):
        super(ORDLogistic, self).__init__()
        self.n_ord = n_ord
        self.breakpoints = torch.nn.Parameter(torch.Tensor(list(range(n_ord - 1))) - (n_ord - 2.0)/2.0)

    def forward(self, pred, actual):
        bmat = torch.tile(
            self.breakpoints.reshape([1, -1]),
            [actual.shape[0], 1]
        )
        predmat = torch.tile(pred, [1, self.n_ord - 1])
        cum_probs = 1.0 / ( torch.exp(predmat - bmat) + 1.0 )
        probs = torch.diff(
            cum_probs,
            prepend=torch.zeros([actual.shape[0], 1]),
            append=torch.ones([actual.shape[0], 1])
        )

        return torch.sum(
            torch.nn.functional.one_hot(
                actual.long(), num_classes=self.n_ord
            ) * -torch.log(probs)
        ) / actual.shape[0]

Usage with gboost_module

class ORDBoost(torch.nn.Module):
    def __init__(self, train_size, input_size, n_categories, params):
        super(ORDBoost, self).__init__()
        self.lgb = xgbmodule.XGBModule(train_size, input_size, 1, params)
        self.ord_loss = ORDLogistic(n_categories)

    def forward(self, input_array, actual):
        ord_output = self.lgb(input_array)
        loss = self.ord_loss(ord_output, actual)
        return ord_output, loss

    def gb_step(self, input_array):
        self.lgb.gb_step(input_array)
mthorrell commented 3 months ago

More stable versions here:

Using Logistic:

class ORDLogistic(torch.nn.Module):
    def __init__(self, n_ord):
        super(ORDLogistic, self).__init__()
        self.n_ord = n_ord
        self.breakpoints = torch.nn.Parameter(torch.Tensor(list(range(n_ord - 1))) - (n_ord - 2.0)/2.0)

    def get_pred_probs(self, pred):
        bmat = torch.tile(
            self.breakpoints.reshape([1, -1]),
            [pred.shape[0], 1]
        )
        predmat = torch.tile(pred, [1, self.n_ord - 1])

        cum_probs = torch.sigmoid(bmat - predmat)

        eps = 1e-8
        probs = torch.diff(
            cum_probs,
            prepend=torch.zeros([pred.shape[0], 1]),
            append=torch.ones([pred.shape[0], 1])
        ).clamp(min=eps)
        return probs

    def forward(self, pred, actual):
        bmat = torch.tile(
            self.breakpoints.reshape([1, -1]),
            [actual.shape[0], 1]
        )
        predmat = torch.tile(pred, [1, self.n_ord - 1])

        cum_probs = torch.sigmoid(bmat - predmat)

        eps = 1e-8
        probs = torch.diff(
            cum_probs,
            prepend=torch.zeros([actual.shape[0], 1]),
            append=torch.ones([actual.shape[0], 1])
        ).clamp(min=eps)

        return torch.sum(
            torch.nn.functional.one_hot(
                actual.long(), num_classes=self.n_ord
            ) * -torch.log(probs)
        ) / actual.shape[0]

Using Probit:

class ORDProbit(torch.nn.Module):
    def __init__(self, n_ord):
        super(ORDProbit, self).__init__()
        self.n_ord = n_ord
        self.breakpoints = torch.nn.Parameter(
            (torch.Tensor(list(range(n_ord - 1))) - (n_ord - 2.0)/2.0)
        ).double()
        self.normal = torch.distributions.Normal(0, 1)

    def get_pred_probs(self, pred):
        pred = pred.double()
        bmat = torch.tile(
            self.breakpoints.reshape([1, -1]),
            [pred.shape[0], 1]
        )
        predmat = torch.tile(pred, [1, self.n_ord - 1])

        #cum_probs = torch.sigmoid(bmat - predmat)
        cum_probs = self.normal.cdf(bmat - predmat)

        eps = 1e-8
        probs = torch.diff(
            cum_probs,
            prepend=torch.zeros([pred.shape[0], 1]),
            append=torch.ones([pred.shape[0], 1])
        ).clamp(min=eps)
        return probs

    def forward(self, pred, actual):

        pred = pred.double()
        bmat = torch.tile(
            self.breakpoints.reshape([1, -1]),
            [actual.shape[0], 1]
        )
        predmat = torch.tile(pred, [1, self.n_ord - 1])

        cum_probs = self.normal.cdf(bmat - predmat)

        eps = 1e-32
        probs = torch.diff(
            cum_probs,
            prepend=torch.zeros([actual.shape[0], 1]),
            append=torch.ones([actual.shape[0], 1])
        ).clamp(min=eps)

        return torch.sum(
            torch.nn.functional.one_hot(
                actual.long(), num_classes=self.n_ord
            ) * -torch.log(probs)
        ) / actual.shape[0]