ml-stat-Sustech / TorchCP

A Python toolbox for conformal prediction research on deep learning models, using PyTorch.
GNU Lesser General Public License v3.0
193 stars 26 forks source link

Margin Implementation #23

Closed ThomasNorr closed 6 days ago

ThomasNorr commented 2 months ago

Hello,

and thanks again for your helpful toolbox. I have a question regarding the implementation of "Margin()":

You mention " indices = torch.arange(num_labels).to(probs.device) temp_probs[:, indices, indices] = -1

torch.max(temp_probs, dim=-1) are the second highest probabilities

"

Why would the comment be the case? From my perspective indices would need to something like torch.max(probs, dim=-1).indices for this to make sense. But I also dont understand why you unsqueeze anything here. Could you maybe explain that?

Best Regards,

Thomas

Jianguo99 commented 2 months ago

Dear Thomas,

The comment and variable name here might be wrong. This is the largest prob except for the current labels, as shown in the equation below. We will fix the comment here.

$S(\boldsymbol{x},y) = max_{j=1, \ldots, |\mathcal{Y}|: j \neq y} Pj - P{y} $

Please let us know if you have any more questions. Best regards, authors