haifengl / smile

Statistical Machine Intelligence & Learning Engine
https://haifengl.github.io
Other
6.05k stars 1.13k forks source link

Issue: Error with Prediction Method in Random Forest and Gradient Boost Regression #771

Closed ProtossidoDiAzoto closed 6 months ago

ProtossidoDiAzoto commented 6 months ago

Describe the bug I encountered an issue with the prediction method when attempting regression using Random Forest and Gradient Boost algorithms. The problem arises specifically in versions higher than 3.0.0. In version 2.6.0, this problem does not occur.

Reproduction Steps 1.Use the provided code snippets for setting up the regression. 2.Attempt to run regression using Random Forest or Gradient Boost with versions > 3.0.0. 3.Observe the error message mentioned below.

Code Snippet

long[] seeds = {
            342317953, 521642753, 72070657, 577451521, 266953217, 179976193,
            374603777, 527788033, 303395329, 185759582, 261518209, 461300737,
            483646580, 532528741, 159827201, 284796929, 655932697, 26390017,
            454330473, 867526205, 824623361, 719082324, 334008833, 699933293,
            823964929, 155216641, 150210071, 249486337, 713508520, 558398977,
            886227770, 74062428, 670528514, 701250241, 363339915, 319216345,
            757017601, 459643789, 170213767, 434634241, 414707201, 153100613,
            753882113, 546490145, 412517763, 888761089, 628632833, 565587585,
            175885057, 594903553, 78450978, 212995578, 710952449, 835852289,
            415422977, 832538705, 624345857, 839826433, 260963602, 386066438,
            530942946, 261866663, 269735895, 798436064, 379576194, 251582977,
            349161809, 179653121, 218870401, 415292417, 86861523, 570214657,
            701581299, 805955890, 358025785, 231452966, 584239408, 297276298,
            371814913, 159451160, 284126095, 896291329, 496278529, 556314113,
            31607297, 726761729, 217004033, 390410146, 70173193, 661580775,
            633589889, 389049037, 112099159, 54041089, 80388281, 492196097,
            912179201, 699398161, 482080769, 363844609, 286008078, 398098433,
            339855361, 189583553, 697670495, 709568513, 98494337, 99107427,
            433350529, 266601473, 888120086, 243906049, 414781441, 154685953,
            601194298, 292273153, 212413697, 568007473, 666386113, 712261633,
            802026964, 783034790, 188095005, 742646355, 550352897, 209421313,
            175672961, 242531185, 157584001, 201363231, 760741889, 852924929,
            60158977, 774572033, 311159809, 407214966, 804474160, 304456514,
            54251009, 504009638, 902115329, 870383757, 487243777, 635554282,
            564918017, 636074753, 870308031, 817515521, 494471884, 562424321,
            81710593, 476321537, 595107841, 418699893, 315560449, 773617153,
            163266399, 274201241, 290857537, 879955457, 801949697, 669025793,
            753107969, 424060977, 661877468, 433391617, 222716929, 334154852,
            878528257, 253742849, 480885528, 99773953, 913761493, 700407809,
            483418083, 487870398, 58433153, 608046337, 475342337, 506376199,
            378726401, 306604033, 724646374, 895195218, 523634541, 766543466,
            190068097, 718704641, 254519245, 393943681, 796689751, 379497473,
            50014340, 489234689, 129556481, 178766593, 142540536, 213594113,
            870440184, 277912577};

    public static final double[][] x = {
            {234.289,      235.6,        159.0,    107.608, 1947,   60.323},
            {259.426,      232.5,        145.6,    108.632, 1948,   61.122},
            {258.054,      368.2,        161.6,    109.773, 1949,   60.171},
            {284.599,      335.1,        165.0,    110.929, 1950,   61.187},
            {328.975,      209.9,        309.9,    112.075, 1951,   63.221},
            {346.999,      193.2,        359.4,    113.270, 1952,   63.639},
            {365.385,      187.0,        354.7,    115.094, 1953,   64.989},
            {363.112,      357.8,        335.0,    116.219, 1954,   63.761},
            {397.469,      290.4,        304.8,    117.388, 1955,   66.019},
            {419.180,      282.2,        285.7,    118.734, 1956,   67.857},
            {442.769,      293.6,        279.8,    120.445, 1957,   68.169},
            {444.546,      468.1,        263.7,    121.950, 1958,   66.513},
            {482.704,      381.3,        255.2,    123.366, 1959,   68.655},
            {502.601,      393.1,        251.4,    125.368, 1960,   69.564},
            {518.173,      480.6,        257.2,    127.852, 1961,   69.331},
            {554.894,      400.7,        282.7,    130.081, 1962,   70.551}
    };

    public static final double[] y = {
            83.0,  88.5,  88.2,  89.5,  96.2,  98.1,  99.0, 100.0, 101.2,
            104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
    };

    public static DataFrame data = DataFrame.of(x, "GNP", "unemployed", "armed_forces", "population", "year", "employed").merge(DoubleVector.of("deflator", y));
    public static Formula formula = Formula.lhs("deflator");

    @Test
    public void tryOutRandomForest(){
        MathEx.setSeed(19650218);
        RandomForest model = RandomForest.fit(formula, data, 100, 3, 20, 10, 3, 1.0, Arrays.stream(seeds));
        for (int i = 0; i < x.length; i++) {
            System.out.println(model.predict(Tuple.of(x[i],model.schema())));
        }
    }

Expected behavior The regression should execute prediction successfully without any errors, similar to the behavior observed in version 2.6.0.

Actual behavior Illegal argument exception is thrown :

