Open rittik9 opened 5 days ago
Thanks for the effort. Imho, you could have handled the "refining" on directly preds
tensor using something like this:
preds_topk = torch.argsort(preds, dim=-1, descending=True)[:, :top_k]
preds_top1 = preds_topk[:, 0]
preds=torch.where((target.view(-1, 1) == preds_topk).sum(dim=-1).bool(), target, preds_top1)
Which is more compact way of doing the same job. (Cloning and reshaping are omitted here)
Also, these changes will break the current tests for all top_k
related classes/functions e.g. for recall, accuracy, f1, so on so forth. I think it is important to re-write these tests. Additionally, maybe for the topk accuracy you can take the scikit learn's top_k_accuracy_score
as a reference.
Thanks for your suggestions.I've noticed some of the tests have failed. I'm working on them. I am also comparing them with other library implementations. I'll keep updating here.
Attention: Patch coverage is 11.11111%
with 8 lines
in your changes missing coverage. Please review.
Project coverage is 41%. Comparing base (
0d3494f
) to head (910c537
).:exclamation: There is a different number of reports uploaded between BASE (0d3494f) and HEAD (910c537). Click for more details.
HEAD has 338 uploads less than BASE
| Flag | BASE (0d3494f) | HEAD (910c537) | |------|------|------| |macOS|21|3| |python3.10|63|9| |cpu|98|14| |torch2.0.1|14|2| |torch2.0.1+cpu|21|3| |Windows|14|2| |python3.12|21|3| |torch2.5.0|7|1| |torch2.5.0+cpu|7|1| |gpu|1|0| |unittest|1|0| |Linux|63|9| |torch2.4.1+cu121|14|2| |torch2.1.2+cpu|7|1| |python3.11|7|1| |torch2.3.1+cpu|7|1| |torch2.2.2+cpu|7|1| |torch2.5.0+cu124|14|2| |python3.9|7|1|
🚨 Try these New Features:
Sure
What does this PR do?
Fixes #1653
Before submitting
- [x] Was this **discussed/agreed** via a Github issue? (no need for typos and docs improvements) - [x] Did you read the [contributor guideline](https://github.com/Lightning-AI/torchmetrics/blob/master/.github/CONTRIBUTING.md), Pull Request section? - [x] Did you make sure to **update the docs**? - [x] Did you write any new **necessary tests**?PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.Did you have fun?
Make sure you had fun coding 🙃
📚 Documentation preview 📚: https://torchmetrics--2839.org.readthedocs.build/en/2839/