Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.15k stars 408 forks source link

Fix `top_k` for `multiclass-f1score` #2839

Open rittik9 opened 5 days ago

rittik9 commented 5 days ago

What does this PR do?

Fixes #1653

Before submitting - [x] Was this **discussed/agreed** via a Github issue? (no need for typos and docs improvements) - [x] Did you read the [contributor guideline](https://github.com/Lightning-AI/torchmetrics/blob/master/.github/CONTRIBUTING.md), Pull Request section? - [x] Did you make sure to **update the docs**? - [x] Did you write any new **necessary tests**?
PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2839.org.readthedocs.build/en/2839/

eneserdo commented 4 days ago

Thanks for the effort. Imho, you could have handled the "refining" on directly preds tensor using something like this:

preds_topk = torch.argsort(preds, dim=-1, descending=True)[:, :top_k]
preds_top1 = preds_topk[:, 0]
preds=torch.where((target.view(-1, 1) == preds_topk).sum(dim=-1).bool(), target, preds_top1)

Which is more compact way of doing the same job. (Cloning and reshaping are omitted here)

Also, these changes will break the current tests for all top_k related classes/functions e.g. for recall, accuracy, f1, so on so forth. I think it is important to re-write these tests. Additionally, maybe for the topk accuracy you can take the scikit learn's top_k_accuracy_score as a reference.

rittik9 commented 4 days ago

Thanks for your suggestions.I've noticed some of the tests have failed. I'm working on them. I am also comparing them with other library implementations. I'll keep updating here.

codecov[bot] commented 3 days ago

Codecov Report

Attention: Patch coverage is 11.11111% with 8 lines in your changes missing coverage. Please review.

Project coverage is 41%. Comparing base (0d3494f) to head (910c537).

:exclamation: There is a different number of reports uploaded between BASE (0d3494f) and HEAD (910c537). Click for more details.

HEAD has 338 uploads less than BASE | Flag | BASE (0d3494f) | HEAD (910c537) | |------|------|------| |macOS|21|3| |python3.10|63|9| |cpu|98|14| |torch2.0.1|14|2| |torch2.0.1+cpu|21|3| |Windows|14|2| |python3.12|21|3| |torch2.5.0|7|1| |torch2.5.0+cpu|7|1| |gpu|1|0| |unittest|1|0| |Linux|63|9| |torch2.4.1+cu121|14|2| |torch2.1.2+cpu|7|1| |python3.11|7|1| |torch2.3.1+cpu|7|1| |torch2.2.2+cpu|7|1| |torch2.5.0+cu124|14|2| |python3.9|7|1|
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #2839 +/- ## ======================================== - Coverage 69% 41% -28% ======================================== Files 346 332 -14 Lines 19129 18962 -167 ======================================== - Hits 13227 7736 -5491 - Misses 5902 11226 +5324 ```

🚨 Try these New Features:

rittik9 commented 1 day ago

Sure