aws / sagemaker-spark

A Spark library for Amazon SageMaker.
https://aws.github.io/sagemaker-spark/
Apache License 2.0
297 stars 126 forks source link

SagemakerModel.transform() doesn't use model's sagemakerClient #28

Open harthur opened 6 years ago

harthur commented 6 years ago

If you create a new SageMakerModel instance (say, with fromModelS3Path()), you can pass in your own sagemakerClient. However, when you go to use the model after it's been created, transform() does not use that client to send prediction requests. It appears to hardcode a AmazonSageMakerRuntimeClientBuilder.defaultClient instead in RequestBatchIterator.

Pardon my ignorance, but is there a reason that it can't just pass the sagemakerClient through?

andremoeller commented 6 years ago

Hi @harthur,

Thanks for using Amazon SageMaker!

There are two SageMaker clients: the AmazonSageMaker client which is used to create and manage Training Jobs, Endpoints and such, and the AmazonSageMakerRuntime which is just used for predictions (with InvokeEndpointRequest in transform()).

Instead of injecting this client, you have to change the value of this var in the RequestBatchIterator singleton:

RequestBatchIterator. sagemakerRuntime = mySageMakerRuntimeClient

Why? Because Spark has to serialize tasks to send them to workers in the mapPartition() call in SageMakerModel.transform(), and these AWS clients aren't serializable. So instead of serializing the RequestBatchIterator directly, we serialize a factory method that creates a RequestBatchIterator.

https://github.com/aws/sagemaker-spark/blob/master/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerModel.scala#L509

https://github.com/aws/sagemaker-spark/blob/master/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/transformation/util/RequestBatchIterator.scala#L228-L239

Please feel free to reopen this if it doesn't answer your question. Thanks!

harthur commented 6 years ago

That does answer my question, thanks!

However, how do you set RequestBatchIterator.sagemakerRuntime without first unintentionally building the default client? Before I can override it, this line: https://github.com/aws/sagemaker-spark/blob/master/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/transformation/util/RequestBatchIterator.scala#L34 I believe is building the client and looking for env vars. Getting this error: Caused by: com.amazonaws.SdkClientException: Unable to find a region via the region provider chain. Must provide an explicit region in the builder or setup environment to supply a region.

Let me know if this isn't the right place to ask more questions.

andremoeller commented 6 years ago

Hey @harthur,

I'm not sure exactly when the client is run. It's possible we should make that a lazy val or otherwise delay instantiation. Do you have a stack trace / code you're running where the client is instantiated before you can set it?

Otherwise: you're right about the client builder -- to get around that error, you can set AWS_DEFAULT_REGION to your region (like us-west-2 or us-east-1 or us-east-2 or eu-west-1, any of the SageMaker regions) or set it in your AWS config file with aws configure.

Thanks!

harthur commented 6 years ago

The stacktrace is

Exception in thread "main" java.lang.ExceptionInInitializerError
    at TestSagemakerJob.run(TestSagemakerJob.scala:59)
...
Caused by: com.amazonaws.SdkClientException: Unable to find a region via the region provider chain. Must provide an explicit region in the builder or setup environment to supply a region.
    at com.amazonaws.client.builder.AwsClientBuilder.setRegion(AwsClientBuilder.java:371)
    at com.amazonaws.client.builder.AwsClientBuilder.configureMutableProperties(AwsClientBuilder.java:337)
    at com.amazonaws.client.builder.AwsSyncClientBuilder.build(AwsSyncClientBuilder.java:46)
    at com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntimeClientBuilder.defaultClient(AmazonSageMakerRuntimeClientBuilder.java:44)
    at com.amazonaws.services.sagemaker.sparksdk.transformation.util.RequestBatchIterator$.<init>(RequestBatchIterator.scala:35)
    at com.amazonaws.services.sagemaker.sparksdk.transformation.util.RequestBatchIterator$.<clinit>(RequestBatchIterator.scala)

Where that line is:

RequestBatchIterator.sagemakerRuntime = sagemakerRuntimeClient
andremoeller commented 6 years ago

Hi @harthur ,

Thanks for the stacktrace! Just FYI: I haven't gotten a chance to reproduce this yet, but this definitely seems like a bug. I suppose that workers are still trying to create the standard client. If you're able to, could you post your code?

Otherwise, to unblock yourself in the short term, it seems like you'll have to get the region from the environment in your workers (by setting AWS_DEFAULT_REGION or writing to ~/.aws/configure first).

Thanks again!

harthur commented 6 years ago

Yeah, I got around it by adding some Java system properties, but ideally you would be able to build your own client and keep that info isolated. That's the pattern we use for all of our other AWS connections, so it's a bit awkward to break that just for Sagemaker.

harthur commented 6 years ago

Sorry, to address your first question, I think it's happening in the driver rather than on any workers (or, that's where this particular line of code is). It happens just by instantiating RequestBatchIterator.

andremoeller commented 6 years ago

Hey @harthur ,

Ah, interesting, thanks for the update! Glad to hear you got it working, but you're right, we should let users build their own client. I've put a fix for this on our backlog, thanks for reporting this.

We'll keep this issue open and update this when the fix is in, but I can't give an ETA on when we'll be able to do this.

Thanks again!

sreemani commented 6 years ago

Hi @harthur , I am also facing this same issue. Created a model from End point and calling transform gives this error. Can you please share the java system properties that helped you get around?

Thanks

harthur commented 6 years ago

@sreemani You have to set these system properties: https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default

dbtorrance commented 3 years ago

Any update on this issue? This bug has been open since 2018, and it is causing our team some problems.