Open TheMrguiller opened 4 months ago
Sorry for the delay, I'd divide the intervals pretty evenly from 0 to 1 into something like 50 bins. The loss is just as in the paper and based on the intervals here. You should be able to use our library if it that's easier?
Thank you so much for your reply, @EtashGuha. I just wanted to use your loss function as I am conducting a set of trials with different objectives. I used the loss function that was published in TorchCP, , which is based on the definition from your paper.
During my experimentation, I encountered a rather strange situation: when selecting the tau parameter, it seemed to cause my model to overfit more compared to when using no tau at all or a very small tau, on the order of 1e-3. Keeping your paper and your thought process in mind, I came up with a new version of the loss function. I’m not sure if it aligns exactly with your original intent, but it works quite well in my case:
class R2ccpLoss(nn.Module):
"""
Conformal Prediction via Regression-as-Classification (Etash Guha et al., 2023).
Paper: https://neurips.cc/virtual/2023/80610
:param p: norm of distance measure.
:param tau: weight of the ‘entropy’ term.
:param midpoints: the midpoint of each bin.
"""
def __init__(self, p, tau, midpoints,sigma=0.1):
super().__init__()
self.p = p
self.tau = tau
self.midpoints = midpoints
self.sigma = sigma
self.distance_matrix= self.generate_distance_matrix(midpoints)
def generate_distance_matrix(self,values):
"""
Generate a distance matrix for a set of continuous or discrete values.
Args:
- values (torch.Tensor): Continuous or discrete class values.
Returns:
- torch.Tensor: Distance matrix.
"""
values = values.unsqueeze(1) # Convert to column vector
distance_matrix = torch.abs(values - values.T) # Compute pairwise absolute differences
return distance_matrix
def forward(self, preds, target, weights=None):
"""
Compute the cross-entropy loss with regularization
:param preds: the predictions logits of the model. The shape is batch*K.
:param target: the truth values. The shape is batch*1.
:param weights: optional weights for each sample. The shape is batch*1.
"""
assert not target.requires_grad
if preds.size(0) != target.size(0):
raise IndexError(f"Batch size of preds must be equal to the batch size of target.")
target = target.view(-1, 1)
abs_diff = torch.abs(target - self.midpoints.to(preds.device).unsqueeze(0))
preds_=torch.nn.functional.softmax(preds, dim=1)
preds_=preds
cross_entropy = torch.sum((abs_diff ** self.p) * preds_, dim=1)
penalties = torch.zeros_like(cross_entropy)
closest_index = torch.argmin(abs_diff, dim=1)
new_target = torch.zeros(preds.size(0), preds.size(1), device=preds.device)
new_target[torch.arange(preds.size(0)), closest_index] = 1.0
self.distance_matrix = self.distance_matrix.to(preds.device)
penalties = self.distance_matrix[closest_index]
penalties_values = torch.sum(preds_ * penalties, dim=1)
losses = cross_entropy + self.tau * penalties_values
if weights is not None:
losses = losses * weights
loss = losses.mean()
return loss
Another question i have is related to your code specially the part of get_all_scores. It seems that in bad_indices you eliminate those labels that where inferior or superior to your midpoints lateral cases, which it is strange because for example in my case you are not taking into account the values between 0 to 0.025 and 0.975 to 1. Is there any particular reason?
def get_all_scores(self,range_vals, cal_pred, y):
step_val = (max(range_vals) - min(range_vals))/(len(range_vals) - 1)
indices_up = torch.ceil((y - min(range_vals))/step_val).squeeze()
indices_down = torch.floor((y - min(range_vals))/step_val).squeeze()
how_much_each_direction = ((y.squeeze() - min(range_vals))/step_val - indices_down)
weight_up = how_much_each_direction
weight_down = 1 - how_much_each_direction
bad_indices = torch.where(torch.logical_or(y.squeeze() > max(range_vals), y.squeeze() < min(range_vals)))
indices_up[bad_indices] = 0
indices_down[bad_indices] = 0
scores = cal_pred
all_scores = scores[torch.arange(cal_pred.shape[0]), indices_up.long()] * weight_up + scores[torch.arange(cal_pred.shape[0]), indices_down.long()] * weight_down
all_scores[bad_indices] = 0
return scores, all_scores
So the idea with bad_indices is that, for $y$ outside the prescribed range (max of the range_vals to min of the range_vals), we automatically do not include that $y$ in our interval. i.e. we give a score of 0 to any candidate value outside the rangevals. This is proven to not be too bad with high probability in terms of coverage. It also makes the code cleaner. Practically, this edge case would only happen if you are doing inference on a datapoint for whose true label lies outside the boundary of the trianing dataset, which should be rare in most cases so not too common. Hope that helps!
Hello @shloknatarajan,
I am analysing your magnificent article and would like to obtain further clarifications on the calculation of the loss and the determination of the intervals. In my case, I am already working with intervals from 0 to 1, so your methodology seems very appropriate to me. Additionally, I am trying to implement your technique for a project, as it could be beneficial. I would also like to receive any recommendations if possible.
Thank you very much