apache / incubator-wayang

Apache Wayang(incubating) is the first cross-platform data processing system.
https://wayang.incubator.apache.org/
Apache License 2.0
174 stars 70 forks source link

Add first draft of ml-in-wayang.md #398

Closed juripetersen closed 3 months ago

juripetersen commented 5 months ago

This PR provides a short .md-guide that shows an examplory usage of the previously introduced abstraction of the cost model. The guide shows how it can be utilized in order to predict query plans runtimes with a pre-trained ML model.

Using Machine Learning for query optimization in Apache Wayang (incubating)

Apache Wayang (incubating) can be customized with concrete implementations of the EstimatableCost interface in order to optimize for a desired metric. The implementation can be enabled by providing it to a Configuration.

public class CustomEstimatableCost implements EstimatableCost {
    /* Provide concrete implementations to match desired cost function(s)
     * by implementing the interface in this class.
     */
}
public class WordCount {
    public static void main(String[] args) {
        /* Create a Wayang context and specify the platforms Wayang will consider */
        Configuration config = new Configuration();
        /* Provision of a EstimatableCost that implements the interface.*/
        config.setCostModel(new CustomEstimatableCost());
        WayangContext wayangContext = new WayangContext(config)
                .withPlugin(Java.basicPlugin())
                .withPlugin(Spark.basicPlugin());
        /*... omitted */
    }
}

In combination with an encoding scheme and a third party package to load ML models, the following example shows how to predict runtimes of query execution plans runtimes in Apache Wayang (incubating):

import org.apache.wayang.core.optimizer.costs.EstimatableCost;
import org.apache.wayang.core.optimizer.costs.EstimatableCostFactory;
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
import org.apache.wayang.core.optimizer.enumeration.LoopImplementation;
import org.apache.wayang.core.optimizer.enumeration.PlanImplementation;
import org.apache.wayang.core.platform.Junction;
import org.apache.wayang.core.plan.executionplan.ExecutionPlan;
import org.apache.wayang.core.plan.executionplan.ExecutionStage;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.ml.encoding.OneHotEncoder;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.exception.WayangException;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.ml.OrtMLModel;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import java.util.List;

public class MLCost implements EstimatableCost {
    public EstimatableCostFactory getFactory() {
        return new Factory();
    }

    public static class Factory implements EstimatableCostFactory {
        @Override public EstimatableCost makeCost() {
            return new MLCost();
        }
    }

    @Override public ProbabilisticDoubleInterval getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return ProbabilisticDoubleInterval.ofExactly(
                model.runModel(OneHotEncoder.encode(plan))
            );
        } catch(Exception e) {
            return ProbabilisticDoubleInterval.zero;
        }
    }

    @Override public ProbabilisticDoubleInterval getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return ProbabilisticDoubleInterval.ofExactly(
                model.runModel(OneHotEncoder.encode(plan))
            );
        } catch(Exception e) {
            return ProbabilisticDoubleInterval.zero;
        }
    }

    /** Returns a squashed cost estimate. */
    @Override public double getSquashedEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return model.runModel(OneHotEncoder.encode(plan));
        } catch(Exception e) {
            return 0;
        }
    }

    @Override public double getSquashedParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return model.runModel(OneHotEncoder.encode(plan));
        } catch(Exception e) {
            return 0;
        }
    }

    @Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator operator) {
        List<ProbabilisticDoubleInterval> intervalList = new ArrayList<ProbabilisticDoubleInterval>();
        List<Double> doubleList = new ArrayList<Double>();
        intervalList.add(this.getEstimate(plan, true));
        doubleList.add(this.getSquashedEstimate(plan, true));

        return new Tuple<>(intervalList, doubleList);
    }

    public PlanImplementation pickBestExecutionPlan(
            Collection<PlanImplementation> executionPlans,
            ExecutionPlan existingPlan,
            Set<Channel> openChannels,
            Set<ExecutionStage> executedStages) {
        final PlanImplementation bestPlanImplementation = executionPlans.stream()
                .reduce((p1, p2) -> {
                    final double t1 = p1.getSquashedCostEstimate();
                    final double t2 = p2.getSquashedCostEstimate();
                    return t1 < t2 ? p1 : p2;
                })
                .orElseThrow(() -> new WayangException("Could not find an execution plan."));
        return bestPlanImplementation;
    }
}

Third-party packages such as OnnxRuntime can be used to load pre-trained .onnx files that contain desired ML models.

import org.apache.wayang.core.api.Configuration;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;

import java.util.Vector;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.Map;
import java.util.function.BiFunction;

public class OrtMLModel {

    private static OrtMLModel INSTANCE;

    private OrtSession session;
    private OrtEnvironment env;

    private final Map<String, OnnxTensor> inputMap = new HashMap<>();
    private final Set<String> requestedOutputs = new HashSet<>();

    public static OrtMLModel getInstance(Configuration configuration) throws OrtException {
        if (INSTANCE == null) {
            INSTANCE = new OrtMLModel(configuration);
        }

        return INSTANCE;
    }

    private OrtMLModel(Configuration configuration) throws OrtException {
        this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
    }

    public void loadModel(String filePath) throws OrtException {
        if (this.env == null) {
            this.env = OrtEnvironment.getEnvironment();
        }

        if (this.session == null) {
            this.session = env.createSession(filePath, new OrtSession.SessionOptions());
        }
    }

    public void closeSession() throws OrtException {
        this.session.close();
        this.env.close();
    }

    /**
     * @param encodedVector
     * @return NaN on error, and a predicted cost on any other value.
     * @throws OrtException
     */
    public double runModel(Vector<Long> encodedVector) throws OrtException {
        double costPrediction;

        OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
        this.inputMap.put("input", tensor);
        this.requestedOutputs.add("output");

        BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
            try {
                return ((double[]) r.get(s).get().getValue())[0];
            } catch (OrtException e) {
                return Double.NaN;
            }
        };

        try (Result r = session.run(inputMap, requestedOutputs)) {
            costPrediction = unwrapFunc.apply(r, "output");
        }

        return costPrediction;
    }
}