Open venkatacrc opened 11 months ago
@WeichenXu123 Any insights on this issue? Please help us.
@venkatacrc MLeap serializeToBundle
work only with java objects.
The latest XGBoost Estimator is a pure python implementation instead a Java Estimator
OLD API
class XGboostEstimator(JavaEstimator, XGBoostReadable, JavaMLWritable, ParamGettersSetters):
vs NEW API
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
@austinzh Thank you. Will there be any support for the Java like the old API in the future?
Problem description: We were able to serialize the XGBoost model with MLeap using the older PySpark API (https://github.com/dmlc/xgboost/pull/4656) as shown below:
But We are not able to do with official PySpark API (https://xgboost.readthedocs.io/en/stable/tutorials/spark_estimator.html) support. Is there anything flags that I need to use to make this work? We use MLeap to store the model as a Serialized bundle and use it in the Java runtime enviroment for model serving.
Steps Followed:
Download spark
wget https://dlcdn.apache.org/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz tar -xvf spark-3.3.3-bin-hadoop3.tgz
Download xgboost jars
wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j_2.12/1.7.3/xgboost4j_2.12-1.7.3.jar wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark_2.12/1.7.3/xgboost4j-spark_2.12-1.7.3.jar
Build MLeap fat jar :
https://github.com/combust/mleap/blob/master/mleap-databricks-runtime-fat/README.md git clone --recursive https://github.com/combust/mleap.git cd mleap git checkout tags/v0.22.0 sbt mleap-databricks-runtime-fat/assembly cp mleap-databricks-runtime-fat/target/scala-2.12/mleap-databricks-runtime-fat-assembly-0.22.0.jar ../spark-3.3.3-bin-hadoop3/jars
Install python requirements
pip install mleap==0.22.0 pip install xgboost==1.7.3 pip install pyarrow
Running the example
cd ../spark-3.3.3-bin-hadoop3 ./bin/spark-submit example.py
error log
23/09/13 22:06:29 INFO CodeGenerator: Code generated in 25.565617 ms +---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+ |feat1|feat2|feat3|feat4|feat5|feat7|feat8| feat9|feat10|feat11|feat12|label| features|rawPrediction|prediction|probability| +---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+ | 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 5| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[5.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| | 5| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[5.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]| +---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+
Traceback (most recent call last): File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/bin/code1.py", line 47, in
model.serializeToBundle(local_path, predictions)
File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 25, in serializeToBundle
serializer.serializeToBundle(self, path, dataset=dataset)
File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 42, in serializeToBundle
self._java_obj.serializeToBundle(transformer._to_java(), path, dataset._jdf)
File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 363, in _to_java
AttributeError: 'SparkXGBClassifierModel' object has no attribute '_to_java'
23/09/13 22:06:30 INFO SparkContext: Invoking stop() from shutdown hook
example.py
Thanks @agsachin for creating the instructions that can be easily reproduced on Mac.