Summary:
Fusing states for each metric to reduce computation overhead. this will help every model that uses RecMetrics. By fusing state we no longer all gather per state, we see metrics go from N all_gather's to 1. N being number of states held by metric. This can be generally helpful for metrics that may require many states to be computed.
Backward compat remains for metrics that are not vectorized. As well as checkpointing from snapshots before this diff.
Summary: Fusing states for each metric to reduce computation overhead. this will help every model that uses RecMetrics. By fusing state we no longer all gather per state, we see metrics go from N all_gather's to 1. N being number of states held by metric. This can be generally helpful for metrics that may require many states to be computed.
Backward compat remains for metrics that are not vectorized. As well as checkpointing from snapshots before this diff.
Differential Revision: D49957177