ChenhongyiYang / PPAL

[CVPR 2024] Plug and Play Active Learning for Object Detection
Apache License 2.0
85 stars 10 forks source link

I use my dataset, when run the diversity_sampler, it encounter "assert len(cluster_i) >= 1" problem #3

Open syyxtl opened 2 years ago

syyxtl commented 2 years ago

image image

ChenhongyiYang commented 1 year ago

Hi, sorry for the late reply. This problem emerges when an image is assigned to more than one centroid, which is often caused by problems in distance computing. Could you please check the computed image distances?

BroadswordZhang commented 1 year ago

I also have the problem when I use my dataset, please tell me how can I do if I know the image distance

dust-removal commented 1 year ago

image image Hi, I also want to train my own dataset, can you tell me what changes you made to train your own dataset

sharat29ag commented 1 year ago

@ChenhongyiYang I also faced this error while training on COCO.

shiyuanyu123 commented 7 months ago

@ChenhongyiYang I also faced this error while training on my dataset.

Parkkkkk commented 5 months ago

Try with that code

@staticmethod
def kmeans(dis_matrix, K, n_iter=100):
    N = dis_matrix.shape[0]
    centroids = DiversitySampler.k_centroid_greedy(dis_matrix, K)
    data_indices = np.arange(N)

    assign_dis_records = []
    for _ in range(n_iter):
        centroid_dis = dis_matrix[:, centroids]
        cluster_assign = np.argmin(centroid_dis, axis=1)
        assign_dis = centroid_dis.min(axis=1).sum()
        assign_dis_records.append(assign_dis)

        new_centroids = []
        for i in range(K):
            cluster_i = data_indices[cluster_assign == i]
            if len(cluster_i) == 0:
                new_centroid_i = np.random.choice(data_indices)
            else:
                dis_mat_i = dis_matrix[cluster_i][:, cluster_i]
                new_centroid_i = cluster_i[np.argmin(dis_mat_i.sum(axis=1))]
            new_centroids.append(new_centroid_i)
        centroids = np.array(new_centroids)
    return centroids.tolist()