Closed szilard closed 7 years ago
@tqchen @hetong007: Any comments/suggestions on the above? (continuing my machine learning benchmark with deep learning and mxnet) https://github.com/szilard/benchm-ml/issues/29
Deep nets are definitely harder to tune, if things converge too fast, try smaller learning rate, shuffle the data. Seems much of gains in the airline dataset comes from combination of categories, which deepnet may not be very good at
I've been playing around with params, see also discussion here https://github.com/szilard/benchm-ml/issues/29
The data was already shuffled. I was actually asking Arno (H2O) a few minutes ago precisely about a smaller manual learning rate (instead of adaptive - see other thread). Maybe I should just try...
Trying to see if DL can match RF/GBM in accuracy on the airline dataset (where train is sampled from years 2005-2006, while validation and test sets sampled disjunctly from 2007). Also, some variables are kept categorical artificially and are intentionally not encoded as ordinal variables (to better match the structure of business datasets).
Recap: with 10M records training (largest in the benchmark) RF AUC
0.80
GBM0.81
(on test set).So far I get
0.72
with DL with mxnet on 1M train: https://github.com/szilard/benchm-ml/blob/master/4-DL/2-mxnet.RComparably on the 1M train xgboost has achieved
0.77
and with some tuning I think it can get0.79
.I tried a few architectures (#hidden layers etc), but it won't beat the settings I took from an mxnet example. Runs about 1 minute to train on a EC2 g2.8xlarge box using 1 GPU (if using all 4 GPUs it was slower).
nvidia-smi
shows GPU utilization ~20% and memory usage ~2GB (out of 4GB). On CPU (32 cores) it training takes about 5 mins.The "problem" is DL learns very fast, the best AUC (on a validation set) is reached after 2 epochs. On the other hand xgboost runs ~1hr to get good accuracy. That is the DL model seems underfitted to me.
Surely, DL might not beat GBM on this kind of data (proxy for general business data such as credit risk or fraud detection), but it should do better than
0.72
.Datasets: https://s3.amazonaws.com/benchm-ml--main/train-1m.csv https://s3.amazonaws.com/benchm-ml--main/train-10m.csv https://s3.amazonaws.com/benchm-ml--main/valid.csv https://s3.amazonaws.com/benchm-ml--main/test.csv