rstudio / vetiver-r

Version, share, deploy, and monitor models
https://rstudio.github.io/vetiver-r/
Other
178 stars 27 forks source link

Getting class probabilities from vetiver / Sagemaker endpoint? #248

Closed sebsilas closed 11 months ago

sebsilas commented 11 months ago

Hi again!

So, I've got my vetiver-Sagemaker hookup working, using the setup described here.

Thus, when I run:

new_data <- my_data %>% 
  slice_sample(n = 50) 

predict(new_endpoint, new_data)

I am getting back class predictions, as expected (the model is an xgboost classification model).

But how can I return the class probabilities? Normally I would add type = 'prob' to the predict method in tidymodels setting. But the following does not work:

predict(new_endpoint, new_data, type = 'prob')

Error in sm_runtime$invoke_endpoint(object$model_endpoint, data_json, : unused argument (type = "prod")

Is it possible?

juliasilge commented 11 months ago

Yes, absolutely! This is one of the aspects of your model API you need to set up when you deploy the model, though, not something you pass in when you call your API via predict().

If you are working with a vetiver_api() object, you pass arguments like this through the dots here:

pr() |> vetiver_api(v, type = "prob")

If you are working with one of the functions to deploy like vetiver_deploy_sagemaker(), then you pass it in via predict_args:

endpoint <- vetiver_deploy_sagemaker(
    board = your_board,
    name = "your_model_name",
    instance_type = "ml.t2.medium",
    predict_args = list(type = "prob")
)

If you are finding it confusing why you need to do this when you set up your API (not when you call it), it might be helpful to think about deployment being when you set up exactly what code will be called when you eventually call the API. You can't change what code will be run after you deploy or set up your API.

sebsilas commented 11 months ago

Brilliant, thank you again @juliasilge. I can confirm that's all working!

What you said makes total sense to me, I just wondered if the class probabilities were already being returned somehow unbeknownst to me.

Thank you for this great package!