src-d / kmcuda

Large scale K-means and K-nn implementation on NVIDIA GPU / CUDA
Other
797 stars 146 forks source link

RuntimeError: cudaMemcpy failed #60

Open bosmart opened 5 years ago

bosmart commented 5 years ago

I'm getting the "cudaMemcpy failed" error w/o any other information despite verbose=2 mode. I have rather big dataset (100k instances) trying to find k=10k clusters. CUDA memory use doesn't go above 600MB (4 GPU configuration).

vmarkovtsev commented 5 years ago

Hi @bosmart what is your feature vector dimension? Also, can you please attach the full log here. Do you use Python, R, or directly?

bosmart commented 5 years ago

I use Python in a Jupyter Notebook and "RuntimeError: cudaMemcpy failed" is literally all I'm getting. I also have only 3 dimensions.

vmarkovtsev commented 5 years ago

I see. Jupyter hides the native standard output, unfortunately, so you don't see the logs. Can you please run your code as a script in a terminal.

bosmart commented 5 years ago

`internal bug in kmeans_init_centroids: j = 0 step 173cudaMemcpyAsync( host_dists + offset, (dists)[devi].get(), length sizeof(float), cudaMemcpyDeviceToHost) /tmp/pip-req-build-0_qkkf5o/src/kmeans.cu:814 -> an illegal memory access was encountered

kmeans_cuda_plus_plus failed kmeans_init_centroids() failed for yinyang groups: an illegal memory access was encountered kmeans_cuda_yy failed: no error

Traceback (most recent call last): File "/tmp/decim.py", line 65, in decimate_and_plot(zm, datam, tit=fname.split('/')[-1], decimation_levels=decimation_levels) File "/tmp/decim.py", line 48, in decimate_and_plot centroids, assignments = kmeans_cuda(datam, k, init=datam[ids], verbosity=2, seed=3, device=0) RuntimeError: cudaMemcpy failed`

vmarkovtsev commented 5 years ago

OK, check your samples for NaNs. kmcuda is tolerant to NaNs, but sometimes it still fails if there are too many or whole vectors with NaNs. Are there marginal values, e.g. too big or too small?

Because this error means that it cannot initialize the centroids because some of the distances from samples to the centroids initialized so far appeared to be NaNs.

vmarkovtsev commented 5 years ago

If it does not help, can you please send your data and the script to my email (listed on my profile page).

Teeeto commented 5 years ago

Possibly an error in python.cc

cudaMalloc(reinterpret_cast<void **>(&neighbors),
                   **samples_size * k * sizeof(float)**) != cudaSuccess

I was struggling with different memory errors, then noted a compiler warning saying the multiplication is performed in unit32_t and may result in an arithmetic overflow. With 4 million samples and 3000 clusters the multiplication result does not fit in uint32. After I applied a static cast to uint_64 before multiplication for all mallocs the issue seem to have disappeared.

So now its like

cudaMalloc(reinterpret_cast<void **>(&neighbors),
        **static_cast<uint64_t>(samples_size) * k * sizeof(float)**) != cudaSuccess

Apologize if reported in the wrong place. My first github post.

vmarkovtsev commented 5 years ago

@Teeeto Good catch! Could you please PR this? Looks like a proper bugfix to me.

Teeeto commented 5 years ago

I am afraid it will be hard. I do not have the Linux machine to do a proper PR with tests and test builds etc. I only have a windows machine and trying to get a working windows build for my project.

vmarkovtsev commented 5 years ago

No worries, there is no need to cover it with tests. Just PR this one liner.

wilsonwong2014 commented 4 years ago

I recurrence the problem 'RuntimeError: cudaMemcpy failed!' by python code. And found the problem cassed by the follow code[from line 330 to 332 in source file kmcuda.cc] when j=0:

        RETERR(cuda_copy_sample_t(
            j - 1, i * features_size, samples_size, features_size, devs,
            verbosity, samples, centroids));

Detailed Description

# test code
data = np.array([10.371722, 10.458154, 10.501911, 10.516579, 10.546037, 10.546037, 10.516579, 10.516579, 10.516579, 10.531287, 10.531287 , 4.78087 ], dtype=np.float32)
means, y_pred = kmeans_cuda(data.reshape(-1,1), 10, tolerance=0.01,verbosity=5, seed=137)
print(means.reshape(-1))
print(y_pred.reshape(-1))

Output Information

arguments: 1 0x7ffc72bb065c 0.010 0.10 0 12 1 10 137 0 0 5 0x56252d841550 0x56252e0ebd90 0x56252df935d0 (nil) reassignments threshold: 0 yinyang groups: 1 [0] *dest: 0x7f4b51400000 - 0x7f4b51400030 (48) [0] device_centroids: 0x7f4b51400200 - 0x7f4b51400228 (40) [0] device_assignments: 0x7f4b51400400 - 0x7f4b51400430 (48) [0] device_assignments_prev: 0x7f4b51400600 - 0x7f4b51400630 (48) [0] device_ccounts: 0x7f4b51400800 - 0x7f4b51400828 (40) [0] device_assignments_yy: 0x7f4b51400a00 - 0x7f4b51400a28 (40) [0] device_bounds_yy: 0x7f4b51400c00 - 0x7f4b51400c60 (96) [0] device_drifts_yy: 0x7f4b51400e00 - 0x7f4b51400e50 (80) [0] device_passed_yy: 0x7f4b51401000 - 0x7f4b51401030 (48) reusing passed_yy for centroids_yy GPU #0 memory: used 739966976 bytes (8.8%), free 7626817536 bytes, total 8366784512 bytes GPU #0 has 49152 bytes of shared memory per block transposing the samples... transpose <<<(1, 1), (32, 8)>>> 12, 1 performing kmeans++... kmeans++: dump 12 1 0x56252e0e4500 kmeans++: dev #0: 0x7f4b51400000 0x7f4b51400200 0x7f4b51400400 step 1[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 6,choice_sum = 3.487407 step 2[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 5,choice_sum = 0.155162 step 3[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 10,choice_sum = 0.159036 step 4[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 5,choice_sum = 0.044974 step 5[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 2,choice_sum = 0.010118 step 6[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 11,choice_sum = 0.027562 step 7[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64)

[my debug]: samples_size = 12, choice_approx = 4,choice_sum = 0.000000

internal bug in kmeans_init_centroids: j = 0 step 8[0] dev_dists: 0x7f4b51401200 - 0x7f4b51401240 (64) cudaMemcpyAsync( host_dists + offset, (dists)[devi].get(), length sizeof(float), cudaMemcpyDeviceToHost) /home/hjw/work/platform/kmcuda/src/kmeans.cu:814 -> an illegal memory access was encountered

kmeans_cuda_plus_plus failed kmeans_init_centroids failed: an illegal memory access was encountered Traceback (most recent call last): File "cudamemcopy_fail.py", line 20, in means, y_pred = kmeans_cuda(data.reshape(-1,1), 10, tolerance=0.01,verbosity=5, seed=137) RuntimeError: cudaMemcpy failed

Track

choice_sum=0 cause to => j=0 cause to => RETERR(cuda_copy_sample_t( j - 1, i * features_size, samples_size, features_size, devs, verbosity, samples, centroids));

j-1=0-1=-1 => Cause memory exception

Fix

if(j>0 && j<=samples_size)
{
    RETERR(cuda_copy_sample_t(
        j - 1, i * features_size, samples_size, features_size, devs,
        verbosity, samples, centroids));
}

Test Result

OK! 'RuntimeError: cudaMemcpy failed!' no more again!!!

================================================= I don't know if it's completely right. I hope it's useful for bloggers