MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

fixed model output function when computing gradients in float16 #36

Closed AlaaKhaddaj closed 1 year ago

AlaaKhaddaj commented 1 year ago

When computing the margins from image_classification task, the default dtype of ch.tensor(-ch.inf) is float32. This leads to a datatype mismatch if the model gradients and output were computed in float16.

kristian-georgiev commented 1 year ago

Great catch, thanks!