This PR fixes RecallAt metric, that was giving wrong results for recall@1 (results higher than the upper cut-offs, and sometimes greater than 1) when 1 was among the cut-off values (e.g. RecallAt([1,5]).
It also adds support to providing a scalar cut-off (e.g. RecallAt(5)), instead of a list or tuple.
Finally, it fixes the error that happened when running trainer.evaluate() with a ranking metric with a single cut-off in the list (e.g.RecallAt([5]), as torchmetrics baseMetric` class we inherit from converts results with single element to scalar.
Testing Details :mag:
Added new unit tests to cover RecallAt metric, for both 2D and 3D inputs, with cut-off 1 and with scalar/unique cut-off.
I have also used the paper reproducibility script to test it and now recall@1 is smaller than the upper cut-offs
Fixes #464, Fixes #700
Goals :soccer:
RecallAt
metric, that was giving wrong results for recall@1 (results higher than the upper cut-offs, and sometimes greater than 1) when 1 was among the cut-off values (e.g.RecallAt([1,5])
.RecallAt(5)
), instead of a list or tuple.trainer.evaluate()
with a ranking metric with a single cut-off in the list (e.g.
RecallAt([5]), as torchmetrics base
Metric` class we inherit from converts results with single element to scalar.Testing Details :mag:
Added new unit tests to cover RecallAt metric, for both 2D and 3D inputs, with cut-off 1 and with scalar/unique cut-off. I have also used the paper reproducibility script to test it and now recall@1 is smaller than the upper cut-offs