pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
211 stars 46 forks source link

support int and float types in _add_state #158

Closed JKSenthil closed 1 year ago

JKSenthil commented 1 year ago

Summary: current sync method only supports syncing states of Tensor, List[Tensor], and Dict[str, Tensor]

this can cause issues for certain metrics (primarily windowed metrics) which rely on integer attributes when merging states.

Added support for syncing int states. Changes:

  1. Added int and float as supported types to TState in metric.py
  2. Changed impacted methods to have these int attributes via the _add_state() method so that they are synced

Note: adding int support and using gather_object was chosen over changing these variables to be single value tensors, because if these tensors are on GPU, there can be much more latency when using these variables in the merge_state logic

Reviewed By: bobakfb, ananthsub

Differential Revision: D46775377

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D46775377

codecov[bot] commented 1 year ago

Codecov Report

Merging #158 (fe4c6ab) into main (76ea267) will increase coverage by 0.01%. The diff coverage is 3.84%.

@@            Coverage Diff             @@
##             main     #158      +/-   ##
==========================================
+ Coverage   24.45%   24.46%   +0.01%     
==========================================
  Files         176      176              
  Lines       10359    10350       -9     
==========================================
- Hits         2533     2532       -1     
+ Misses       7826     7818       -8     
Impacted Files Coverage Δ
tests/metrics/aggregation/test_cat.py 35.89% <ø> (ø)
tests/metrics/test_metric.py 11.85% <0.00%> (-0.09%) :arrow_down:
tests/metrics/test_toolkit.py 39.50% <ø> (ø)
tests/metrics/window/test_auroc.py 16.43% <ø> (ø)
tests/metrics/window/test_click_through_rate.py 20.68% <ø> (ø)
tests/metrics/window/test_mean_squared_error.py 11.57% <ø> (ø)
tests/metrics/window/test_weighted_calibration.py 21.42% <ø> (ø)
torcheval/metrics/aggregation/cat.py 50.00% <0.00%> (ø)
torcheval/metrics/toolkit.py 37.64% <0.00%> (+4.65%) :arrow_up:
torcheval/metrics/window/auroc.py 15.85% <0.00%> (ø)
... and 5 more

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D46775377

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D46775377