MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
175 stars 24 forks source link

Trouble Matching Margin Values in cifar_quickstart.ipynb #65

Closed amberyzheng closed 5 months ago

amberyzheng commented 6 months ago

Thank you for the great work!

I've run into a bit of a hiccup with the margin values that are included in the downloaded dataset—they aren’t matching what I expected. Could you possibly share more details or the exact code used for computing these values, like training details, and how to compute the function $f$?

Any help would be greatly appreciated!

sung-max commented 5 months ago

Thanks for checking out our work!

You can compute the margin values from logits and labels for classification models using (a variant) of this model output function: https://github.com/MadryLab/trak/blob/9f2f34b474b86c91af5992c6b4fd8da6ca30c2f1/trak/modelout_functions.py#L101

Note that the above function expects a single image and label, but you can adapt it to use a batch of inputs, for example:

def compute_margins(logits, labels): 
    bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) 
    logits_correct = logits[bindex, labels] 
    cloned_logits = logits.clone() 
    # remove the logits of the correct labels from the sum   
    # in logsumexp by setting to -ch.inf   
    cloned_logits[bindex, labels] = ch.tensor( 
        -ch.inf, device=logits.device, dtype=logits.dtype)  
    margins = logits_correct - cloned_logits.logsumexp(dim=-1) 
    return margins

Hope that helps!

kristian-georgiev commented 5 months ago

Closing this for now @amberyzheng, feel free to re-open if you have additional questions!