haifengl / smile

Statistical Machine Intelligence & Learning Engine
https://haifengl.github.io
Other
6.06k stars 1.13k forks source link

ArrayIndexOutOfBoundsException inside RegressionTree.findBestSplit #786

Closed Mikkomario closed 3 months ago

Mikkomario commented 3 months ago

Describe the bug When training a Gradient Boosted Regression Tree model using gmb(...), the function rhrows, as it attempts to access a non-existing index in an array.

Expected behavior I expected the function to run normally and to complete the training. In case the input data is faulty, I'd have expected the function to throw an IllegalArgumentException or something.

Actual behavior gmb throws an exception. Here's the stack trace:

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 643
    at smile.regression.RegressionTree.findBestSplit(RegressionTree.java:150)
    at smile.base.cart.CART.lambda$findBestSplit$10(CART.java:373)
    at java.util.stream.IntPipeline$4$1.accept(IntPipeline.java:250)
    at java.util.Spliterators$IntArraySpliterator.forEachRemaining(Spliterators.java:1032)
    at java.util.Spliterator$OfInt.forEachRemaining(Spliterator.java:693)
    at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:481)
    at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:471)
    at java.util.stream.ReduceOps$ReduceTask.doLeaf(ReduceOps.java:747)
    at java.util.stream.ReduceOps$ReduceTask.doLeaf(ReduceOps.java:721)
    at java.util.stream.AbstractTask.compute(AbstractTask.java:316)
    at java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:731)
    at java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:289)
    at java.util.concurrent.ForkJoinTask.doInvoke(ForkJoinTask.java:401)
    at java.util.concurrent.ForkJoinTask.invoke(ForkJoinTask.java:734)
    at java.util.stream.ReduceOps$ReduceOp.evaluateParallel(ReduceOps.java:714)
    at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:233)
    at java.util.stream.ReferencePipeline.reduce(ReferencePipeline.java:479)
    at java.util.stream.ReferencePipeline.max(ReferencePipeline.java:515)
    at smile.base.cart.CART.findBestSplit(CART.java:379)
    at smile.regression.RegressionTree.<init>(RegressionTree.java:255)
    at smile.regression.GradientTreeBoost.fit(GradientTreeBoost.java:240)
    at smile.regression.package$.$anonfun$gbm$1(package.scala:390)
    at smile.util.package$time$.apply(package.scala:67)
    at smile.regression.package$.gbm(package.scala:390)
    at ...

The stack trace is referring to this line in RegressionTree.java: trueCount[idx] += samples[o];

Code snippet Here's the code that produced this error. I removed additional prints, etc. The training data was read from a local database. I've attached a csv file generated using Write(df, ...).

import smile.data.formula.Formula
import smile.data.{DataFrame, Tuple}
import smile.regression._
import scala.jdk.CollectionConverters._

// Collects the training data
val (data, testingData, context) = collectTrainingData(...)
implicit val c: OrderSizePredictionContext = context
val schema = OrderSizeTrainingRow.schema

// Converts the read data into a data frame
val df = DataFrame.of(data.view.map { row => Tuple.of(row.features, schema) }.toList.asJava, schema)

// Define the formula for the regression model
val formula = Formula.lhs("orderSize")

// Trains the Gradient-Boosted Regression Trees -model
val model = gbm(formula = formula, data = df)

Here's the used StructType as a String, for reference:

