aws / sagemaker-spark

A Spark library for Amazon SageMaker.
https://aws.github.io/sagemaker-spark/
Apache License 2.0
297 stars 126 forks source link

What is the correct way to construct a ProtobufResponseRowDeserializer in PySpark? #118

Open RunshengSong opened 4 years ago

RunshengSong commented 4 years ago

Please fill out the form below.

System Information

Describe the problem

I have the following code in pyspark trying to to construct a SageMakerEstimator for a random cut forest image:

# Random Cut Forest Estimator
from pyspark.sql.types import *
from sagemaker_pyspark import IAMRole
from sagemaker import get_execution_role
from sagemaker_pyspark import SageMakerEstimator
from sagemaker_pyspark import RandomNamePolicyFactory
from sagemaker_pyspark import EndpointCreationPolicy
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker_pyspark.transformation.serializers.serializers import ProtobufRequestRowSerializer
from sagemaker_pyspark.transformation.deserializers.deserializers import ProtobufResponseRowDeserializer

response_schema = StructType([StructField("score", DoubleType(), False)])

estimator = SageMakerEstimator(
    trainingImage = get_image_uri(region, 'randomcutforest'), # Training image 
    modelImage = get_image_uri(region, 'randomcutforest'), # Model image
    requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName="features"),
    responseRowDeserializer = ProtobufResponseRowDeserializer(response_schema, protobufKeys["score"]),
    sagemakerRole = IAMRole(role),
    hyperParameters = {"feature_dim": "6"}, 
    trainingInstanceType = "ml.m4.4xlarge",
    trainingInstanceCount = 1,
    endpointInstanceType = "ml.t2.medium",
    endpointInitialInstanceCount = 1,
    trainingSparkDataFormat = "sagemaker",
    namePolicyFactory = RandomNamePolicyFactory("sparksm-4-"),
    endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_CONSTRUCT
    )

When I run this code using PySpark, I got the following error:

Py4JError: An error occurred while calling None.com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer. Trace:
py4j.Py4JException: Constructor com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer([class org.apache.spark.sql.types.StructType, class scala.collection.immutable.$colon$colon]) does not exist
    at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:179)
    at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:196)
    at py4j.Gateway.invoke(Gateway.java:237)
    at py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
    at py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:745)

The problem is in the ProtobufResponseRowDeserializer. According to the source code of this object for Scala, it should accept a Seq.

What is the correct counterpart in PySpark? Obviously it doesn't accept a list of string.

I tried to search the sagemaker-spark-sdk and I couldn't find any reference there.

chuyang-deng commented 4 years ago

Hi @RunshengSong, to use ProtobufResponseRowDeserializer with pyspark-sdk, the constructor accepts a StructType instead of a string: https://github.com/aws/sagemaker-spark/blob/master/sagemaker-pyspark-sdk/src/sagemaker_pyspark/transformation/deserializers/deserializers.py#L30

You will need to build a StructType (https://spark.apache.org/docs/1.1.1/api/python/pyspark.sql.StructType-class.html) that contains the feature column field and feed it to the ProtobufResponseRowDeserializer constructor.

RunshengSong commented 4 years ago

Hi @ChuyangDeng , thanks for the reply. I understand that I need to send a StructType as the schema to ProtobufResponseRowDeserializer, which is already the case in the code I provide above.

However, the problem I was asking is the protobufKeys attribute. When I don't send this parameter, it gives me an NPE when I display the Dataframe of prediction output.

What should be correct type of protobufKeys attribute?

Thanks again.