swuxyj / DeepHash-pytorch

Implementation of Some Deep Hash Algorithms, Including DPSH、DSH、DHN、HashNet、DSDH、DTSH、DFH、GreedyHash、CSQ.
MIT License
495 stars 116 forks source link

CSQ: big class num #21

Closed sophieyl closed 2 years ago

sophieyl commented 3 years ago

hi~ I want to train CSQ on a person reid task. the class num is more than 40,000 the code which was used to set the hash target center can be extremly time consumming, do you have any suggestion? @swuxyj

`

if H_2K.shape[0] < n_class:
    hash_targets.resize_(n_class, bit)
    for k in range(20):
        for index in range(H_2K.shape[0], n_class):
            ones = torch.ones(bit)
            # Bernouli distribution
            sa = random.sample(list(range(bit)), bit // 2)
            ones[sa] = -1
            hash_targets[index] = ones
        # to find average/min  pairwise distance
        c = []
        for i in range(n_class):
            for j in range(n_class):
                if i < j:
                    TF = sum(hash_targets[i] != hash_targets[j])
                    c.append(TF)
        c = np.array(c)

        # choose min(c) in the range of K/4 to K/3
        # see in https://github.com/yuanli2333/Hadamard-Matrix-for-hashing/issues/1
        # but it is hard when bit is  small
        if c.min() > bit / 4 and c.mean() >= bit / 2:
            print(c.min(), c.mean())
            break

`

sophieyl commented 3 years ago

god! I modified the code as below, but still time consumming :(

` random.seed(0) n_class = 43828 bit = 64 H_K = hadamard(bit) H_2K = np.concatenate((H_K, -H_K), 0) hash_targets = torch.from_numpy(H_2K[:n_class]).float()

if H_2K.shape[0] < n_class:
    hash_targets.resize_(n_class, bit)
    invalid_list = []
    for index in tqdm(range(H_2K.shape[0], n_class)):
        while 1:
            ones = torch.ones(bit)
            sa = random.sample(list(range(bit)), bit // 2)
            same_sa_flag = False
            for sample in invalid_list:
                interset = set(sa) & set(sample)
                if len(interset) == bit // 2:
                    same_sa_flag = True
                    break
            if same_sa_flag:
                continue
            ones[sa] = -1
            c = []
            for id, hash_target in enumerate(hash_targets[0:index]):
                TF = sum(hash_target != ones)
                c.append(TF)
                if TF < bit / 4:
                    invalid_list.append(set(sa))
                    break
            if len(c) == int(index):
                if np.array(c).mean() >= bit / 2:
                    break
        hash_targets[index] = ones`
swuxyj commented 3 years ago

I am not sure whether 64bits can generate hash target centerof so many classes. I think there will be many classes with the same hash target center.

sophieyl commented 3 years ago

I am not sure whether 64bits can generate hash target centerof so many classes. I think there will be many classes with the same hash target center.

I think I should try other deephash method, since it only generate 25,000 hash center by the pass 67 hours, really time consumming, Do you have any suggestions by what method can handdle well on large classes dataset?

swuxyj commented 3 years ago

I am not sure whether 64bits can generate hash target centerof so many classes. I think there will be many classes with the same hash target center.

I think I should try other deephash method, since it only generate 25,000 hash center by the pass 67 hours, really time consumming, Do you have any suggestions by what method can handdle well on large classes dataset?

Sorry i don't have a suitable suggestion