pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Investigate per-sample-grad performance relative to other mechanism #1020

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

Opacus has a benchmark suite over at https://github.com/pytorch/opacus/tree/main/benchmarks that computes per-sample-grads using three mechanisms:

kshitij12345 commented 1 year ago

See https://github.com/pytorch/opacus/issues/521