tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
658 stars 108 forks source link

tfdf.model_plotter.plot_model() is broken for GradientBoostedTreesModel and CartModel #129

Closed fhossfel closed 2 years ago

fhossfel commented 2 years ago

I am using tfdf 0.2.4 and can successfully train a model and plot it using the plot_model() function.

model = tfdf.keras.RandomForestModel()
model.fit(train_ds)
model.compile(metrics=["accuracy"])
evaluation = model.evaluate(test_ds)
with open("model.html", "w") as html_file:
    html_file.write(tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=10))

For my current task I get a decision tree graph consisting of two decision nodes and tree outputs. The key line in the generated HTML file seems to be this one:

display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "PROBABILITY", "distribution": [0.006622516556291391, 0.695364238410596, 0.2781456953642384, 0.019867549668874173], "num_examples": 151.0}, "condition": {"type": "CATEGORICAL_IS_IN", "attribute": "product_group", "mask": ["DIY"]}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 0.0, 1.0, 0.0], "num_examples": 42.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.009174311926605505, 0.963302752293578, 0.0, 0.027522935779816515], "num_examples": 109.0}, "condition": {"type": "NUMERICAL_IS_HIGHER_THAN", "attribute": "height", "threshold": 42.0}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 1.0, 0.0, 0.0], "num_examples": 103.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.16666666666666666, 0.3333333333333333, 0.0, 0.5], "num_examples": 6.0}}]}]}, "#tree_plot_24de9183c1d54e6b8c963d372b714bc0")

If I use exactly the same code but replace the RandomForestModelwith a GradientBoostedTreesModelI only get one decision and two outputs:

display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "REGRESSION", "value": -0.09703703969717026, "num_examples": 135.0, "standard_deviation": 0.08574694002066838}, "condition": {"type": "NUMERICAL_IS_HIGHER_THAN", "attribute": "length", "threshold": 227.0}, "children": [{"value": {"type": "REGRESSION", "value": -0.020000001415610313, "num_examples": 5.0, "standard_deviation": 0.4}}, {"value": {"type": "REGRESSION", "value": -0.10000000149011612, "num_examples": 130.0, "standard_deviation": 0.0}}]}, "#tree_plot_73421ac8ea9a47a88761b7441afab47c")

This can't be right since the inferences of the GradientBoostedTreesModelare perfect (100% correct, thanks!) and that requires to take more features into account that the length od the classified object. Additionally

The model summary is below. (I have replaced some sensitive feature names). I am not really an expert but if I read the summary correctly than the decision tree should have a depth of 5 and 26 to 27 nodes. On the other hand I would have expected more noees to show for the RandomForestModel, too. ¯_(ツ)_/¯

If there is any additional information I can provide please let me know.

Model: "gradient_boosted_trees_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (11):
    parcel_count
    ft_ot_text
    girth
    height
    length
    product_group
    tipping_risk
    shipping_mode
    volume
    weight
    width

No weights

Variable Importance: MEAN_MIN_DEPTH:
    1.             "parcel_count"  3.890381 ################
    2.            "__LABEL"  3.890381 ################
    3. "shipping_mode"  3.889688 ###############
    4.           "girth"  3.541476 #############
    5.         "ft_ot_text"  3.516287 #############
    6.              "width"  3.287958 ###########
    7.             "volume"  3.184331 ##########
    8.             "length"  3.039927 #########
    9.             "height"  2.885094 ########
   10.       "tipping_risk"  2.538273 ######
   11.             "weight"  2.267620 ####
   12.           "product_group"  1.719362 

Variable Importance: NUM_AS_ROOT:
    1.     "product_group" 616.000000 ################
    2.       "height" 183.000000 ####
    3.       "weight" 172.000000 ####
    4.       "length" 117.000000 ##
    5.        "width" 47.000000 
    6.       "volume" 41.000000 
    7. "tipping_risk" 24.000000 

Variable Importance: NUM_NODES:
    1.             "weight" 2592.000000 ################
    2.       "tipping_risk" 2367.000000 ##############
    3.             "volume" 1318.000000 ########
    4.             "height" 1271.000000 #######
    5.           "product_group" 1195.000000 #######
    6.              "width" 1062.000000 ######
    7.           "girth" 968.000000 #####
    8.         "ft_ot_text" 730.000000 ####
    9.             "length" 689.000000 ####
   10. "shipping_mode"  5.000000 

