xmc-aalto / cascadexml

Code for our paper CascadeXML: Rethinking Transformers for End-to-end Multi-resolution Training in Extreme Multi-label Classification
7 stars 4 forks source link

Settings for Amazon-670k #1

Open fafeeeeee opened 1 year ago

fafeeeeee commented 1 year ago

Hello, nice job. But I only find training command of Wiki10-31K. Could you please provide the settings for other datasets like Amazon-670k?

caseware66 commented 1 year ago

Hello,Have you run through Wiki10-31K?

fafeeeeee commented 1 year ago

There are some problems in code. I can't run Wiki10-31K directly. Here are some modifications to solve them (but I can't reproduce the results of Wiki10-31K in paper):

diff --git a/src/CascadeXML.py b/src/CascadeXML.py
index 9109fc6..c71361b 100755
--- a/src/CascadeXML.py
+++ b/src/CascadeXML.py
@@ -64,6 +64,11 @@ class CascadeXML(nn.Module):
             clusters[-1][i] = np.pad(clusters[-1][i], (0, max_cluster-len(clusters[-1][i])), 
                                         constant_values=self.num_ele[-1]).astype(np.int32)

+        for c in clusters[:-1]:
+            max_cluster = max([len(c_) for c_ in c])
+            for i in range(len(c)):
+                c[i] = np.pad(c[i], (0, max_cluster-len(c[i])), constant_values=c[i][0]).astype(np.int32)
+
         clusters = [np.stack(c) for c in clusters]
         self.clusters = [torch.LongTensor(c).to(device) for c in clusters]

@@ -217,4 +222,4 @@ class CascadeXML(nn.Module):
                 sum_loss += l * self.rw_loss[i]
             return all_probs, all_candidates, sum_loss
         else:
-            return all_probs, all_candidates, all_probs_weighted
\ No newline at end of file
+            return all_probs, all_candidates, all_probs_weighted
diff --git a/src/data_utils.py b/src/data_utils.py
index 363b1a7..60f09aa 100755
--- a/src/data_utils.py
+++ b/src/data_utils.py
@@ -101,10 +101,15 @@ def make_csr_tfidf(dataset):
     if os.path.exists(file_name):
         tfidf_mat = sp.load_npz(file_name)
     else:
+        sub = False
         with open(f'{dataset}/train.txt') as fil:
             row_idx, col_idx, val_idx = [], [], []
             for i, data in enumerate(fil.readlines()):
                 data = data.split()[1:]
+                if i == 0 and len(data) == 2:
+                    print('skip head')
+                    sub = True
+                    continue
                 for tfidf in data:
                     try:
                         token, weight = tfidf.split(':')
@@ -115,6 +120,9 @@ def make_csr_tfidf(dataset):
                     col_idx.append(int(token))
                     val_idx.append(float(weight))
             m = max(row_idx) + 1
+            if sub:
+                m -= 1
+                row_idx = [x-1 for x in row_idx]
             n = max(col_idx) + 1
             tfidf_mat = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n))
             sp.save_npz(file_name, tfidf_mat)
diff --git a/src/dataset.py b/src/dataset.py
index e58f7c0..99f315f 100755
--- a/src/dataset.py
+++ b/src/dataset.py
@@ -191,7 +191,7 @@ class MultiXMLGeneral(Dataset):
         _features = self.tf_X
         lbl_sparse = _labels.dot(_features).tocsr()
         lbl_sparse = retain_topk(lbl_sparse, k=1000)
-        return lbl_sparse   
+        return lbl_sparse

     def __getitem__(self, idx):
         cluster_ids = [torch.LongTensor(self.Y[idx].indices)]
@@ -208,7 +208,9 @@ class MultiXMLGeneral(Dataset):
         cluster_ids = cluster_ids[::-1]

         if self.train_W:
-            return torch.FloatTensor(self.x[idx]), torch.ones(128), *cluster_ids[::-1]
+            result = [torch.FloatTensor(self.x[idx]), torch.ones(128)]
+            for a in cluster_ids[::-1]: result.append(a)
+            return tuple(result)

         input_ids = self.x[idx]

@@ -224,7 +226,9 @@ class MultiXMLGeneral(Dataset):
         attention_mask = torch.tensor([1] * len(input_ids) + [0] * padding_length)
         input_ids = torch.tensor(input_ids + ([0] * padding_length))

