Closed amberyzheng closed 5 months ago
Thanks for checking out our work!
You can compute the margin values from logits and labels for classification models using (a variant) of this model output function: https://github.com/MadryLab/trak/blob/9f2f34b474b86c91af5992c6b4fd8da6ca30c2f1/trak/modelout_functions.py#L101
Note that the above function expects a single image and label, but you can adapt it to use a batch of inputs, for example:
def compute_margins(logits, labels):
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]
cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, labels] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins
Hope that helps!
Closing this for now @amberyzheng, feel free to re-open if you have additional questions!
Thank you for the great work!
I've run into a bit of a hiccup with the margin values that are included in the downloaded dataset—they aren’t matching what I expected. Could you possibly share more details or the exact code used for computing these values, like training details, and how to compute the function $f$?
Any help would be greatly appreciated!