knakamura13 / mlrose-ky

A highly optimized fork of the popular mlrose-hiive package. For Machine Learning, Randomized Optimization and SEarch algorithms.
https://nkapila6.github.io/mlrose-ky/
BSD 3-Clause "New" or "Revised" License
20 stars 8 forks source link

`NeuralNetwork()` fails with multiclass problems #20

Open nkapila6 opened 1 month ago

nkapila6 commented 1 month ago

Issue Summary NeuralNetwork class fails with multi class problems. Seems like this is an outstanding issue from hiive.

Steps to Reproduce / Describe the Request Using multiclass dataset on NeuralNetwork renders an error when performing .predict.

Additional Information

  1. Sklearn one hot encodes the labels: https://github.com/scikit-learn/scikit-learn/blob/5c4aa5d0d90ba66247d675d4c3fc2fdfba3c39ff/sklearn/neural_network/_multilayer_perceptron.py#L1124

  2. The same was added into hiive: https://github.com/hiive/mlrose/commit/2a6d48b62aeb3ed932fba785c8265cdec2a1387f

Proposed Solution (Optional) Will need to debug this using a toy dataset.

zlyin commented 1 month ago

same issues here