Field deflator doesn't exist java.lang.IllegalArgumentException: Field deflator doesn't exist at smile.data.type.StructType.indexOf(StructType.java:103) at smile.data.formula.Variable$1.<init>(Variable.java:80) at smile.data.formula.Variable.bind(Variable.java:78) at smile.data.formula.Formula.bind(Formula.java:360) at smile.data.formula.Formula.x(Formula.java:433) at smile.regression.RandomForest.predict(RandomForest.java:455)

Additional context

Request for Assistance Could someone kindly provide insights into what might be causing this error? I'd greatly appreciate any guidance or suggestions for troubleshooting steps.

haifengl commented 6 months ago

Thanks for reporting. But I cannot reproduce the issue. Here is the output of your code in jshell:

jshell> smile.regression.RandomForest model = smile.regression.RandomForest.fit(formula, data, 100, 3, 20, 10, 3, 1.0, Arrays.stream(seeds));
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 92.68%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 92.50%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 95.29%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 45.31%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 72.80%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 96.00%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: -52.84%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 79.68%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 84.25%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 80.84%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.74%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 81.39%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 52.97%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 19.96%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.88%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 77.90%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 80.58%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 65.82%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: -0.68%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 85.60%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.41%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 86.14%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.58%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 69.15%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 47.19%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 64.80%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 50.14%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 91.10%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 66.80%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 82.73%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 71.68%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 71.86%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 82.10%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 37.68%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 90.01%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 44.82%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 89.69%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 89.96%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 86.88%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 68.18%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 80.44%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: -67.39%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 83.29%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 68.86%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: -52.48%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 69.99%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 52.94%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 88.30%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 76.42%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 46.09%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 88.84%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.47%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 90.22%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 44.55%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 73.57%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 71.46%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 56.63%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.19%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 85.09%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 48.94%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 59.37%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 86.04%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 91.90%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 77.22%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 68.18%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 72.97%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 42.43%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 88.18%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 89.88%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 75.93%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 63.56%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 82.64%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 79.55%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: -592.92%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: -40.05%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 92.73%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 72.49%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 80.21%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 75.78%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 81.12%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 81.77%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 84.95%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 84.55%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 77.15%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 44.95%
[ForkJoinPool.commonPool-worker-5] INFO smile.regression.RandomForest - Regression tree OOB R2: 81.46%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 42.47%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 79.70%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 76.79%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 16.79%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 13.83%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 89.02%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 83.28%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: -147.08%
[ForkJoinPool.commonPool-worker-4] INFO smile.regression.RandomForest - Regression tree OOB R2: 50.39%
[ForkJoinPool.commonPool-worker-6] INFO smile.regression.RandomForest - Regression tree OOB R2: 30.56%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 94.31%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 91.06%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 58.76%
[main] INFO smile.regression.RandomForest - Regression tree OOB R2: 55.30%
smpawlowski commented 6 months ago

Hi! I run into similar issue when trying to upgrade from 2.6.0 to 3.1.0. Predict requires DataFrame that contains the predicted variable. Code below works in 2.6.0, but not 3.1.0. ` import org.junit.Assert; import org.junit.Test; import smile.data.DataFrame; import smile.data.formula.Formula; import smile.data.vector.DoubleVector; import smile.regression.LinearModel; import smile.regression.OLS;

public class TestSmileRegression {

@Test
public void test_formula_OLS() {
    double[] x = {1, 2, 3};
    double[] y = {1, 2, 3};
    DataFrame df = DataFrame.of(DoubleVector.of("x", x),
                                DoubleVector.of("y", y));
    LinearModel regr = OLS.fit(Formula.lhs("y"), df);

    double[] x_pred = {4,5,6};
    double[] y_pred = regr.predict(DataFrame.of( DoubleVector.of("x", x_pred)));
    for(int i=0; i<x_pred.length; i++) {
        Assert.assertEquals(x_pred[i], y_pred[i], 1e-9);
    }

}

}`

Exception: `java.lang.IllegalArgumentException: Field y doesn't exist

at smile.data.type.StructType.indexOf(StructType.java:103)
at smile.data.formula.Variable$1.<init>(Variable.java:80)
at smile.data.formula.Variable.bind(Variable.java:78)
at smile.data.formula.Formula.bind(Formula.java:360)
at smile.data.formula.Formula.x(Formula.java:497)
at smile.data.formula.Formula.matrix(Formula.java:546)
at smile.regression.LinearModel.predict(LinearModel.java:358)
at models.TestSmileRegression.test_formula_OLS(TestSmileRegression.java:22)
`
ProtossidoDiAzoto commented 6 months ago

yes exactly "predict requires DataFrame that contains the predicted variable" indeed I had solved the issue the past week by implementing the following solution:

@Test
    public void tryOutRandomForestArrayData(){
        MathEx.setSeed(19650218);
        RandomForest model = RandomForest.fit(formula, data, 100, 3, 20, 10, 3, 1.0, Arrays.stream(seeds));

        List<StructField> fields = Arrays.asList(
                new StructField("GNP", DataTypes.DoubleType),
                new StructField("unemployed", DataTypes.DoubleType),
                new StructField("armed_forces", DataTypes.DoubleType),
                new StructField("population", DataTypes.DoubleType),
                new StructField("year", DataTypes.IntegerType),
                new StructField("employed", DataTypes.DoubleType),
                new StructField("deflator", DataTypes.DoubleType)
        );

        StructType st = new StructType(fields);
        for (int i = 0; i < x.length; i++) {
            Tuple param = Tuple.of(x[i],st);
            System.out.println(model.predict(param));
        }
    }