[airline: int nominal[GN, HO, ENT, XX, QR, FH, BT, NJE, HV, JTD, EE, 3U, CAT, PC, LO, AF, ABF, QS, AA, KL, FI, DK, FR, OR, JEF, WT, TF, AEG, HT, FX, 5X, D8, BB, DY, C3, Q7, SK, QY, JL, A3, TK, LH, EW, 3V, 5Y, 6X, TE, KK, RP, TT, N7, BIX, HCX, MB, VKA, MMD, GAV, ZT, BID, CVK], aircraftType: int nominal[???, A20N, A21N, A319, A320, A321, A333, A359, BCS3, AT72, AT75, C25C, E190, TST12], destination: int nominal[TOS, FNC, SJY, LEJ, BGO, KUN, BQH, BKK, LIN, LGW, PRG, CRL, BCN, JOE, BUD, OSL, ARN, BLR, SVL, BRQ, NHT, VNO, BVA, SGD, FMM, DUS, KEF, TRF, BGY, MAD, HKG, BER, ZTH, FCO, JYV, VBY, NCE, ATH, MHQ, RIX, KOK, FRA, MUC, PVG, KAJ, IST, EDI, LCY, HND, LPA, MMX, GDN, WMI, CHQ, CFU, KTT, CDG, BLQ, ERF, LBG, JTR, ORB, AYT, MIK, STN, AHO, GZP, TAY, MXP, KUO, FAB, AMS, MAN, VRN, ORD, DUB, AGP, PRN, CTA, RHO, HAM, BRU, DOL, OUL, ICN, ZAD, GVA, SAW, LHR, KAO, CGN, TIA, JFK, RVN, HER, PMI, IVL, SPU, SMI, VIE, RTM, TLL, STR, PVK, KGS, ZRH, CPH, VAA, KEM, RUH, KIX, DOH, LPI, LCA, ALC, SZG, BMA, KTW, LLA, VCE, POR, WAW, RKE, WRO, DEL, LIS, GOT, SNN, DLM, NRT, CGO, GRO, TMP, DBV, TKU, DRS, OTP, SIN, BLL, LAX, KRK, DFW, ACH, LJU, KIV, TDR, KLX, SEA, PSA, TIV, NAP, JKG, AOK, SKG, ANR, LBA, YFB, REU, HYV, NGO, RNN, YUL, DEN, BZG, PZY, AZI, ESS, LPX, SOO, PMO, TPS, CHR], weekday: int nominal[Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday], schTime: double interval, orderSize: double ratio]

Here's a code where I perform the same function with the same data, but with manually created StructType instance and a DataFrame read from a csv file. But: This doesn't reproduce the error.

import org.apache.commons.csv.CSVFormat
import smile.data.`type`.{DataTypes, StructField, StructType}
import smile.data.formula.Formula
import smile.data.measure.{IntervalScale, NominalScale, RatioScale}
import smile.io.Read
import smile.regression.gbm

import java.text.NumberFormat

object ModelTrainTest extends App
{
    val schema = {
        val weekday = new StructField("weekday", DataTypes.IntegerType,
            new NominalScale(Array(0,1,2,3,4,5,6),
                Array("Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday")))
        val scheduledDeparture = new StructField("schTime", DataTypes.DoubleType,
            new IntervalScale(NumberFormat.getIntegerInstance))
        val airline = new StructField("airline", DataTypes.IntegerType, new NominalScale(
            "GN", "HO", "ENT", "XX", "QR", "FH", "BT", "NJE", "HV", "JTD", "EE", "3U", "CAT", "PC", "LO", "AF", "ABF",
            "QS", "AA", "KL", "FI", "DK", "FR", "OR", "JEF", "WT", "TF", "AEG", "HT", "FX", "5X", "D8", "BB", "DY",
            "C3", "Q7", "SK", "QY", "JL", "A3", "TK", "BB", "EW", "3V", "5Y", "6X", "TE", "KK", "RP", "TT", "N7",
            "BIX", "HCX", "MB", "VKA", "MMD", "GAV", "ZT", "BID", "CVK"))
        val aircraftType = new StructField("aircraftType", DataTypes.IntegerType,
            new NominalScale("???", "A20N", "A21N", "A319", "A320", "A321", "A333", "A359", "BCS3", "AT72", "AT75",
                "C25C", "E190", "TST12"))
        val destination = new StructField("destination", DataTypes.IntegerType,
            new NominalScale("TOS", "FNC", "SJY", "LEJ", "BGO", "KUN", "BQH", "BKK", "LIN", "LGW", "PRG", "CRL", "BCN",
                "JOE", "BUD", "OSL", "ARN", "BLR", "SVL", "BRQ", "NHT", "VNO", "BVA", "SGD", "FMM", "DUS", "KEF", "TRF",
                "BGY", "MAD", "HKG", "BER", "ZTH", "FCO", "JYV", "VBY", "NCE", "ATH", "MHQ", "RIX", "KOK", "FRA", "MUC",
                "PVG", "KAJ", "IST", "EDI", "LCY", "HND", "LPA", "MMX", "GDN", "WMI", "CHQ", "CFU", "KTT", "CDG", "BLQ",
                "ERF", "LBG", "JTR", "ORB", "AYT", "MIK", "STN", "AHO", "GZP", "TAY", "MXP", "KUO", "FAB", "AMS", "MAN",
                "VRN", "ORD", "DUB", "AGP", "PRN", "CTA", "RHO", "HAM", "BRU", "DOL", "OUL", "ICN", "ZAD", "GVA", "SAW",
                "LHR", "KAO", "CGN", "TIA", "JFK", "RVN", "HER", "PMI", "IVL", "SPU", "SMI", "VIE", "RTM", "TLL", "STR",
                "PVK", "KGS", "ZRH", "CPH", "VAA", "KEM", "RUH", "KIX", "DOH", "LPI", "LCA", "ALC", "SZG", "BMA", "KTW",
                "LLA", "VCE", "POR", "WAW", "RKE", "WRO", "DEL", "LIS", "GOT", "SNN", "DLM", "NRT", "CGO", "GRO", "TMP",
                "DBV", "TKU", "DRS", "OTP", "SIN", "BLL", "LAX", "KRK", "DFW", "ACH", "LJU", "KIV", "TDR", "KLX", "SEA",
                "PSA", "TIV", "NAP", "JKG", "AOK", "SKG", "ANR", "LBA", "YFB", "REU", "HYV", "NGO", "RNN", "YUL", "DEN",
                "BZG", "PZY", "AZI", "ESS", "LPX", "SOO", "PMO", "TPS", "CHR"))
        val orderSizeField = new StructField("orderSize", DataTypes.DoubleType,
            new RatioScale(NumberFormat.getNumberInstance))

        new StructType(airline, aircraftType, destination, weekday, scheduledDeparture, orderSizeField)
    }

