IntelLabs / numba

NumPy aware dynamic Python compiler using LLVM
http://numba.pydata.org/
BSD 2-Clause "Simplified" License
12 stars 2 forks source link

np.sum parallelization with axis #31

Open ehsantn opened 7 years ago

ehsantn commented 7 years ago

We need to parallelize np.sum with axis argument to enable the following simpler version of k-means:

@numba.njit(parallel=True)
def kmeans(A, numCenter, numIter):
    N, D = A.shape
    centroids = np.random.ranf((numCenter, D))

    for l in range(numIter):
        dist = np.array([[sqrt(np.sum((A[i,:]-centroids[j,:])**2))
                                for j in range(numCenter)] for i in range(N)])
        labels = np.array([dist[i,:].argmin() for i in range(N)])
        centroids = np.array([np.sum(A[labels==i], 0)/np.sum(labels==i)
                                for i in range(numCenter)])

    return centroids