aws-samples / amazon-sagemaker-local-mode

Amazon SageMaker Local Mode Examples
MIT No Attribution
242 stars 59 forks source link

TypeError: predict() got an unexpected keyword argument 'pred_contribs' with xgboost v0.90 #12

Closed kevalshah90 closed 2 years ago

kevalshah90 commented 2 years ago

I used the inference.py file to make predictions locally. I am getting an error with pred_contribs argument in predict_fn() function. This was an issue with old version, but anything >0.8 recognizes pred_contribs argument.

eitansela commented 2 years ago

Hello @kevalshah90 , This sample code is using pred_contribs=True and completes successfully. What do you mean by old version?

kevalshah90 commented 2 years ago

Hi @eitansela

I am using the inference.py file and have trained my model using xgboost v0.90.

from xgboost import XGBRegressor model = XGBRegressor()

However, when I run the script and invoke the endpoint to make prediction, I run into the error. Here's what my inference.py code looks like:

import json
import os
from io import BytesIO
import pickle as pkl
import numpy as np
import sagemaker_xgboost_container.encoder as xgb_encoders
import xgboost as xgb
from os import listdir
from scipy import sparse

# Load your model
def model_fn(model_dir):
    """
    Deserialize and return fitted model.
    """
    model_file = "xgboost-model"
    booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))

    return booster

def input_fn(request_body, request_content_type):
    """
    The SageMaker XGBoost model server receives the request data body and the content type,
    and invokes the `input_fn`.

    Return a DMatrix (an object that can be passed to predict_fn).
    """

    if request_content_type == "text/csv":

        values = [i for i in request_body.split(',')]

        values = [val.strip() for val in values]

        # to 2-d numpy array
        npa = np.array(values).reshape(-1,1)

        return npa

    if request_content_type == "text/libsvm":

        return xgb_encoders.libsvm_to_dmatrix(request_body)

    else:
        raise ValueError("Content type {} is not supported.".format(request_content_type))

# Run Predictions
def predict_fn(input_data, model):
    """
    SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`.

    Return a two-dimensional NumPy array where the first columns are predictions
    and the remaining columns are the feature contributions (SHAP values) for that prediction.
    """

    names = model.get_booster().feature_names

    prediction = model.predict(input_data, validate_features=False)

    feature_contribs = model.predict(input_data, preds_contribs=True, validate_features=False)

    output = np.hstack((prediction[:, np.newaxis], feature_contribs))

    return output

def output_fn(predictions, content_type):
    """
    After invoking predict_fn, the model server invokes `output_fn`.
    """
    if content_type == "text/csv":
        return ",".join(str(x) for x in predictions[0])
    else:
        raise ValueError("Content type {} is not supported.".format(content_type))
eitansela commented 2 years ago

Got it. Please open a new issue in SageMaker XGBoost Container GitHub repo.