Variable Importance: SUM_SCORE:
    1.           "product_group" 212.827222 ################
    2.             "height" 17.159601 #
    3.             "weight"  3.552953 
    4.       "tipping_risk"  2.266512 
    5.             "length"  1.447021 
    6.             "volume"  0.999544 
    7.           "girth"  0.891605 
    8.              "width"  0.525099 
    9.         "ft_ot_text"  0.106717 
   10. "shipping_mode"  0.000000 

Loss: MULTINOMIAL_LOG_LIKELIHOOD
Validation loss value: 2.87221e-06
Number of trees per iteration: 4
Node format: NOT_SET
Number of trees: 1200
Total number of nodes: 25594

Number of nodes by tree:
Count: 1200 Average: 21.3283 StdDev: 3.18991
Min: 3 Max: 27 Ignored: 0
----------------------------------------------
[  3,  4)   2   0.17%   0.17%
[  4,  5)   0   0.00%   0.17%
[  5,  6)   2   0.17%   0.33%
[  6,  8)   0   0.00%   0.33%
[  8,  9)   0   0.00%   0.33%
[  9, 10)   0   0.00%   0.33%
[ 10, 11)   0   0.00%   0.33%
[ 11, 13)   8   0.67%   1.00%
[ 13, 14)  12   1.00%   2.00%
[ 14, 15)   0   0.00%   2.00%
[ 15, 16)  21   1.75%   3.75% #
[ 16, 18)  73   6.08%   9.83% ##
[ 18, 19)   0   0.00%   9.83%
[ 19, 20) 262  21.83%  31.67% #######
[ 20, 21)   0   0.00%  31.67%
[ 21, 23) 372  31.00%  62.67% ##########
[ 23, 24) 199  16.58%  79.25% #####
[ 24, 25)   0   0.00%  79.25%
[ 25, 26) 156  13.00%  92.25% ####
[ 26, 27]  93   7.75% 100.00% ###

Depth by leafs:
Count: 13397 Average: 3.9155 StdDev: 1.0663
Min: 1 Max: 5 Ignored: 0
----------------------------------------------
[ 1, 2)  178   1.33%   1.33%
[ 2, 3) 1354  10.11%  11.44% ###
[ 3, 4) 3100  23.14%  34.57% ######
[ 4, 5) 3555  26.54%  61.11% #######
[ 5, 5] 5210  38.89% 100.00% ##########

Number of training obs by leaf:
Count: 13397 Average: 12.0923 StdDev: 18.5167
Min: 5 Max: 130 Ignored: 0
----------------------------------------------
[   5,  11) 11675  87.15%  87.15% ##########
[  11,  17)   419   3.13%  90.27%
[  17,  23)    42   0.31%  90.59%
[  23,  30)    40   0.30%  90.89%
[  30,  36)    63   0.47%  91.36%
[  36,  42)     7   0.05%  91.41%
[  42,  49)     1   0.01%  91.42%
[  49,  55)    40   0.30%  91.71%
[  55,  61)   158   1.18%  92.89%
[  61,  68)   320   2.39%  95.28%
[  68,  74)    53   0.40%  95.68%
[  74,  80)   306   2.28%  97.96%
[  80,  86)   226   1.69%  99.65%
[  86,  93)    27   0.20%  99.85%
[  93,  99)    16   0.12%  99.97%
[  99, 105)     2   0.01%  99.99%
[ 105, 112)     1   0.01%  99.99%
[ 112, 118)     0   0.00%  99.99%
[ 118, 124)     0   0.00%  99.99%
[ 124, 130]     1   0.01% 100.00%

Attribute in nodes:
    2592 : weight [NUMERICAL]
    2367 : tipping_risk [NUMERICAL]
    1318 : volume [NUMERICAL]
    1271 : height [NUMERICAL]
    1195 : product_group [CATEGORICAL]
    1062 : width [NUMERICAL]
    968 : girth [NUMERICAL]
    730 : ft_ot_text [CATEGORICAL]
    689 : length [NUMERICAL]
    5 : shipping_mode [CATEGORICAL]

