ChenhongyiYang / PPAL

[CVPR 2024] Plug and Play Active Learning for Object Detection
Apache License 2.0
68 stars 8 forks source link

I use my dataset, when run the diversity_sampler, it encounter "assert len(cluster_i) >= 1" problem #3

Open syyxtl opened 1 year ago

syyxtl commented 1 year ago

image image

ChenhongyiYang commented 1 year ago

Hi, sorry for the late reply. This problem emerges when an image is assigned to more than one centroid, which is often caused by problems in distance computing. Could you please check the computed image distances?

BroadswordZhang commented 1 year ago

I also have the problem when I use my dataset, please tell me how can I do if I know the image distance

dust-removal commented 1 year ago

image image Hi, I also want to train my own dataset, can you tell me what changes you made to train your own dataset

sharat29ag commented 9 months ago

@ChenhongyiYang I also faced this error while training on COCO.

shiyuanyu123 commented 3 months ago

@ChenhongyiYang I also faced this error while training on my dataset.

Parkkkkk commented 1 month ago

Try with that code

@staticmethod
def kmeans(dis_matrix, K, n_iter=100):
    N = dis_matrix.shape[0]
    centroids = DiversitySampler.k_centroid_greedy(dis_matrix, K)
    data_indices = np.arange(N)

    assign_dis_records = []
    for _ in range(n_iter):
        centroid_dis = dis_matrix[:, centroids]
        cluster_assign = np.argmin(centroid_dis, axis=1)
        assign_dis = centroid_dis.min(axis=1).sum()
        assign_dis_records.append(assign_dis)

        new_centroids = []
        for i in range(K):
            cluster_i = data_indices[cluster_assign == i]
            if len(cluster_i) == 0:
                new_centroid_i = np.random.choice(data_indices)
            else:
                dis_mat_i = dis_matrix[cluster_i][:, cluster_i]
                new_centroid_i = cluster_i[np.argmin(dis_mat_i.sum(axis=1))]
            new_centroids.append(new_centroid_i)
        centroids = np.array(new_centroids)
    return centroids.tolist()