microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.61k stars 322 forks source link

Multiclass Classification: assert num_classes >=2 #2205

Open robmarkcole opened 1 month ago

robmarkcole commented 1 month ago

Summary

Both segmentation and object detection require that the background be included and there is currently a note on these args: num_classes: Number of prediction classes (including the background). Considering every dataaset must have at least 1 class, the min value of num_classes is 2. I propose adding an assertion, to prevent people (like myself!) from forgetting this and setting num_classes=1 for datasets with a single class.

Rationale

This config error has happened to me several times, and can pass silently

Implementation

I suppose we add validation to the BaseTask init

Alternatives

No response

Additional information

No response

adamjstewart commented 1 month ago

Not to completely derail what should otherwise be a simple fix, but...

This brings up the question of how we want to handle different forms of classification/semantic segmentation:

Torchmetrics originally had a single class for Accuracy. In https://github.com/Lightning-AI/torchmetrics/issues/1001, they proposed and implemented separate classes for each of the 3 above types of classification (BinaryAccuracy, etc.). The original plan was to deprecate and remove the old single class, but it seems that plan was aborted at some point.

We should decide whether we want BinaryClassificationTask, etc. or whether we want to add a task='binary', etc. parameter to ClassificationTask.

We could definitely still add such an assertion for now and change it to assert num_classes > 1 if task != 'binary' later if needed.

robmarkcole commented 1 month ago

As you point out, binary etc are args torchmetrics accepts, so I think it makes sense to have this functionality with the existing task

adamjstewart commented 1 month ago

Just waiting for clarity on whether torchmetrics is planning on supporting the old metrics forever before deciding, but I was leaning towards that too.

adamjstewart commented 1 month ago

Looks like I misinterpreted, both are supported.

Is there anything special we need to do in our trainers to support binary and multilabel, or do we literally just need to pass different task values to torchmetrics? If the former, we may want to split, but if the latter, I agree we should just keep the current classes and add a task parameter.