-        return input_ids, attention_mask, *cluster_ids
+        result = [input_ids, attention_mask]
+        for a in cluster_ids: result.append(a)
+        return tuple(result)

@@ -296,7 +300,9 @@ class PoolableMultiXMLGeneral(MultiXMLGeneral):
             cluster_ids = cluster_ids[::-1]

         if self.train_W:
-            return torch.FloatTensor(self.x[idx]), torch.ones(128), *cluster_ids[::-1]
+            result = [torch.FloatTensor(self.x[idx]), torch.ones(128)]
+            for a in cluster_ids[::-1]: result.append(a)
+            return tuple(result)

         input_ids = self.x[idx]
         tfidfs = np.array(self.bert_tfidf[idx].todense())[0] #[input_ids]
@@ -318,8 +324,10 @@ class PoolableMultiXMLGeneral(MultiXMLGeneral):

         input_ids, attention_mask = self.word_pool(input_ids, tfidfs, 8)

-        return input_ids, attention_mask, *cluster_ids
-        
+        result = [input_ids, attention_mask]
+        for a in cluster_ids: result.append(a)
+        return tuple(result)
+

 class PecosDataset(Dataset):
     def __init__(self, x, y, num_labels, max_length, groups=None, 
@@ -401,7 +409,10 @@ class PecosDataset(Dataset):
             if cluster_ids[-1][0] == -1:
                 cluster_ids[-1] == cluster_ids[-1][1:]

-        return input_ids, attention_mask, *cluster_ids[::-1]
+        #return input_ids, attention_mask, *cluster_ids[::-1]
+        result = [torch.FloatTensor(self.x[idx]), torch.ones(128)]
+        for a in cluster_ids[::-1]: result.append(a)
+        return tuple(result)

 class XMLData(Dataset):
     def __init__(self, x, y, num_labels, max_length, params, group_y=None, 
diff --git a/src/tree.py b/src/tree.py
index 8d8f2ff..82cf143 100755
--- a/src/tree.py
+++ b/src/tree.py
@@ -1,4 +1,5 @@
 import numpy as np
+import scipy
 import torch
 import copy
 import time
@@ -105,6 +106,8 @@ def b_kmeans_sparse_dense(lf_sparse, lf_dense, index, metric='cosine', tol=1e-4,

 def b_kmeans_sparse(labels_features, index, metric='cosine', tol=1e-4, leakage=None):
+    if type(labels_features) is not scipy.sparse._csr.csr_matrix:
+        labels_features = scipy.sparse.csr_matrix(labels_features)
     labels_features = _normalize(labels_features)
     if labels_features.shape[0] == 1:
         return [index]
@@ -169,6 +172,7 @@ def cluster_labels(labels, clusters, verbose_label_index, num_nodes, splitter, f
     #     del temp_cluster_list
     # print(cpu_count()-1)
     with Pool(cpu_count()-1) as p:
+    #with Pool(50) as p:
         while len(clusters) < num_nodes:
             if isinstance(labels, list):
                 temp_cluster_list = functools.reduce(
@@ -177,7 +181,7 @@ def cluster_labels(labels, clusters, verbose_label_index, num_nodes, splitter, f
             else:    
                 temp_cluster_list = functools.reduce(
                     operator.iconcat,
-                    p.starmap(splitter, map(lambda x: (labels[x], x), clusters)), [])
+                    p.starmap(splitter, map(lambda x: (dict(labels[x].todok().items()), x), clusters)), [])

             end = time.time()
             print("Total clusters {}".format(len(temp_cluster_list)),
@@ -325,4 +329,4 @@ class build_tree:
         self.__dict__ = pik.load(open(fname, 'rb'))

     def save(self, fname):
-        pik.dump(self.__dict__, open(fname, 'wb'))
\ No newline at end of file
+        pik.dump(self.__dict__, open(fname, 'wb'))
caseware66 commented 1 year ago

Thank you very much. I will refer to the modifications.

caseware66 commented 1 year ago

Could you please send me the modified file via email.After I made the modifications, it still doesn't work.my mail is 157500647@qq.com

Atom-101 commented 1 year ago

I apologize for the issues. I will solve try to solve them after May 17.