keras-team / keras-preprocessing

Utilities for working with image data, text data, and sequence data.
Other
1.02k stars 444 forks source link

Add balance in flow_from_directory to handle data imbalance (using random oversampling) #310

Open DOLARIK opened 3 years ago

DOLARIK commented 3 years ago

Summary

This PR helps us balance the imbalanced classes using random oversampling.

Introduces a new argument, balance, in DirectoryIterator. This argument is boolean in nature. (accepts True/False)

Example:

data_train_dir = 'data/train'

from keras_preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(validation_split=0.2)

train_gen = datagen.flow_from_directory(data_train_dir, subset = 'training', 
                                        balance = True)

valid_gen = datagen.flow_from_directory(data_train_dir, subset = 'validation', 
                                        balance = True)

Only the training subset undergoes oversampling. Validation subset is excluded from oversampling. This is done as we use random oversampling only to increase the number of samples in the training dataset for robust learning.

I have created a Colab Notebook to play around with this new feature:

Open In Colab

Underlying Concept:

Consider an imbalanced dataset having categories A, B and C, with the following files in the respective sub-directories:

data
|____train________ ___________________ __________________
                  |                   |                  |
                  A                   B                  C
                  |__A_0.jpg          |__B_0.jpg         |__C_0.jpg
                  |__A_1.jpg          |__B_1.jpg         |__C_1.jpg
                  |__A_2.jpg                             |__C_2.jpg
                  |__A_3.jpg           

As seen here, majority count here is 4 (in A), so the count in all the other categories too will be made 4 by randomly sampling a filename from the original set of filenames from the respective sub-directories (category directories):

data
|____train________ ___________________ __________________
                  |                   |                  |
                  A                   B                  C
                  |__A_0.jpg          |__B_0.jpg         |__C_0.jpg
                  |__A_1.jpg          |__B_1.jpg         |__C_1.jpg
                  |__A_2.jpg          |--B_1.jpg         |__C_2.jpg
                  |__A_3.jpg          |--B_0.jpg         |--C_0.jpg           

Here, from B, after randomly oversampling from [B_0.jpg, B_1.jpg] and appending them to the list, to make the total filenames equal to 4 (the majority count) we got [B_0.jpg, B_1.jpg, B_1.jpg, B_0.jpg]. (resampled filenames are in bold)

Similarly, for C, after random oversampling, we got [C_0.jpg, C_1.jpg, C_2.jpg, __C_0.jpg] (resampled filenames are in bold__)

After Data Augmentation:

data
|____train________ ___________________ __________________
                  |                   |                  |
                  A                   B                  C
                  |__A_0.jpg.{67}     |__B_0.jpg.{12}    |__C_0.jpg.{42}
                  |__A_1.jpg.{55}     |__B_1.jpg.{43}    |__C_1.jpg.{20}
                  |__A_2.jpg.{32}     |--B_1.jpg.{05}    |__C_2.jpg.{71}
                  |__A_3.jpg.{45}     |--B_0.jpg.{10}    |--C_0.jpg.{80}           

Notice that B_1.jpg{43} and B_1.jpg{05} are technically two different images. So, this is how with random oversampling using data augmentation, we can increase the number of samples in our dataset.

This feature has helped me handle data imbalance without using external libraries and keep the whole training pipeline clean, smooth and simple. I hope it'll help others too.

I am still figuring out the best practices for creating unit tests and updating the docs, therefore have not been able to add new tests for this feature yet. So, for you to test this out now, I have created this Colab Notebook. It includes demo data directory and visualizations. I will soon add appropriate tests.

This new feature has passed all the pre-existing tests.

PR Overview