A highly optimized fork of the popular mlrose-hiive package. For Machine Learning, Randomized Optimization and SEarch algorithms.
BSD 3-Clause "New" or "Revised" License
20
stars
8
forks
source link
`NeuralNetwork()` fails with multiclass problems #20
Open
nkapila6 opened 1 month ago
Issue Summary NeuralNetwork class fails with multi class problems. Seems like this is an outstanding issue from hiive.
Steps to Reproduce / Describe the Request Using multiclass dataset on NeuralNetwork renders an error when performing .predict.
Additional Information
Sklearn one hot encodes the labels: https://github.com/scikit-learn/scikit-learn/blob/5c4aa5d0d90ba66247d675d4c3fc2fdfba3c39ff/sklearn/neural_network/_multilayer_perceptron.py#L1124
The same was added into hiive: https://github.com/hiive/mlrose/commit/2a6d48b62aeb3ed932fba785c8265cdec2a1387f
Proposed Solution (Optional) Will need to debug this using a toy dataset.