RubixML / ML

A high-level machine learning and deep learning library for the PHP language.
https://rubixml.com
MIT License
2.03k stars 182 forks source link

OneHotEncoder feature mismatch on training/test samples #33

Closed Kalyse closed 5 years ago

Kalyse commented 5 years ago

Unsure if this is just a problem with how I handle things, but let's say that I train my estimator and have a OneHotEncoder transformer.

If one of my categorical features has 4 distinct values, (let's use colors for sake of simplicity), I will get

Red n_1 Green n_2 Yellow n_3 Blue n_4

So, my training data will have sample values which match these additional columns because all of those colors exist in some place in my training data.

However, if I have test samples which does not have those columns, there is no way to "fit" the columns to the training data. Additionally, the methods I would need to get the categories from the transformers from within my pipeline are protected, so I can't actually know how the columns were built.

Is this normal, and if so, am I suposed to do some pre-processing/detection before I usethe OneHot on my learner so that I can also fit manually myself?

andrewdalpino commented 5 years ago

Hi @Kalyse thanks for the question ... I'm assuming you got it working since the issue is closed now but just in case ...

The One Hot Encoder can be fit to whatever dataset you want using the fit() method ... if you are using a training/testing split for cross validation purposes, and for some reason your testing set contains some categorical features not present in the training set, you can simply fit the One Hot Encoder to the entire dataset before splitting.

Example

use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Transformers\OneHotEncoder;

$dataset = new Labeled($samples, $labels);

$transformer = new OneHotEncoder();

$transformer->fit($dataset);

[$training, $testing] = $dataset->split(0.5);

or, you can fit and transform in one go using the apply() method on the dataset object ...

...

$dataset->apply(new OneHotEncoder());

[$training, $testing] = $dataset->split(0.5);

I added a categories() method to OneHotEncoder in the next release so you can see the categories that the transformer computed for each categorical feature column after fitting. This method should primarily be used for exploratory purposes.

Let me know if that helps, and thank you again for the question