Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 213 forks source link

Ability to pass the `TargetFormatter` to use with classification data modules #1131

Closed daMichaelB closed 2 years ago

daMichaelB commented 2 years ago

❓ Questions and Help

What is your question?

I have a highly imbalanced dataset, where some minority classes are very rare. I put them ONLY into the validation set. I want to validate, if the model can classify them not to be in the majority class.

The Datamodule was created with:

        datamodule = ImageClassificationData.from_data_frame(
            "file", "label",
            train_images_root=...,
            val_images_root=....,
            test_images_root=...,
            train_data_frame=train_df,
            val_data_frame=valid_df,
            test_data_frame=test_df,
            ...
        )

As i understood, i can create the ImageClassifier with the number of ALL classes:

        model = ImageClassifier(backbone=...,
                                num_classes=self.num_classes,
                                pretrained=...)

However my training crashes at the beginning with Validation sanity check:. Tracelog:

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/io/input.py", line 317, in __getitem__
    return self._call_load_sample(self.data[index])
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/io/input.py", line 236, in _call_load_sample
    return load_sample(copy(sample))
  File "/usr/local/lib/python3.8/dist-packages/flash/image/classification/input.py", line 49, in load_sample
    sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/io/classification_input.py", line 79, in format_target
    return self.target_formatter(target)
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/utilities/classification.py", line 163, in __call__
    return self.format(target)
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/utilities/classification.py", line 182, in format
    return self.label_to_idx[(target[0] if not isinstance(target, str) else target).strip()]
KeyError: '14'

I found that Label 14 is in the validation set but not in the training set.

Question

Is there a way to train on a subset of the classes but validate on all classes ?

What have you tried?

I have no idea how to workaround this...

What's your environment?

ethanwharris commented 2 years ago

Hey @daMichaelB Thanks for reporting this! This is something we do now have support for internally but don't yet expose to the user. All labels / num classes etc. for classification problems are handled by a TargetFormatter object (see the API references here: https://lightning-flash.readthedocs.io/en/latest/api/data.html#flash-core-data-utilities-classification ).

These objects are usually inferred from the training data, but in cases where that inference is not possible (e.g. where can't efficiently get a list of all targets) we have begun to expose this object. So you could have for example:

 datamodule = ImageClassificationData.from_data_frame(
    ...,
    target_formatter = MultiLabelTargetFormatter(labels=["label_1", ..., "label_n"]),
)

Would this API work for you? If so, I can get to work on adding the target_formatter argument to all of our from_* methods :smiley:

daMichaelB commented 2 years ago

Hey @ethanwharris . This would solve a lot of trouble on my side πŸŽ‰ ! I think that would be a great feature for dealing with imbalanced datasets!

Thank you for the suggestion and let me know if i can help with testing it!

daMichaelB commented 2 years ago

Thank you for the great support and implementation πŸ‘