Imbalanced Classification with Deep Reinforcement Learning.
This repository contains an (Double) Deep Q-Network implementation of binary classification on unbalanced datasets using TensorFlow 2.3+ and TF Agents 0.6+. The Double DQN as published in this paper by van Hasselt et al. (2015) is using a custom environment based on this paper by Lin, Chen & Qi (2019).
Example scripts on the Mnist, Fashion Mnist, Credit Card Fraud and Titanic datasets can be found in the ./imbDRL/examples/ddqn/
folder.
The following results are collected with the scripts in the appendix: imbDRLAppendix. Experiments conducted on the latest release of imbDRL and based on this paper by Lin, Chen & Qi (2019).
requirements.txt
./logs/
./models/
./data/
folder located at the root of this repository.
creditcard.csv
downloaded from Kaggle if you would like to use the Credit Card Fraud dataset.creditcard.csv
needs to be split in a seperate train and test file. Please use the function imbDRL.utils.split_csv
Install via pip
:
pip install imbDRL
Run any of the following scripts:
python .\imbDRL\examples\ddqn\train_credit.py
python .\imbDRL\examples\ddqn\train_famnist.py
python .\imbDRL\examples\ddqn\train_mnist.py
python .\imbDRL\examples\ddqn\train_titanic.py
To enable TensorBoard, run tensorboard --logdir logs
Extra arguments are handled with the ./tox.ini
file.
python -m pytest
flake8
./htmlcov
folderThe appendix can be found in the imbDRLAppendix repository.