Attribute in nodes with depth <= 0:
    616 : product_group [CATEGORICAL]
    183 : height [NUMERICAL]
    172 : weight [NUMERICAL]
    117 : length [NUMERICAL]
    47 : width [NUMERICAL]
    41 : volume [NUMERICAL]
    24 : tipping_risk [NUMERICAL]

Attribute in nodes with depth <= 1:
    709 : weight [NUMERICAL]
    627 : product_group [CATEGORICAL]
    468 : height [NUMERICAL]
    457 : tipping_risk [NUMERICAL]
    378 : length [NUMERICAL]
    314 : volume [NUMERICAL]
    218 : width [NUMERICAL]
    156 : girth [NUMERICAL]
    95 : ft_ot_text [CATEGORICAL]

Attribute in nodes with depth <= 2:
    1550 : weight [NUMERICAL]
    1225 : tipping_risk [NUMERICAL]
    767 : volume [NUMERICAL]
    741 : product_group [CATEGORICAL]
    675 : height [NUMERICAL]
    479 : length [NUMERICAL]
    437 : width [NUMERICAL]
    361 : girth [NUMERICAL]
    277 : ft_ot_text [CATEGORICAL]

Attribute in nodes with depth <= 3:
    2342 : weight [NUMERICAL]
    1860 : tipping_risk [NUMERICAL]
    1077 : volume [NUMERICAL]
    937 : product_group [CATEGORICAL]
    927 : height [NUMERICAL]
    778 : girth [NUMERICAL]
    734 : width [NUMERICAL]
    601 : length [NUMERICAL]
    336 : ft_ot_text [CATEGORICAL]

Attribute in nodes with depth <= 5:
    2592 : weight [NUMERICAL]
    2367 : tipping_risk [NUMERICAL]
    1318 : volume [NUMERICAL]
    1271 : height [NUMERICAL]
    1195 : product_group [CATEGORICAL]
    1062 : width [NUMERICAL]
    968 : girth [NUMERICAL]
    730 : ft_ot_text [CATEGORICAL]
    689 : length [NUMERICAL]
    5 : shipping_mode [CATEGORICAL]

Condition type in nodes:
    10267 : HigherCondition
    1930 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
    616 : ContainsBitmapCondition
    584 : HigherCondition
Condition type in nodes with depth <= 1:
    2700 : HigherCondition
    722 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
    5494 : HigherCondition
    1018 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
    8319 : HigherCondition
    1273 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
    10267 : HigherCondition
    1930 : ContainsBitmapCondition

None
fhossfel commented 2 years ago

CartModel has a similar problem of showing only one decision but at least the mouseover is working.

display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "PROBABILITY", "distribution": [0.007407407407407408, 0.7333333333333333, 0.24444444444444444, 0.014814814814814815], "num_examples": 135.0}, "condition": {"type": "CATEGORICAL_IS_IN", "attribute": "product_Group", "mask": ["DIY"]}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 0.0, 1.0, 0.0], "num_examples": 33.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.00980392156862745, 0.9705882352941176, 0.0, 0.0196078431372549], "num_examples": 102.0}}]}, "#tree_plot_e7010c332612435caae222c9a1230050")
rstz commented 2 years ago

Hi, I'm not sure I correctly understand the problem just yet, but let me summarize what I think is going on.

The GradientBoostedTrees model you're building has Number of trees: 1200 i.e. it consists of 1200 trees. You inspect the first tree of this collection using tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=10) (this is what tree_idx does). This tree alone might not be great, but this is expected - all 1200 trees together give great performance, not a single tree.

For CART, there is indeed just a single tree - but for most problems, CART models do not perform as well as Random Forests or Gradient Boosted Trees.

fhossfel commented 2 years ago

Ahh, okay. Did not read the manual properly and misinterpreted the tree_idx parameter.

I had noticed that the missing class distribution bars are for the gradient boosted trees. Is that intentional?

rstz commented 2 years ago

Can you please clarify what you mean with "missing class distribution bars"?

rstz commented 2 years ago

Closing this as stale