RubixML / ML

A high-level machine learning and deep learning library for the PHP language.
https://rubixml.com
MIT License
2.04k stars 184 forks source link

RandomForest example? #165

Open zenichanin opened 3 years ago

zenichanin commented 3 years ago

Is there a full Random Forest example/demo anywhere? I'm kinda trying to stitch together things using the Divorce example, but I'm not sure if that will even work. The things it is returning do not make sense to me right now.

Here's what I have so far:

$logger = new Screen();

$logger->info('Loading data into memory');

$estimator = new RandomForest(new ClassificationTree(10), 300, 0.1, true);

$dataset = Labeled::fromIterator(new NDJSON('dataset.ndjson'));

$estimator->train($dataset);

$predict = $estimator->predict($dataset);
$predict = print_r($predict, true);

$logger->info("Prediction is $predict");

And the result I get is an array with 170 items that say married for each.

It would be nice to see what type of data it accepts, the formatting, etc...

andrewdalpino commented 3 years ago

Hello @zenichanin thank you for the great question!

I did some experimenting on my own with Random Forest and the Divorce dataset and indeed RF performs poorly using continuous feature representations. However, I did get great results from using a categorical representation of the features (about 97% accuracy on the testing set). In other words, I treated the Likert scale as a set of 5 discrete categories rather than an interval between 1 and 5. To convert the continuous features to categorical you could use the Interval Discretizer, however, since the repo also provides the dataset in CSV format (which always imports data as strings/categorical values by default), you could import the data from CSV instead without converting the integer strings to integers. Numeric strings are treated as categorical data by convention.

See https://docs.rubixml.com/latest/representing-your-data.html and https://docs.rubixml.com/latest/extracting-data.html#csv for further info

use Rubix\ML\Other\Loggers\Screen;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Extractors\CSV;
use Rubix\ML\Classifiers\RandomForest;
use Rubix\ML\Classifiers\ClassificationTree;
use Rubix\ML\CrossValidation\Metrics\Accuracy;

$logger = new Screen();

$logger->info('Loading data into memory');

$dataset = Labeled::fromIterator(new CSV('dataset.csv'));

[$training, $testing] = $dataset->stratifiedSplit(0.8);

$estimator = new RandomForest(new ClassificationTree(10), 300, 0.1, true);

$logger->info('Training');

$estimator->train($training);

$logger->info('Making predictions');

$predictions = $estimator->predict($testing);

$metric = new Accuracy();

$score = $metric->score($predictions, $testing->labels());

$logger->info("Accuracy is $score");

My intuition is that since the variance of the features on the Likert scale is low, the splitting algorithm has a hard time inducing an appropriate ruleset. Having that said, I found some interesting behavior with the quantile-based splitting algorithm we use to induce decision tree rulesets for continuous features. Particularly, we tend to doubly and sometimes quadruply compute the same split when the variance is low and the the number of samples to split is above some threshold. Addressing this will likely not effect the accuracy (categorical representation of features will fix that) but it may allow us to optimize quantile splitting even further by only computing unique split points.

See https://github.com/RubixML/ML/blob/master/src/Graph/Trees/CART.php#L419

Thank you very much again! Let me know if you get stuck again.

P.S. Rubix ML version 0.2.x does not use FlySystem if you're looking to be compatible with the current version of Laravel (8) that has not upgraded to FlySystem 2.0 yet.

zenichanin commented 3 years ago

Awesome @andrewdalpino thanks so much. I am experimenting with it, will let you know if I get into any hiccups. Thanks for the tip about v0.2.x and Laravel v8. :)

andrewdalpino commented 3 years ago

You're welcome @zenichanin, also Rubix ML 1.0.0 will not use Flysystem so you can try 1.0.0-beta when it's released as well.

zenichanin commented 3 years ago

Hey @andrewdalpino,

So I have been playing around with it some more, although I'm not having much luck applying my data to it.

I wanted to see if you can point me in the right direction regarding how to format the data correctly for RandomForest. I am trying to apply this to stock market data similar to #38, except I'm not trying to predict trend, but rather the price, volume, or other fields I already have.

I just don't know how to best format the data.

I am using CSV data source and have these columns right now:

vw_average
median
trend
bid
ask
5ma
8ma
13ma
volume
created_at

All columns are numeric data types except created_at which is a date/time string.

Whenever I try to load the data into the dataset, I get the exception: Rubix\ML\Exceptions\EmptyDataset : Dataset must contain at least 1 sample.

andrewdalpino commented 3 years ago

Sure @zenichanin, I will do my best to help

Can you post your code? That would be very helpful to me. Thank you!

zenichanin commented 3 years ago

Hey @andrewdalpino, this is roughly my code right now.

public function __construct($dataset)
{
    $this->dataset = Labeled::fromIterator(new CSV($dataset));
}

public function predict()
{
    $logger = new Screen();

    $logger->info('Loading data into memory');

    [$training, $testing] = $this->dataset->stratifiedSplit(0.8);

    $estimator = new RandomForest(new ClassificationTree(10), 300, 0.1, true);

    $logger->info('Training');

    $estimator->train($training);

    $logger->info('Making predictions');

    $predictions = $estimator->predict($testing);

    $metric = new Accuracy();

    $score = $metric->score($predictions, $testing->labels());

    $logger->info("Accuracy is $score");

    //$probability = $estimator->proba($this->dataset);
    //$predict = print_r($predict, true);

    //return $predict;
}

And my actual dataset file is a CSV file with the fields mentioned in the earlier post, with 1st row being the column names. Here's a sample file for one day so you can use for testing. https://file.io/1k8VPDdM7vDe

andrewdalpino commented 3 years ago

And my actual dataset file is a CSV file with the fields mentioned in the earlier post, with 1st row being the column names.

Make sure to set the 'header' argument to true on CSV Extractor if your CSV file contains a header.

https://docs.rubixml.com/latest/extractors/csv.html

Try that and let me know if you still get the Empty Dataset error

zenichanin commented 3 years ago

Thanks @andrewdalpino.

So that helps get rid of the Empty Dataset error.

Next exception I get is: Classifiers require categorical labels, continuous given.

My guess is RandomForest needs to have some categories/labels and not just numeric values. And if so, do you have any suggestion what those categories could be to be useful?

Thanks a bunch for your help!