SankhaSubhra / GAMO

Generative Adversarial Minority Oversampling
GNU General Public License v3.0
31 stars 10 forks source link

GAMO: Generative Adversarial Minority Oversampling

The following is an implementation of an end-to-end deep oversampling approach for feature extraction-classification in presence of class imbalance in image dataset. The algorithm can be described as a game between three players, where a classifier performs its usual actions, a generator attempts to create convex combination of points inside a class which are likely to be misclassified by the classifier, and a discriminator which enforces the generator to adhere the class distribution.

Reference

@InProceedings{Mullick_2019_ICCV,
author = {Mullick, Sankha Subhra and Datta, Shounak and Das, Swagatam},
title = {Generative Adversarial Minority Oversampling},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
} 

Data and code files

The GAMO framework can be used on pre-computed feature vectors or flattened image. Additionally, GAMO can extract useful convolutional features from images by itself in an end-to-end fashion. To illustrate both of these features of GAMO framework we have provided a couple of exemplary codes, respectively applicable on MNIST (flattened image is taken as features) and Fashion-MNIST (deep convolutional features are extracted, where the network is simultaneously trained with the classifier).

Dependencies:

You can use either python2.7 (and above) or python3 as per your choice. Additionally you will need keras (with any backend), scikit-learn, scipy, numpy, os, sys, opencv, pickle, matplotlib as supporting libraries.

Data preparation:

Neither MNIST nor Fashion-MNIST are sufficiently imbalanced in nature to test the efficacy of GAMO. Therefore,we subsample from the different classes and form a new training set with an imbalance ratio of 100. You can download MNIST and Fashion-MNIST from the sources, convert it to csv and respectively run the preprocessing script MNIST_process.py and fMNIST_process.py to create training and test sets which are similar to those used in our experiments.

MNIST codes:

Fashion-MNIST codes:

GAMO2pix:

A tool to visualize the feature vectors generated by GAMO in the original image space. This may be useful for an application which explicitly requires the artificially generated images, in addition to the trained classifier. Here, as an example we provide the code only for the Fashion-MNIST.

For example, some of the generated images for CIFAR10, Fashion-MNIST, and SVHN when visualized by GAMO2pix are as follows, where the imbalance ratio compared to the majority class decreases from the top: