MadryLab / trak

A fast, effective data attribution method for neural networks in PyTorch
https://trak.csail.mit.edu/
MIT License
169 stars 22 forks source link

Save scores in half-precision format by default #30

Closed kristian-georgiev closed 1 year ago

kristian-georgiev commented 1 year ago

Rough idea: use mixed-precision for gradients by default; only instantiate float16 gradients and pass them to the projectors; for BasicProjector, do everything in float16, for CudaProjector, conver the output to float16; save scores in float16 to disk.

Finally, we should have a flag in TRAKer along the lines of use_float32: bool = False.

kristian-georgiev commented 1 year ago

Projecting and storing in float16 by default https://github.com/MadryLab/trak/commit/ec068c375e116304294e8180ce3dbf5431ac88c4

kristian-georgiev commented 1 year ago

Would be nice to also add mixed precision by default. I'm not sure how to combine torch.func.functional_call with torch.autocast tbh. Maybe the easiest way would be to add an @autocast() decorator to get_output().

Another thing from the AMP docs (e.g. https://pytorch.org/docs/stable/notes/amp_examples.html):

Backward passes under autocast are not recommended.

As far as I can tell, otherwise training becomes unstable, because there's a lot of underflow if you store gradients in float16. So the usual AMP training accumulates float32 gradients. But here we're projecting with float16 gradients anyways, so I think we could actually wrap both the forward and backward passes within autocast()?

kristian-georgiev commented 1 year ago

On a second thought, if we're doing everything in float16, we might as well skip autocast() altogether and just .half() the model. My main concern is that this could be more numerically unstable than calling .half() on the gradients after we compute them in float32. Given that compute is dominated by the JL projection, I'll leave things as they are rn, but we can re-open if we think it's worth it.

Added a small script trying out a few options https://github.com/MadryLab/trak/commit/bc627e0e84e415d949a1b36c4dbca41d4be9b77a.