MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

all NaN or all zero scores? #41

Closed nickk124 closed 1 year ago

nickk124 commented 1 year ago

Hi,

I've been using trak to perform attribution on binary classification predictions on images. (I modified my model, which is built to output a single class 0 probability, to give trak a class 0 logit and corresponding class 1 logit to work with trak's prebuilt image classification task). Are there typical generic reasons why I should be getting attribution scores that are all NaN or all zero? I've encountered this problem in various situations.

Thank you for your excellent codebase and documentation; it's been great experimenting with it.

Thanks!

andrewilyas commented 1 year ago

Hi @nickk124 can you give some additional info about the dataset and model you're trying? Happy to help!

nickk124 commented 1 year ago

Hi @andrewilyas, thanks for the quick reply! I'm not free to provide too many details, but I'm doing attribution on binary classification predictions with a ResNet-18, pre-trained on ImageNet. The training and testing sets that I'm using are both subsets of ImageNet (some selection of certain classes + images from other random classes), such that the training set has about 50,000 images, and the test set has about 2,000. I'm attributing predictions on test images back to the training set.

The issue is that attribution performs fine when I use only about 10% of the full data, randomly sampled (about 5,000 training, 200 testing)--the trak scores and features seem reasonable. However, when I use the full training and test sets, I get trak features of all zero and track scores of all NaNs.

Thanks for the help, and sorry that I can only provide limited details.

kristian-georgiev commented 1 year ago

Hi @nickk124, can you try using use_half_precision=False when initializing the TRAKer? I suspect this is a precision issue caused by using float16 instead of float32.

nickk124 commented 1 year ago

@kristian-georgiev yes, will do! I'll let you know once I have results.

nickk124 commented 1 year ago

Turning off half precision worked--I'm getting reasonable attribution scores and features. Thanks for the help and for being so responsive!

kristian-georgiev commented 1 year ago

Nice, we'll make use_half_precision be False by default in the next release.