rstudio / vetiver-r

Version, share, deploy, and monitor models
https://rstudio.github.io/vetiver-r/
Other
181 stars 27 forks source link

Predict function only returns class not probability? #268

Closed HanselPalencia closed 8 months ago

HanselPalencia commented 8 months ago

Lets look at the following example:

Say we have a logistic regression model (GLM) for the mtcars dataset in which we deploy using vetiver as below:

glm(am ~ mpg + wt + drat, family = "binomial")

Everything is working correctly and when you use the predict function as shown in the docs it returns a 0 or 1, as these are the inputs used in the recipe.

Usually in predict() functions for GLM models you are able to include a type parameter to define whether you want a probability or class returned.

I haven't been able to get this functionality to work with this function, I've tried including the type parameter both in the regular predict() function but to no avail.

For reference, this is my starting point:

predict(
      endpoint,
      new_data = data_predict(),
      httr::add_headers(Authorization = paste("Key", api_key))) %>% 
      pull(.pred_class)
ncullen93 commented 8 months ago

Hard to tell without seeing how your endpoint is created but I believe the type parameter has to be passed in when you create your endpoint. Extra parameters to the predict function are basically hard-coded when the endpoint is created so it doesn't seem like something you can pass dynamically when the endpoint is called.

So you would add type="response" to vetiver_api(..., type='response') or vetiver_pr_post(..., type='response') .

juliasilge commented 8 months ago

Yes, you'll need to specify what kind of predictions you want when you create your API, not when you call the API. This is because when you create your API, you are specifying what code will be run when you call it later. So something like this:

library(vetiver)
library(plumber)

mtcars_glm <- glm(mpg ~ ., data = mtcars)
v <- vetiver_model(mtcars_glm, "cars_glm")

pr() |> vetiver_api(v, type = "response")
#> # Plumber router with 4 endpoints, 4 filters, and 1 sub-router.
#> # Use `pr_run()` on this object to start the API.
#> ├──[queryString]
#> ├──[body]
#> ├──[cookieParser]
#> ├──[sharedSecret]
#> ├──/logo
#> │  │ # Plumber static router serving from directory: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library/vetiver
#> ├──/metadata (GET)
#> ├──/ping (GET)
#> ├──/predict (POST)
#> └──/prototype (GET)

Created on 2024-01-08 with reprex v2.0.2

If you are using vetiver_deploy_rsconnect() or vetiver_prepare_docker() you pass this through like predict_args = list(type = "response").

juliasilge commented 8 months ago

Thanks for your question @HanselPalencia! 🙌 Let us know if further issues come up as you use vetiver.

HanselPalencia commented 7 months ago

Hi @juliasilge & @ncullen93! Thanks for the expert assist. I've implemented the solution from above but unfortunately this is resulting in a 500 error when calling the API.

I used the vetiver_deploy_rsconnect() function and passed the predict_args = list(type = "response"). Has anyone reported any issues with this in the past?

I did some testing on my end and when I remove this argument from the vetiver_deploy_rsconnect() I am able to successfully return a prediction (i.e. 200) code.

See below how I generated my api.

v <- vetiver_model(final_fit_to_deploy, model_name = "name")

model_board <- pins::board_connect(auth = "manual", server = "myserver.com", key = "myauthkey", versioned = TRUE)
model_board |> vetiver_pin_write(v)
model_board |>
  vetiver_write_plumber("USER/name")

vetiver_deploy_rsconnect(
  model_board, 
  "USER/name",
  predict_args = list(type = "response") # If I remove this argument it works returning a Yes/No Response
)

I also went into the plumber.R file generated by the vetiver_write_plumber() function and add the type = "response" argument to the vetiver_api() function which did not change anything either.

juliasilge commented 7 months ago

@HanselPalencia Can you try setting the debug argument so errors from R are surfaced in your API response, like predict_args = list(type = "response", debug = TRUE)? If you have further problems (i.e. you don't understand what you get in your returned error) can you open a new issue outlining what you see, rather than commenting on this old/closed one?

Good luck! 🙌