siboehm / lleaves

Compiler for LightGBM gradient-boosted trees, based on LLVM. Speeds up prediction by ≥10x.
https://lleaves.readthedocs.io/en/latest/
MIT License
364 stars 29 forks source link

Profile guided optimization: Branch annotation #2

Closed siboehm closed 3 years ago

siboehm commented 3 years ago

In the model.txt LightGBM stores leaf_count and internal_count, which saves how often this leaf / node is visited during a full prediction on the training set. This should be converted into branch_weights-Metadata inside the IR, supplying profile-guided hints to the compiler and potentially enabling binary layout optimization.

siboehm commented 3 years ago

Profiling results

Intel(R) Core(TM) i7-4770 CPU @ 3.40GHz

Previous

c25cdd134d7efdf5aa1b8339a19e36727958c7d0

../tests/models/NYC_taxi/model.txt 
---- NYC_TAXI (100 samples) ---
lleaves setup: 1.7
lleaves (Batchsize 500000, nthread 4): 38519.86μs

../tests/models/mtpl2/model_small.txt 
---- MTPL2 (100 samples) ---
lleaves setup: 7.31
lleaves (Batchsize 500000, nthread 4): 68963.59μs

../tests/models/airline/model.txt 
---- AIRLINE (100 samples) ---
lleaves setup: 8.62
lleaves (Batchsize 500000, nthread 4): 120513.7μs

With annotation

../tests/models/NYC_taxi/model.txt 
---- NYC_TAXI (100 samples) ---
lleaves setup: 2.0
lleaves (Batchsize 500000, nthread 4): 36889.66μs

../tests/models/mtpl2/model_small.txt 
---- MTPL2 (100 samples) ---
lleaves setup: 7.55
lleaves (Batchsize 500000, nthread 4): 70896.09μs

../tests/models/airline/model.txt 
---- AIRLINE (100 samples) ---
lleaves setup: 7.98
lleaves (Batchsize 500000, nthread 4): 117501.83μs

Summary

No significant speedup, not worth the extra complexity. Looking at some model.txts the left/right branches don't seem very imbalanced, which might explain why adding annotations doesn't help. Treelite also has an option to do branch annotation and last I tested it the option also didn't add any speedup.

Code dump for reference

diff --git a/lleaves/compiler/ast/nodes.py b/lleaves/compiler/ast/nodes.py
index 9f167ec..f8c5265 100644
--- a/lleaves/compiler/ast/nodes.py
+++ b/lleaves/compiler/ast/nodes.py
@@ -48,6 +48,7 @@ class DecisionNode(Node):
         decision_type_id: int,
         left_idx: int,
         right_idx: int,
+        n_visits: int,
     ):
         self.idx = idx
         self.split_feature = split_feature
@@ -55,6 +56,7 @@ class DecisionNode(Node):
         self.decision_type = DecisionType(decision_type_id)
         self.right_idx = right_idx
         self.left_idx = left_idx
+        self.n_visits = n_visits

     def add_children(self, left, right):
         self.left = left
@@ -75,9 +77,10 @@ class DecisionNode(Node):

 class LeafNode(Node):
-    def __init__(self, idx, value):
+    def __init__(self, idx, value, n_visits):
         self.idx = idx
         self.value = value
+        self.n_visits = n_visits

     def __str__(self):
         return f"leaf_{self.idx}"
diff --git a/lleaves/compiler/ast/parser.py b/lleaves/compiler/ast/parser.py
index f1244e9..ebf4126 100644
--- a/lleaves/compiler/ast/parser.py
+++ b/lleaves/compiler/ast/parser.py
@@ -21,14 +21,24 @@ class Feature:
 def _parse_tree_to_ast(tree_struct, features):
     n_nodes = len(tree_struct["decision_type"])
     leaves = [
-        LeafNode(idx, value) for idx, value in enumerate(tree_struct["leaf_value"])
+        LeafNode(idx, value, n_visits)
+        for idx, (value, n_visits) in enumerate(
+            zip(tree_struct["leaf_value"], tree_struct["leaf_count"])
+        )
     ]
+    assert len(leaves) == tree_struct["num_leaves"]

     # Create the nodes using all non-specific data
     # categorical nodes are finalized later
     nodes = [
         DecisionNode(
-            idx, split_feature, threshold, decision_type_id, left_idx, right_idx
+            idx,
+            split_feature,
+            threshold,
+            decision_type_id,
+            left_idx,
+            right_idx,
+            n_visits,
         )
         for idx, (
             split_feature,
@@ -36,6 +46,7 @@ def _parse_tree_to_ast(tree_struct, features):
             decision_type_id,
             left_idx,
             right_idx,
+            n_visits,
         ) in enumerate(
             zip(
                 tree_struct["split_feature"],
@@ -43,6 +54,7 @@ def _parse_tree_to_ast(tree_struct, features):
                 tree_struct["decision_type"],
                 tree_struct["left_child"],
                 tree_struct["right_child"],
+                tree_struct["internal_count"],
             )
         )
     ]
diff --git a/lleaves/compiler/ast/scanner.py b/lleaves/compiler/ast/scanner.py
index 4b20597..47082a6 100644
--- a/lleaves/compiler/ast/scanner.py
+++ b/lleaves/compiler/ast/scanner.py
@@ -85,8 +85,10 @@ TREE_SCAN_KEYS = {
     "left_child": ScannedValue(int, True),
     "right_child": ScannedValue(int, True),
     "leaf_value": ScannedValue(float, True),
+    "leaf_count": ScannedValue(int, True, True),
     "cat_threshold": ScannedValue(int, True, True),
     "cat_boundaries": ScannedValue(int, True, True),
+    "internal_count": ScannedValue(int, True, True),
 }

diff --git a/lleaves/compiler/codegen/codegen.py b/lleaves/compiler/codegen/codegen.py
index 7a5e1ba..313a759 100644
--- a/lleaves/compiler/codegen/codegen.py
+++ b/lleaves/compiler/codegen/codegen.py
@@ -141,7 +141,8 @@ def _gen_decision_node(func, node_block, node):
         ret = builder.select(comp, dconst(node.left.value), dconst(node.right.value))
         builder.ret(ret)
     else:
-        builder.cbranch(comp, left_block, right_block)
+        branch = builder.cbranch(comp, left_block, right_block)
+        branch.set_weights([node.left.n_visits, node.right.n_visits])

     # populate generated child blocks
     if left_block:
@@ -299,7 +300,8 @@ def _populate_categorical_node_block(
         val,
         iconst(32 * len(node.cat_threshold)),
     )
-    builder.cbranch(comp, bitset_comp_block, right_block)
+    branch = builder.cbranch(comp, bitset_comp_block, right_block)
+    branch.set_weights([9, 1])

     idx = bitset_comp_builder.udiv(val, iconst(32))
     bit_vecs = ir.Constant(