    // Converts the read data into a data frame
    val df = Read.csv("Fuel-AI/data/test-data/dataframe-censored.csv", CSVFormat.DEFAULT, schema)

    // Define the formula for the regression model
    val formula = Formula.lhs("orderSize")

    // Trains the Gradient-Boosted Regression Trees -model
    val model = gbm(formula = formula, data = df)

    println(s"\nTrained model:\n$model")
}

Could the issue be related to the DataFrame instance, somehow? df.toString and df.summary yield the same results on both codes. However, in the original code, the nominal values (airline, destination, aircraftType) are represented with integers. In this csv-based code they are represented with Strings only.

I used this code to construct the original NominalScale instances:

private def codeField(fieldName: String, allCodeValues: Iterable[StoredCode]) = {
    val orderedCodes = allCodeValues.toSeq.sortBy { _.id }
    new StructField(fieldName, DataTypes.IntegerType, nominalScaleFrom(orderedCodes) { _.id } { _.code })
}

private def nominalScaleFrom[A](values: Seq[A])(idOf: A => Int)(nameOf: A => String) =
    new NominalScale(values.view.map(idOf).toArray, values.view.map(nameOf).toArray)

These use a StoredCode class, which maps a String code to a Int database row id.

Input data The DataFrame instance (df) used in the above code is attached as a separate csv file: dataframe-censored.csv

Additional context Java version: 1.8.0_402 from OpenJDK Scala version: 2.13.14 SMILE version: 3.1.1 OS: Linux

haifengl commented 3 months ago

It should be related with data schema. Not sure how this code work

val (data, testingData, context) = collectTrainingData(...)

So it is impossible for me to reproduce the issue even with your data file. Nominal are stored as integer (with string representation). Note that nominal starts with 0. I suspect that your data values are not in range [0, k), where k is the levels of nominal values.

Mikkomario commented 3 months ago

Your suspection would be correct. The javadoc documentation between the relationship of levels and values was short, so I didn't realize to make that distinction. Thank you for your insights.

Mikkomario commented 3 months ago

I just confirmed that you were correct. The cause of this error seems to be the fact that the nominal values were not assigned correctly. It might be useful to include some indication for this requirement in the javadoc, or alternatively some sort of check / IllegalArgumentException, etc. Thank you very much for your help, @haifengl.