MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

Fix bug with custom model output and add test #25

Closed jvendrow closed 1 year ago

jvendrow commented 1 year ago

Due to a typo, any custom model output functions cause an error when initializing the TRAKer. Fix the bug here and add a corresponding test.

sung-max commented 1 year ago

Thank you, Josh!