motefly / DeepGBM

SIGKDD'2019: DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks
647 stars 135 forks source link

关于split_gain的问题 #12

Open Aliang-CN opened 4 years ago

Aliang-CN commented 4 years ago

您好: 我在阅读您的代码的时候发现一个问题,self.gain = getItemByTree(self, 'split_gain'),这行代码应该是获取节点每次分裂的信息增益,但是在getItemByTree里面的getFeature里面并没有相对应的操作。 def getItemByTree(tree, item='split_feature'): root = tree.raw['tree_structure']
split_nodes = tree.split_nodes
res = np.zeros(split_nodes+tree.raw['num_leaves'], dtype=np.int32)
if 'value' in item or 'threshold' in item or 'split_gain' in item:
res = res.astype(np.float64) def getFeature(root, res): if 'child' in item:
if 'split_index' in root: node = root[item]
if 'split_index' in node:
res[root['split_index']] = node['split_index'] else:
res[root['split_index']] = node['leaf_index'] + split_nodes # need to check else: res[root['leaf_index'] + split_nodes] = -1 elif 'value' in item:
if 'split_index' in root: res[root['splitindex']] = root['internal'+item]
else: res[root['leaf_index'] + splitnodes] = root['leaf'+item] else:
if 'split_index' in root:
res[root['split_index']] = root[item] else: res[root['leaf_index'] + split_nodes] = -2
if 'left_child' in root:
getFeature(root['left_child'], res) if 'right_child' in root: getFeature(root['right_child'], res) getFeature(root, res) return res

Aliang-CN commented 4 years ago

您好,可以更新一下最新的模型吗?

Aliang-CN commented 4 years ago

您好! 还有一个问题想咨询一下。在SubGBDTLeaf_cls函数里面下面这段代码中的all_hav主要是记录什么信息呢?然后下面treeI[tree].gain[kdx],实际上应该是记录split _feature,我这么理解可以吗? all_hav = {} # set([i for i in range(MAX)]) for jdx, tree in enumerate(tree_indices):
for kdx, f in enumerate(treeI[tree].feature):
if f == -2:
continue if f not in all_hav:
all_hav[f] = 0
all_hav[f] += treeI[tree].gain[kdx]

Aliang-CN commented 4 years ago

您好! 还有一个问题想咨询你,vectors[idx] = set(features[np.where(features>0)]) ,这行代码是过滤split_feature>0的特征, 那么split_feature=0的这个特征就会被遗留。 def EqualGroup(self, n_clusters, args): vectors = {}

n_feature = 256

    for idx,features in enumerate(self.featurelist):            
        vectors[idx] = set(features[np.where(features>0)])     
    keys = random.sample(vectors.keys(), len(vectors))         
    clusterIdx = np.zeros(len(vectors))                         
    groups = [[] for i in range(n_clusters)]
    trees_per_cluster = len(vectors)//n_clusters                
    mod_per_cluster = len(vectors) % n_clusters                 
    begin = 0
    for idx in range(n_clusters):                              
        for jdx in range(trees_per_cluster):                    
            clusterIdx[keys[begin]] = idx                       
            begin += 1
        if idx < mod_per_cluster:                              
            clusterIdx[keys[begin]] = idx                       
            begin += 1
    print([np.where(clusterIdx==i)[0].shape for i in range(n_clusters)])
    return clusterIdx
motefly commented 4 years ago

您好: 我在阅读您的代码的时候发现一个问题,self.gain = getItemByTree(self, 'split_gain'),这行代码应该是获取节点每次分裂的信息增益,但是在getItemByTree里面的getFeature里面并没有相对应的操作。 def getItemByTree(tree, item='split_feature'): root = tree.raw['tree_structure'] split_nodes = tree.split_nodes res = np.zeros(split_nodes+tree.raw['num_leaves'], dtype=np.int32) if 'value' in item or 'threshold' in item or 'split_gain' in item: res = res.astype(np.float64) def getFeature(root, res): if 'child' in item: if 'split_index' in root: node = root[item] if 'split_index' in node: res[root['split_index']] = node['split_index'] else: res[root['split_index']] = node['leaf_index'] + split_nodes # need to check else: res[root['leaf_index'] + split_nodes] = -1 elif 'value' in item: if 'split_index' in root: res[root['splitindex']] = root['internal'+item] else: res[root['leaf_index'] + splitnodes] = root['leaf'+item] else: if 'split_index' in root: res[root['split_index']] = root[item] else: res[root['leaf_index'] + split_nodes] = -2 if 'left_child' in root: getFeature(root['left_child'], res) if 'right_child' in root: getFeature(root['right_child'], res) getFeature(root, res) return res

您好,之前我误判了,https://github.com/motefly/DeepGBM/blob/master/tree_model_interpreter.py#L36 这里应该是可以拿到gain的。