mickeysjm / TaxoExpan

The source code used for self-supervised taxonomy expansion method TaxoExpan, published in WWW 2020
Apache License 2.0
75 stars 18 forks source link

can not generate negetive samples #11

Open dorbodwolf opened 3 months ago

dorbodwolf commented 3 months ago

Alert in _get_exactly_k_negatives, query_node: 2555, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 4086, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 9162, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 9305, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 1819, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 11334, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 1401, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 383, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 2228, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 12267, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 14564, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 720, current negative size: 0 Alert in _get_exactly_k_negatives, query_node: 12083, current n

dorbodwolf commented 3 months ago
def _get_exactly_k_negatives(self, query_node, negative_size):
        """ Generate EXACTLY negative_size samples for the query node
        """
        if self.pointer == 0: 
            random.shuffle(self.queue)
        masks = self.node2masks[query_node]
        negatives = []
        max_try = 0
        while len(negatives) != negative_size:
            n_lack = negative_size - len(negatives)
            negatives.extend([ele for ele in self.queue[self.pointer: self.pointer+n_lack] if ele not in masks and ele not in sels])
            self.pointer += n_lack
            if self.pointer >= len(self.queue):
                self.pointer = 0
                random.shuffle(self.queue)
            max_try += 1
            if max_try > 10:  # corner cases, trim/expand negatives to the size
                print(f"Alert in _get_exactly_k_negatives, query_node: {query_node}, current negative size: {len(negatives)}")
                if len(negatives) > negative_size:
                    nega