combust / mleap

MLeap: Deploy ML Pipelines to Production
https://combust.github.io/mleap-docs/
Apache License 2.0
1.5k stars 310 forks source link

How to use XGBoost PySpark API with MLeap? #867

Open venkatacrc opened 11 months ago

venkatacrc commented 11 months ago

Problem description: We were able to serialize the XGBoost model with MLeap using the older PySpark API (https://github.com/dmlc/xgboost/pull/4656) as shown below:

import mleap.pyspark
from mleap.pyspark.spark_support import SimpleSparkSerializer

trans_model = model.transform(df)
local_path = "jar:file:/tmp/pyspark.model.zip"
model.serializeToBundle(local_path, trans_model)

But We are not able to do with official PySpark API (https://xgboost.readthedocs.io/en/stable/tutorials/spark_estimator.html) support. Is there anything flags that I need to use to make this work? We use MLeap to store the model as a Serialized bundle and use it in the Java runtime enviroment for model serving.

Steps Followed:

Download spark

wget https://dlcdn.apache.org/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz tar -xvf spark-3.3.3-bin-hadoop3.tgz

Download xgboost jars

wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j_2.12/1.7.3/xgboost4j_2.12-1.7.3.jar wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark_2.12/1.7.3/xgboost4j-spark_2.12-1.7.3.jar

Build MLeap fat jar :

https://github.com/combust/mleap/blob/master/mleap-databricks-runtime-fat/README.md git clone --recursive https://github.com/combust/mleap.git cd mleap git checkout tags/v0.22.0 sbt mleap-databricks-runtime-fat/assembly cp mleap-databricks-runtime-fat/target/scala-2.12/mleap-databricks-runtime-fat-assembly-0.22.0.jar ../spark-3.3.3-bin-hadoop3/jars

Install python requirements

pip install mleap==0.22.0 pip install xgboost==1.7.3 pip install pyarrow

Running the example

cd ../spark-3.3.3-bin-hadoop3 ./bin/spark-submit example.py

error log

23/09/13 22:06:29 INFO CodeGenerator: Code generated in 25.565617 ms +---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+ |feat1|feat2|feat3|feat4|feat5|feat7|feat8| feat9|feat10|feat11|feat12|label| features|rawPrediction|prediction|probability| +---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+ | 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 5| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[5.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 5| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[5.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| +---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+

Traceback (most recent call last): File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/bin/code1.py", line 47, in model.serializeToBundle(local_path, predictions) File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 25, in serializeToBundle serializer.serializeToBundle(self, path, dataset=dataset) File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 42, in serializeToBundle self._java_obj.serializeToBundle(transformer._to_java(), path, dataset._jdf) File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 363, in _to_java AttributeError: 'SparkXGBClassifierModel' object has no attribute '_to_java' 23/09/13 22:06:30 INFO SparkContext: Invoking stop() from shutdown hook

example.py

Thanks @agsachin for creating the instructions that can be easily reproduced on Mac.

venkatacrc commented 11 months ago

@WeichenXu123 Any insights on this issue? Please help us.

austinzh commented 10 months ago

@venkatacrc MLeap serializeToBundle work only with java objects. The latest XGBoost Estimator is a pure python implementation instead a Java Estimator OLD API

class XGboostEstimator(JavaEstimator, XGBoostReadable, JavaMLWritable, ParamGettersSetters):

vs NEW API

class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
venkatacrc commented 10 months ago

@austinzh Thank you. Will there be any support for the Java like the old API in the future?