Open NicolasHug opened 5 years ago
I am not sure I understand. Feel free to submit a PR with the change to be able to directly see the diff.
I'm proposing what I have done in https://github.com/scikit-learn/scikit-learn/pull/12807
grower._compute_spittability()
will look like this:
if node.hist_subtraction:
if node is node.parent.right_child:
sum_gradients = node.parent.split_info.gradient_right
sum_hessians = node.parent.split_info.hessian_right
else:
sum_gradients = node.parent.split_info.gradient_left
sum_hessians = node.parent.split_info.hessian_left
split_info = self.splitter.find_node_split_subtraction(
node.sample_indices,
sum_gradients, sum_hessians, node.parent.histograms,
node.sibling.histograms, histograms)
else:
split_info = self.splitter.find_node_split(
node.sample_indices, histograms)
and we can remove these lines in the find_node_split_subtraction
:
# We can pick any feature (here the first) in the histograms to
# compute the gradients: they must be the same across all features
# anyway, we have tests ensuring this. Maybe a more robust way would
# be to compute an average but it's probably not worth it.
context.sum_gradients = (parent_histograms[0]['sum_gradients'].sum() -
sibling_histograms[0]['sum_gradients'].sum())
n_samples = sample_indices.shape[0]
if context.constant_hessian:
context.sum_hessians = \
context.constant_hessian_value * float32(n_samples)
else:
context.sum_hessians = (parent_histograms[0]['sum_hessians'].sum() -
sibling_histograms[0]['sum_hessians'].sum())
Ok +1.
Actually as we just discussed, we don't need to recompute sum_gradients and sum_hessians regardless of the histogram computation method that is used.
Instead of summing over all the bins of an arbitrary feature histogram to compute
context.sum_gradients
andcontext.sum_hessians
, we can instead directly pass those values tofind_node_split_subtraction
since the parent'ssplit_info
already containsgradient_left/right
andhessian_left/right
.We're summing over 255 bins max so the gain is very minimal, but it would be clearer, and more accurate.