EtashGuha / R2CCP

MIT License
19 stars 5 forks source link

Further information #9

Open TheMrguiller opened 4 months ago

TheMrguiller commented 4 months ago

Hello @shloknatarajan,

I am analysing your magnificent article and would like to obtain further clarifications on the calculation of the loss and the determination of the intervals. In my case, I am already working with intervals from 0 to 1, so your methodology seems very appropriate to me. Additionally, I am trying to implement your technique for a project, as it could be beneficial. I would also like to receive any recommendations if possible.

Thank you very much

EtashGuha commented 2 months ago

Sorry for the delay, I'd divide the intervals pretty evenly from 0 to 1 into something like 50 bins. The loss is just as in the paper and based on the intervals here. You should be able to use our library if it that's easier?

TheMrguiller commented 2 months ago

Thank you so much for your reply, @EtashGuha. I just wanted to use your loss function as I am conducting a set of trials with different objectives. I used the loss function that was published in TorchCP, , which is based on the definition from your paper.

During my experimentation, I encountered a rather strange situation: when selecting the tau parameter, it seemed to cause my model to overfit more compared to when using no tau at all or a very small tau, on the order of 1e-3. Keeping your paper and your thought process in mind, I came up with a new version of the loss function. I’m not sure if it aligns exactly with your original intent, but it works quite well in my case:

class R2ccpLoss(nn.Module):
    """
    Conformal Prediction via Regression-as-Classification (Etash Guha et al., 2023).
    Paper: https://neurips.cc/virtual/2023/80610

    :param p: norm of distance measure.
    :param tau: weight of the ‘entropy’ term.
    :param midpoints: the midpoint of each bin.
    """

    def __init__(self, p, tau, midpoints,sigma=0.1):
        super().__init__()
        self.p = p
        self.tau = tau
        self.midpoints = midpoints
        self.sigma = sigma
        self.distance_matrix= self.generate_distance_matrix(midpoints)

    def generate_distance_matrix(self,values):
        """
        Generate a distance matrix for a set of continuous or discrete values.

        Args:
        - values (torch.Tensor): Continuous or discrete class values.

        Returns:
        - torch.Tensor: Distance matrix.
        """
        values = values.unsqueeze(1)  # Convert to column vector
        distance_matrix = torch.abs(values - values.T)  # Compute pairwise absolute differences
        return distance_matrix

    def forward(self, preds, target, weights=None):
        """ 
        Compute the cross-entropy loss with regularization

        :param preds: the predictions logits of the model. The shape is batch*K.
        :param target: the truth values. The shape is batch*1.
        :param weights: optional weights for each sample. The shape is batch*1.
        """
        assert not target.requires_grad
        if preds.size(0) != target.size(0):
            raise IndexError(f"Batch size of preds must be equal to the batch size of target.")

        target = target.view(-1, 1)
        abs_diff = torch.abs(target - self.midpoints.to(preds.device).unsqueeze(0))

        preds_=torch.nn.functional.softmax(preds, dim=1)

        preds_=preds
        cross_entropy = torch.sum((abs_diff ** self.p) * preds_, dim=1)

        penalties = torch.zeros_like(cross_entropy)
        closest_index = torch.argmin(abs_diff, dim=1)
        new_target = torch.zeros(preds.size(0), preds.size(1), device=preds.device)
        new_target[torch.arange(preds.size(0)), closest_index] = 1.0
        self.distance_matrix = self.distance_matrix.to(preds.device)
        penalties = self.distance_matrix[closest_index]
        penalties_values = torch.sum(preds_ * penalties, dim=1)
        losses = cross_entropy + self.tau * penalties_values
        if weights is not None:
            losses = losses * weights
        loss = losses.mean()
        return loss

Another question i have is related to your code specially the part of get_all_scores. It seems that in bad_indices you eliminate those labels that where inferior or superior to your midpoints lateral cases, which it is strange because for example in my case you are not taking into account the values between 0 to 0.025 and 0.975 to 1. Is there any particular reason?

def get_all_scores(self,range_vals, cal_pred, y):
        step_val = (max(range_vals) - min(range_vals))/(len(range_vals) - 1)
        indices_up = torch.ceil((y - min(range_vals))/step_val).squeeze()
        indices_down = torch.floor((y - min(range_vals))/step_val).squeeze()

        how_much_each_direction = ((y.squeeze() - min(range_vals))/step_val - indices_down)

        weight_up = how_much_each_direction
        weight_down = 1 - how_much_each_direction

        bad_indices = torch.where(torch.logical_or(y.squeeze() > max(range_vals), y.squeeze() < min(range_vals)))
        indices_up[bad_indices] = 0
        indices_down[bad_indices] = 0

        scores = cal_pred
        all_scores = scores[torch.arange(cal_pred.shape[0]), indices_up.long()] * weight_up + scores[torch.arange(cal_pred.shape[0]), indices_down.long()] * weight_down
        all_scores[bad_indices] = 0
        return scores, all_scores
EtashGuha commented 1 month ago

So the idea with bad_indices is that, for $y$ outside the prescribed range (max of the range_vals to min of the range_vals), we automatically do not include that $y$ in our interval. i.e. we give a score of 0 to any candidate value outside the rangevals. This is proven to not be too bad with high probability in terms of coverage. It also makes the code cleaner. Practically, this edge case would only happen if you are doing inference on a datapoint for whose true label lies outside the boundary of the trianing dataset, which should be rare in most cases so not too common. Hope that helps!