dmlc / tl2cgen

TL2cgen (TreeLite 2 C GENerator) is a model compiler for decision tree models
https://tl2cgen.readthedocs.io/en/latest/
Apache License 2.0
21 stars 6 forks source link

Multi-threaded prediction for treelite4j #5

Open thvasilo opened 2 years ago

thvasilo commented 2 years ago

Hello,

I'm running some benchmarks for treelite4j, testing out different batch sizes (splitting up a dataset into batches and predicting for each batch in sequence) and the number of threads passed to the Predictor object.

One thing I'm observing is that the number of threads set in the Predictor only seems to matter when my batch size is larger than 1, i.e. if I create a DMatrix with only a single row and call Predict on it, the number of threads the Predictor object was created with doesn't seem to matter.

Also, batch size doesn't seem to have a large effect when prediction is single threaded, is that expected as well?

Is it the case that multi-threading is only relevant when there's more than one row in the input DMatrix?

Would it be possible to use multi-threading for single-instance prediction as well, using each thread to predict for a single tree and merging the result in the end?

JMH results:

Benchmark           (batchSize)  (datapointNumber)    (treeliteThreads)   Mode  Cnt  Score   Error  Units
treelitePrediction            1             100000                    1  thrpt    3  0.053 ± 0.009  ops/s
treelitePrediction            1             100000                    8  thrpt    3  0.053 ± 0.005  ops/s
treelitePrediction           10             100000                    1  thrpt    3  0.060 ± 0.007  ops/s
treelitePrediction           10             100000                    8  thrpt    3  0.156 ± 0.515  ops/s
treelitePrediction          100             100000                    1  thrpt    3  0.064 ± 0.016  ops/s
treelitePrediction          100             100000                    8  thrpt    3  0.228 ± 0.667  ops/s

Some example code:

package me.tvas.benchmark

import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.DoubleStream;
import java.util.stream.Collectors;
import java.util.Iterator;

import ml.dmlc.treelite4j.java.*;
import ml.dmlc.treelite4j.DataPoint;
import ml.dmlc.treelite4j.DataPointFloat64;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;

import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.results.format.ResultFormatType;

@State(Scope.Benchmark)
public class PredictionBenchmarks {

    @Param({""})
    public String treeliteModelPath;

    @Param({"1000"})
    public int datapointNumber;

    @Param({"1"})
    public int batchSize;

    @Param({"1"})
    public int treeliteThreads;

    @Param({"/tmp"})
    public static String destinationFolder;

    Predictor treelitePredictor;
    DataSet randomDataSet;
    long numFeature;
    int numBatches;

    @Setup
    public void prepare() throws IOException, TreeliteError {
        Predictor treelitePredictor = new Predictor(this.treeliteModelPath, this.treeliteThreads, false);
        // Create random data for prediction.
        long numFeature = treelitePredictor.GetNumFeature();
        long rngSeed = 42;
        long[] shape = new long[]{datapointNumber, numFeature};
        Nd4j.getRandom().setSeed(rngSeed);;
        INDArray randomDoubles = Nd4j.rand(0.0, 32000.0, Nd4j.getRandom(), shape);
        INDArray dummyLabels = Nd4j.ones(new long[]{datapointNumber, 1});

        this.treelitePredictor = treelitePredictor;
        this.randomDataSet = new DataSet(randomDoubles, dummyLabels);
        this.numFeature = numFeature;
    }

    @Benchmark
    public void treelitePrediction(Blackhole blackhole) throws TreeliteError {
        List<Double> treelitePreds = new ArrayList<Double>(datapointNumber);
        List<DataSet> batches = this.randomDataSet.batchBy(batchSize);
        Iterator<DataSet> datasetIterator = batches.iterator();

        while(datasetIterator.hasNext()) {
            DataSet batch = datasetIterator.next();
            INDArray features = batch.getFeatures();
            long currentBatchSize = features.shape()[0];

            double[] doubleVector = features.data().asDouble();
            DMatrix dmat = new DMatrix(doubleVector, Double.NaN, currentBatchSize, this.numFeature);
            INDArray preds = treelitePredictor.predict(dmat, false, false);
            blackhole.consume(preds);
        }
    }

    public static void main(String[] args) throws Exception {

        Options opt = new OptionsBuilder()
                .include(PredictionBenchmarks.class.getSimpleName())
                .result(destinationFolder + "/" + "benchmarkResults.csv")
                .resultFormat(ResultFormatType.CSV)
                .forks(1)
                .threads(1)
                .jvmArgs("-ea")
                .build();

        new Runner(opt).run();
    }
}
hcho3 commented 2 years ago

Is it the case that multi-threading is only relevant when there's more than one row in the input DMatrix?

Yes, currently multi-threading is only useful when you have multiple rows in the input DMatrix. The rows of the DMatrix get distributed equally across worker threads.

guozhaochen commented 1 year ago

I am trying to call the predict in a multi-threading way (i.e., multiple threads calling the predict instead of multiple worker threads in the predictor), so I set the thread to 1 so threads are not blocked by the synchronization. However, I found out that the JavaCPP library used by the ND4J doesn't allow multi-threading as well, see here https://github.com/bytedeco/javacpp/blob/d23879af7a03a04c12b2374ae9d0850b9dda9d96/src/main/java/org/bytedeco/javacpp/Pointer.java#L699

Any particular reason that we need to use INDArray from ND4J?