szilard / benchm-ml

A minimal benchmark for scalability, speed and accuracy of commonly used open source implementations (R packages, Python scikit-learn, H2O, xgboost, Spark MLlib etc.) of the top machine learning algorithms for binary classification (random forests, gradient boosted trees, deep neural networks etc.).
MIT License
1.87k stars 334 forks source link

DL with mxnet #29

Closed szilard closed 7 years ago

szilard commented 8 years ago

Trying to see if DL can match RF/GBM in accuracy on the airline dataset (where train is sampled from years 2005-2006, while validation and test sets sampled disjunctly from 2007). Also, some variables are kept categorical artificially and are intentionally not encoded as ordinal variables (to better match the structure of business datasets).

Recap: with 10M records training (largest in the benchmark) RF AUC 0.80 GBM 0.81 (on test set).

So far I get 0.72 with DL with mxnet on 1M train: https://github.com/szilard/benchm-ml/blob/master/4-DL/2-mxnet.R

Comparably on the 1M train xgboost has achieved 0.77 and with some tuning I think it can get 0.79.

I tried a few architectures (#hidden layers etc), but it won't beat the settings I took from an mxnet example. Runs about 1 minute to train on a EC2 g2.8xlarge box using 1 GPU (if using all 4 GPUs it was slower). nvidia-smi shows GPU utilization ~20% and memory usage ~2GB (out of 4GB). On CPU (32 cores) it training takes about 5 mins.

The "problem" is DL learns very fast, the best AUC (on a validation set) is reached after 2 epochs. On the other hand xgboost runs ~1hr to get good accuracy. That is the DL model seems underfitted to me.

Surely, DL might not beat GBM on this kind of data (proxy for general business data such as credit risk or fraud detection), but it should do better than 0.72.

Datasets: https://s3.amazonaws.com/benchm-ml--main/train-1m.csv https://s3.amazonaws.com/benchm-ml--main/train-10m.csv https://s3.amazonaws.com/benchm-ml--main/valid.csv https://s3.amazonaws.com/benchm-ml--main/test.csv

szilard commented 8 years ago

@tqchen @hetong007: Any comments/suggestions on the above? (continuing my machine learning benchmark with deep learning and mxnet) https://github.com/szilard/benchm-ml/issues/29

tqchen commented 8 years ago

Deep nets are definitely harder to tune, if things converge too fast, try smaller learning rate, shuffle the data. Seems much of gains in the airline dataset comes from combination of categories, which deepnet may not be very good at

szilard commented 8 years ago

I've been playing around with params, see also discussion here https://github.com/szilard/benchm-ml/issues/29

The data was already shuffled. I was actually asking Arno (H2O) a few minutes ago precisely about a smaller manual learning rate (instead of adaptive - see other thread). Maybe I should just try...