tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
660 stars 110 forks source link

Check failed: NDIMS == dims() (2 vs. 3) when using TFDF RandomForestModel and using TF Serving #170

Open alexctslade opened 1 year ago

alexctslade commented 1 year ago

Summary Sometimes features passed to TF Serving need to be wrapped in brackets to match the dimension, and the correct dimensions are required. However, in some cases if the incorrect dimensions are passed TF Serving crashes, when it would be expected to handle the error In particular:

Detail on the cases below: (1) When TF serving is used with a Tensorflow model with numerical and categorical features, using tensorflow.feature_column to dummify the categorical features

(2) When TF serving is used with a TFDF Random Forest Model with only numeric features

(3) When TF serving is used with a TFDF Random Forest Model with numeric features and categorical features, and a preprocessing model

(4) When TF serving is used with a TFDF Random Forest Model with numeric features and categorical features, and the RandomForestModel categorical_set_split_max_num_items and categorical_set_split_min_item_frequency are used to dummify the categorical field

Related issue https://github.com/tensorflow/tensorflow/issues/9505

System details Python 3.10 Tensorflow version 2.11.1 Tensorflow Decision Forest 1.2.0 tensorflow/serving latest

Code to recreate

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend, Input, layers, Model
from tensorflow import feature_column, random, string

import requests
import json

"""Make some fake data"""
rows = []
for _ in range(10000):
    y = np.random.uniform(0, 10000)
    x1 = y + np.random.normal(1000, scale=500)
    x2 = x1 + np.random.normal(1000, scale=500)
    tmp_x3 = y + np.random.normal(1000, scale=500)
    rows.append([x1, x2, tmp_x3, y])
df = pd.DataFrame(rows, columns=['x1', 'x2', 'tmp_x3', 'y'])
df['x3'] = 'cat' + pd.qcut(df['tmp_x3'], q=20, retbins=False, labels=False).astype('str')
train, valid = train_test_split(df, test_size=0.15)
train, test = train_test_split(train, test_size=0.15)
print(train.shape[0])
print(valid.shape[0])
print(test.shape[0])

"""(1) Model with numerical and cat features"""
feature_layer_input = {}
feature_columns = []
for feature in ['x1', 'x2']:
    feature_columns.append(feature_column.numeric_column(feature))
    feature_layer_input[feature] = Input(shape=(1,), name=feature)
for feature in ['x3']:
    feature_column_emb = feature_column.categorical_column_with_hash_bucket(
        feature, hash_bucket_size=10
    )
    feature_columns.append(feature_column.indicator_column(feature_column_emb))
    feature_layer_input[feature] = Input(shape=(1,), dtype=string, name=feature)

feature_layer = layers.DenseFeatures(feature_columns)(feature_layer_input)
hidden_layer = layers.Dense(units=80, activation='relu')(feature_layer)
output_layer = layers.Dense(units=1, activation='relu')(hidden_layer)
model = Model(feature_layer_input, output_layer)
from tensorflow.keras.losses import MeanSquaredError
model.compile(loss=MeanSquaredError(), optimizer='adam')
model.summary()

from beeswax.hexagon.real_time_models.model_common.model_utils import df_to_dataset
from sklearn.metrics import r2_score
features = ['x1', 'x2', 'x3']
train_ds = df_to_dataset(train[features], labels=train['y'])
valid_ds = df_to_dataset(valid[features], labels=valid['y'])
test_ds = df_to_dataset(test[features], labels=test['y'])

model.fit(train_ds, epochs=10)

print(r2_score(valid['y'], model.predict(valid_ds)))
model.save('/home/username/model_assessment/test_model/1/')

""" Now run TF Serving docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model,target=/models/test_model -e MODEL_NAME=test_model -t tensorflow/serving """

test_data = {"signature_name": "serving_default",
        "instances": [{"x1": [300],
            "x2": [220],
            "x3": ["cat11"]
        }]
}
headers = {"content-type": "application/json"}
response = requests.post('http://localhost:8501/v1/models/test_model:predict', data=json.dumps(test_data), headers=headers)
print(response.text)
# Success: '{\n    "predictions": [[250.152435]\n    ]\n}'

# Floats are fine without brackets
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": 300.0,
            "x2": 220.0,
            "x3": ["cat11"]
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

# But cat needs brackets
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": 300.0,
            "x2": 220.0,
            "x3": "cat11"
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

"""(2) Random Forest with TFDF - numerical features"""
import tensorflow_decision_forests as tfdf
from tensorflow_decision_forests.tensorflow import core_inference as tf_core
Task = tf_core.Task

rf = tfdf.keras.RandomForestModel(
    task=Task.REGRESSION,
    num_trees=100,
    max_depth=20,
    min_examples=5,
    num_candidate_attributes_ratio=0.45,
    split_axis="AXIS_ALIGNED", #"SPARSE_OBLIQUE",
    growing_strategy="LOCAL", #"BEST_FIRST_GLOBAL",
    )

features = ['x1', 'x2']
train_ds = df_to_dataset(train[features], labels=train['y'], batch_size=1000)
valid_ds = df_to_dataset(valid[features], labels=valid['y'], batch_size=1000)
test_ds = df_to_dataset(test[features], labels=test['y'], batch_size=1000)

rf.fit(train_ds)
print(r2_score(valid['y'], rf.predict(valid_ds)))

rf.save('/home/username/model_assessment/test_model_rf1/1/')

""" Run model in TFServing docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model_rf1,target=/models/test_model_rf1 -e MODEL_NAME=test_model_rf1 -t tensorflow/serving """

