ZainNasrullah / pollution-select-feature-selection

conceptualizing method for feature selection
MIT License
0 stars 1 forks source link

.. highlight:: rst


Background

Pollution Select is a feature selection algorithm method based on ideas from boruta and other iterative selection methods. It finds features that consistently achieve a desired performance criteria and are more important than random noise in monte carlo cross-validation.


Algorithm


Install

The simplest way to install right now is to clone this repo and then do a local install:

.. code-block:: console

git clone https://github.com/ZainNasrullah/feature-selection-experiments.git
cd feature-selection-experiments
pip install .

Quick Start

Simple example without dropping any features:

.. code-block:: python

import numpy as np from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from pollution_select import PollutionSelect

iris = load_iris() X = iris.data y = iris.target X_noise = np.concatenate( (np.random.rand(150, 1), X, np.random.rand(150, 1)), axis=1 )

def acc(y, preds): return np.mean(y == preds)

selector = PollutionSelect( RandomForestClassifier(), performance_function=acc, performance_threshold=0.7, )

X_transform = selector.fit_transform(X_noise, y) print(selector.featureimportances)

More complex example with feature dropping:

.. code-block:: python

import numpy as np from sklearn.datasets import make_classification from sklearn.ensemble import RandomForestClassifier from pollution_select import PollutionSelect

X, y = make_classification( n_samples=1000, n_features=20, n_informative=10, n_redundant=5 )

def acc(y, preds): return np.mean(y == preds)

selector = PollutionSelect( RandomForestClassifier(), n_iter=100, pollute_type="random_k", drop_features=True, performance_threshold=0.7, performance_function=acc, min_features=4, )

selector.fit(X, y)

print(selector.retainedfeatures) print(selector.droppedfeatures) print(selector.featureimportances)

selector.plot_test_scores_by_iters() selector.plot_test_scores_by_n_features()