Open fafeeeeee opened 1 year ago
Hello,Have you run through Wiki10-31K?
There are some problems in code. I can't run Wiki10-31K directly. Here are some modifications to solve them (but I can't reproduce the results of Wiki10-31K in paper):
diff --git a/src/CascadeXML.py b/src/CascadeXML.py
index 9109fc6..c71361b 100755
--- a/src/CascadeXML.py
+++ b/src/CascadeXML.py
@@ -64,6 +64,11 @@ class CascadeXML(nn.Module):
clusters[-1][i] = np.pad(clusters[-1][i], (0, max_cluster-len(clusters[-1][i])),
constant_values=self.num_ele[-1]).astype(np.int32)
+ for c in clusters[:-1]:
+ max_cluster = max([len(c_) for c_ in c])
+ for i in range(len(c)):
+ c[i] = np.pad(c[i], (0, max_cluster-len(c[i])), constant_values=c[i][0]).astype(np.int32)
+
clusters = [np.stack(c) for c in clusters]
self.clusters = [torch.LongTensor(c).to(device) for c in clusters]
@@ -217,4 +222,4 @@ class CascadeXML(nn.Module):
sum_loss += l * self.rw_loss[i]
return all_probs, all_candidates, sum_loss
else:
- return all_probs, all_candidates, all_probs_weighted
\ No newline at end of file
+ return all_probs, all_candidates, all_probs_weighted
diff --git a/src/data_utils.py b/src/data_utils.py
index 363b1a7..60f09aa 100755
--- a/src/data_utils.py
+++ b/src/data_utils.py
@@ -101,10 +101,15 @@ def make_csr_tfidf(dataset):
if os.path.exists(file_name):
tfidf_mat = sp.load_npz(file_name)
else:
+ sub = False
with open(f'{dataset}/train.txt') as fil:
row_idx, col_idx, val_idx = [], [], []
for i, data in enumerate(fil.readlines()):
data = data.split()[1:]
+ if i == 0 and len(data) == 2:
+ print('skip head')
+ sub = True
+ continue
for tfidf in data:
try:
token, weight = tfidf.split(':')
@@ -115,6 +120,9 @@ def make_csr_tfidf(dataset):
col_idx.append(int(token))
val_idx.append(float(weight))
m = max(row_idx) + 1
+ if sub:
+ m -= 1
+ row_idx = [x-1 for x in row_idx]
n = max(col_idx) + 1
tfidf_mat = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n))
sp.save_npz(file_name, tfidf_mat)
diff --git a/src/dataset.py b/src/dataset.py
index e58f7c0..99f315f 100755
--- a/src/dataset.py
+++ b/src/dataset.py
@@ -191,7 +191,7 @@ class MultiXMLGeneral(Dataset):
_features = self.tf_X
lbl_sparse = _labels.dot(_features).tocsr()
lbl_sparse = retain_topk(lbl_sparse, k=1000)
- return lbl_sparse
+ return lbl_sparse
def __getitem__(self, idx):
cluster_ids = [torch.LongTensor(self.Y[idx].indices)]
@@ -208,7 +208,9 @@ class MultiXMLGeneral(Dataset):
cluster_ids = cluster_ids[::-1]
if self.train_W:
- return torch.FloatTensor(self.x[idx]), torch.ones(128), *cluster_ids[::-1]
+ result = [torch.FloatTensor(self.x[idx]), torch.ones(128)]
+ for a in cluster_ids[::-1]: result.append(a)
+ return tuple(result)
input_ids = self.x[idx]
@@ -224,7 +226,9 @@ class MultiXMLGeneral(Dataset):
attention_mask = torch.tensor([1] * len(input_ids) + [0] * padding_length)
input_ids = torch.tensor(input_ids + ([0] * padding_length))
- return input_ids, attention_mask, *cluster_ids
+ result = [input_ids, attention_mask]
+ for a in cluster_ids: result.append(a)
+ return tuple(result)
@@ -296,7 +300,9 @@ class PoolableMultiXMLGeneral(MultiXMLGeneral):
cluster_ids = cluster_ids[::-1]
if self.train_W:
- return torch.FloatTensor(self.x[idx]), torch.ones(128), *cluster_ids[::-1]
+ result = [torch.FloatTensor(self.x[idx]), torch.ones(128)]
+ for a in cluster_ids[::-1]: result.append(a)
+ return tuple(result)
input_ids = self.x[idx]
tfidfs = np.array(self.bert_tfidf[idx].todense())[0] #[input_ids]
@@ -318,8 +324,10 @@ class PoolableMultiXMLGeneral(MultiXMLGeneral):
input_ids, attention_mask = self.word_pool(input_ids, tfidfs, 8)
- return input_ids, attention_mask, *cluster_ids
-
+ result = [input_ids, attention_mask]
+ for a in cluster_ids: result.append(a)
+ return tuple(result)
+
class PecosDataset(Dataset):
def __init__(self, x, y, num_labels, max_length, groups=None,
@@ -401,7 +409,10 @@ class PecosDataset(Dataset):
if cluster_ids[-1][0] == -1:
cluster_ids[-1] == cluster_ids[-1][1:]
- return input_ids, attention_mask, *cluster_ids[::-1]
+ #return input_ids, attention_mask, *cluster_ids[::-1]
+ result = [torch.FloatTensor(self.x[idx]), torch.ones(128)]
+ for a in cluster_ids[::-1]: result.append(a)
+ return tuple(result)
class XMLData(Dataset):
def __init__(self, x, y, num_labels, max_length, params, group_y=None,
diff --git a/src/tree.py b/src/tree.py
index 8d8f2ff..82cf143 100755
--- a/src/tree.py
+++ b/src/tree.py
@@ -1,4 +1,5 @@
import numpy as np
+import scipy
import torch
import copy
import time
@@ -105,6 +106,8 @@ def b_kmeans_sparse_dense(lf_sparse, lf_dense, index, metric='cosine', tol=1e-4,
def b_kmeans_sparse(labels_features, index, metric='cosine', tol=1e-4, leakage=None):
+ if type(labels_features) is not scipy.sparse._csr.csr_matrix:
+ labels_features = scipy.sparse.csr_matrix(labels_features)
labels_features = _normalize(labels_features)
if labels_features.shape[0] == 1:
return [index]
@@ -169,6 +172,7 @@ def cluster_labels(labels, clusters, verbose_label_index, num_nodes, splitter, f
# del temp_cluster_list
# print(cpu_count()-1)
with Pool(cpu_count()-1) as p:
+ #with Pool(50) as p:
while len(clusters) < num_nodes:
if isinstance(labels, list):
temp_cluster_list = functools.reduce(
@@ -177,7 +181,7 @@ def cluster_labels(labels, clusters, verbose_label_index, num_nodes, splitter, f
else:
temp_cluster_list = functools.reduce(
operator.iconcat,
- p.starmap(splitter, map(lambda x: (labels[x], x), clusters)), [])
+ p.starmap(splitter, map(lambda x: (dict(labels[x].todok().items()), x), clusters)), [])
end = time.time()
print("Total clusters {}".format(len(temp_cluster_list)),
@@ -325,4 +329,4 @@ class build_tree:
self.__dict__ = pik.load(open(fname, 'rb'))
def save(self, fname):
- pik.dump(self.__dict__, open(fname, 'wb'))
\ No newline at end of file
+ pik.dump(self.__dict__, open(fname, 'wb'))
Thank you very much. I will refer to the modifications.
Could you please send me the modified file via email.After I made the modifications, it still doesn't work.my mail is 157500647@qq.com
I apologize for the issues. I will solve try to solve them after May 17.
Hello, nice job. But I only find training command of Wiki10-31K. Could you please provide the settings for other datasets like Amazon-670k?