# Without brackets, floats are fine
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": 300.0,
            "x2": 220.0,
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf1:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

# With brackets, floats cause tf serving to fail
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": [300.0],
            "x2": [220.0],
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf1:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

"""
Causes 
2023-03-21 10:31:24.921448: F external/org_tensorflow/tensorflow/core/framework/tensor_shape.cc:45] Check failed: NDIMS == dims() (2 vs. 3)Asking for tensor of 2 dimensions from a tensor of 3 dimensions
/usr/bin/tf_serving_entrypoint.sh: line 3:     7 Aborted              (core dumped) tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} "$@"
"""

"""(3) Now make a Random Forest with feature transformation on cat feature"""
feature_layer_input = {}
feature_columns = []
for feature in ['x1', 'x2']:
    feature_columns.append(feature_column.numeric_column(feature))
    feature_layer_input[feature] = Input(shape=(1,), name=feature)
for feature in ['x3']:
    feature_column_emb = feature_column.categorical_column_with_hash_bucket(
        feature, hash_bucket_size=10
    )
    feature_columns.append(feature_column.indicator_column(feature_column_emb))
    feature_layer_input[feature] = Input(shape=(1,), dtype=string, name=feature)

feature_layer = layers.DenseFeatures(feature_columns)(feature_layer_input)
feat_x3_model = Model(feature_layer_input, feature_layer)

rf = tfdf.keras.RandomForestModel(
    task=Task.REGRESSION,
    preprocessing=feat_x3_model,
    num_trees=100,
    max_depth=20,
    min_examples=5,
    num_candidate_attributes_ratio=0.45,
    split_axis="AXIS_ALIGNED", #"SPARSE_OBLIQUE",
    growing_strategy="LOCAL", #"BEST_FIRST_GLOBAL",
    )

features = ['x1', 'x2', 'x3']
train_ds = df_to_dataset(train[features], labels=train['y'], batch_size=1000)
valid_ds = df_to_dataset(valid[features], labels=valid['y'], batch_size=1000)
test_ds = df_to_dataset(test[features], labels=test['y'], batch_size=1000)

rf.fit(train_ds)
print(r2_score(valid['y'], rf.predict(valid_ds)))
rf.save('/home/username/model_assessment/test_model_rf2/1/')

""" Run in TF Serving docker run -p 8501:8501 --mount type=bind,source=/home/username/model_assessment/test_model_rf2,target=/models/test_model_rf2 -e MODEL_NAME=test_model_rf2 -t tensorflow/serving """

# Without brackets. Fine
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": 300.0,
            "x2": 220.0,
            "x3": "cat11"
        },
            {"x1": 400.0,
            "x2": 320.0,
            "x3": "cat11"
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

# With brackets. Fine
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": [300.0],
            "x2": [220.0],
            "x3": ["cat11"]
        }]
}
headers = {"content-type": "application/json"}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

# Brackets on floats, none on cat. Fine
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": [[300.0]],
            "x2": [[[[220.0]]]],
            "x3": "cat11"
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

# Brackets on cat, none on floats. Fine
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": 300.0,
            "x2": 220.0,
            "x3": ["cat11"]
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf2:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

#(4) Build a model with no preprocessing model, categorical dummifying handled by the model
rf = tfdf.keras.RandomForestModel(
    task=Task.REGRESSION,
    num_trees=100,
    max_depth=20,
    min_examples=5,
    num_candidate_attributes_ratio=0.45,
    split_axis="AXIS_ALIGNED", #"SPARSE_OBLIQUE",
    growing_strategy="LOCAL", #"BEST_FIRST_GLOBAL",
    categorical_set_split_max_num_items=100,
    categorical_set_split_min_item_frequency=1,
    )

features = ['x1', 'x2', 'x3']
train_ds = df_to_dataset(train[features], labels=train['y'], batch_size=1000)
valid_ds = df_to_dataset(valid[features], labels=valid['y'], batch_size=1000)
test_ds = df_to_dataset(test[features], labels=test['y'], batch_size=1000)

rf.fit(train_ds)
print(r2_score(valid['y'], rf.predict(valid_ds)))

rf.save('/home/username/username/model_assessment/test_model_rf3/1/')

""" Run in TF Serving docker run -p 8501:8501 --mount type=bind,source=/home/username/username/model_assessment/test_model_rf3,target=/models/test_model_rf3 -e MODEL_NAME=test_model_rf3 -t tensorflow/serving """

# No brackets works
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": 300.0,
            "x2": 220.0,
            "x3": "cat11"
        },
        {"x1": 330.0,
            "x2": 230.0,
            "x3": "cat13"
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf3:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

# Brackets on any (float or categorical) causes TF serving to fail
test_data = {"signature_name": "serving_default",
        "instances": [{"x1": [300.0],
            "x2": [220.0],
            "x3": ["cat11"]
        }]
}
response = requests.post('http://localhost:8501/v1/models/test_model_rf3:predict', data=json.dumps(test_data), headers=headers)
print(response.text)

"""
Causes
2023-03-21 10:39:01.342750: F external/org_tensorflow/tensorflow/core/framework/tensor_shape.cc:45] Check failed: NDIMS == dims() (2 vs. 3)Asking for tensor of 2 dimensions from a tensor of 3 dimensions
/usr/bin/tf_serving_entrypoint.sh: line 3:     7 Aborted                 (core dumped) tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} "$@"
"""
rstz commented 1 year ago

Hi, thank you for the detailed report, I will look into this and report back!