getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 64 forks source link

GpuConv1D.cu:613 : an illegal memory access was encountered #209

Open gjhuizing opened 2 years ago

gjhuizing commented 2 years ago

Hello,

I'm hoping to use using KeOps in order to save some memory in my PyTorch model (thanks for all the work developing this tool!)

The line pykeops.test_torch_bindings() works fine.

I first create a cosine distance matrix Lazy Tensor

# Normalizing rows
X = X.T
X /= torch.norm(X, dim=0)
X = X.T

# Initialize the LazyTensors.
x_i = LazyTensor(X[:, None, :]) # size m, 1, n
x_j = LazyTensor(X[None, :, :]) # size 1, m, n

# This is the cosine distance.
C_lazy = 1 - (x_i * x_j).sum(dim=2) # size m, m

# This is the cosine distance kernel.
K = (-C_lazy/self.eps).exp() # size m, m

Then, I want to compute the following loss:

loss = (A*torch.log(K@torch.exp(Y/eps))).sum()

where A and Y are torch Tensors and eps is a float. Y requires a gradient as it is optimized

On my CPU, this runs fine, however on GPU, I get the following error: GpuConv1D.cu:613 : an illegal memory access was encountered

Actually, even K.sum() throws the same error.

Is something wrong with my installation on the GPU cluster or am I doing things wrong ?

Best,

GJ

jeanfeydy commented 2 years ago

Hi @gjhuizing ,

I hope that you are going well! I was not able to reproduce your bug on Google Colab: could you maybe provide us with a minimal script to reproduce this error, including the generation of the data arrays? This will allow us to understand if the problem is configuration-specific or not. As far as I can tell, the main «suspect» causes for this type of error could be that:

  1. The arrays X and Y are not stored on the same GPU device.
  2. Some arrays are not contiguous.
  3. A strange compilation bug with a specific version of CUDA that we are not familiar with.

Furthermore, it's very likely that your current implementation will run into numerical stability issues: you should consider using a logsumexp() reduction as detailed in the routine sinkhorn_loop_stable, page 30 of our Neurips 2021 paper.

What do you think? Best regards, Jean

gjhuizing commented 2 years ago

Hi @jeanfeydy, thanks for the quick answer! Hope you're also doing well.

I'm working on creating a minimum example that fails, i'll post it here as soon as possible.

I'm working with a single GPU so I think point (1) should be fine. Point (2) might very well be the reason, fingers crossed as that would be an easier fix than (3).

Yes that's a good point about numerical stability. Until now I had to work with double, but I might as well also use Keops to stabilize it :)

Best

GJ

gjhuizing commented 2 years ago

Okay here is a minimal gist for colab that fails for me: https://gist.github.com/gjhuizing/298123f2811f562b5802d031b8a1eef1 It fails both on cuda and cpu

jeanfeydy commented 2 years ago

Hi @gjhuizing ,

Thanks for your report! I have made some tests, and here are the main conclusions:

  1. You have found a real bug, which is visibly triggered by the number of columns for the right hand side term G in your matmul. For whatever reason, on Colab, your code works when the number of columns is smaller than 146 but fails beyond 147. The bug also holds with our newest branch (python_engine), so we definitely have to fix it. I suspect that this is due to optimizations that @joanglaunes implemented about a year ago, and that we didn't test carefully enough. What do you think @joanglaunes ?

  2. Fortunately, until we fix the issue, there is a simple work-around for your problem: cut your array G vertically into chunks of at most 50 or 100 columns and apply the matrix-matrix product chunk-by-chunk using a python for loop. This should still be very efficient.

  3. Generally speaking, KeOps is not really optimized to be used with very high-dimensional variables (i.e. arrays A and G in your code that have "too many columns"): as detailed in this benchmark, there is a strong incentive to keep the dimension of your vector below 50. This is due to the size of CUDA registers, which tend to overflow when manipulating high-dimensional vectors. As a consequence, you should expect a significant speed-up if you use a dimension reduction technique (such as PCA) to a space of dimension 16 or 32 before using KeOps on your point clouds "x_i".

  4. By default, C_lazy.max() implements a maximum on the "symbolic dimension" -1 (= 2, in your example). To compute the maximum value of C_lazy over all indices i and j, please use C_lazy.max(1).max().

  5. Explicit LogSumExp reductions are much more stable and fast than float64 computations, so you should definitely try this :-)

Does this all make sense for you?

Best regards, Jean

gjhuizing commented 2 years ago
  1. Yay! Wasn't expecting to find a bug :smile:
  2. Hum, actually the columns are my samples (weird convention i know) and I'm expecting ~10,000 of them. Rows (features) should be ~1,000. Do you think a for loop would still be the way to go ?
  3. Interesting! So i'd be better off using vanilla PyTorch ? I could compute C on a PCA reduction, but G will still need to have a lot of columns.
  4. Okay thank you!
  5. Yes working on it right now! Do you think that this might also work around this bug?
jeanfeydy commented 2 years ago

Hi @gjhuizing ,

I see! So in other words, you are solving an OT problem in parallel ~10,000 times with ~1,000 different point locations? If this is the case, I now remember the e-mails that we exchanged back in June 2020 (12th-13th): the main stumbling block is that Elem does not currently support "variable" input. (This still hasn't changed since 2020, but should be very easy to fix...)

Are you often at the ENS these days? I am back in Paris now, and will probably come and visit Thibault next week: it may be good to chat about this in person :-)

Best regards, Jean

gjhuizing commented 2 years ago

Hi, exactly! That would be great, I'd love to have a chat about it :)

joanglaunes commented 2 years ago

Hello @gjhuizing , In fact for the example you provide in Colab, the problem comes from your operation C_lazy /= C_lazy.max(). This does not at all divide by the max of all elements, but rather divide each value by itself, because by default the max is equivalent to max(..., dim=2), and since here it is applied to a scalar matrix, what you do is in fact C_lazy /= C_lazy, which will give NaN values on the diagonal. In fact the max operation in KeOps is either with dim=0, dim=1 or dim=2. The two first options will return the max reduction of the tensor with respect to I or j variables, and the dim=2 option is the lazy operation that you used. There is no way in KeOps to do a max over all elements of a lazy matrix as you intended. Now the fact that it crashes instead of returning NaN values for this example is definitely a bug, so we need to investigate. Can you tell us if in your initial issue with the cosine distance matrix, you also used such a max operation?

joanglaunes commented 2 years ago

I replied a bit too fast, sorry. What I said about the max operation is true, but In fact the bug is still present when you remove this operation in your script. So definitely it has to do with the special computation mode for high dimensional data, as Jean suggested. So we need to investigate, and otherwise there is a (quite hidden..) way to disable this special mode. You need to replace the call to out = K@G by the following lines : KG = K * LazyTensor(G[None,:,:]) del KG.rec_multVar_highdim # this is how option is disabled out = KG.sum(dim=1)

gjhuizing commented 2 years ago

Hi, thanks for your answer! I removed the division by the max following Jean's comments. I'll try disabling the option before reducing, thank you. I see that you're using *, how do matrix multiplications work with LazyTensors ?

joanglaunes commented 2 years ago

Hello @gjhuizing , Ok yes indeed I had not seen that Jean already noticed the problem with the max operation, sorry ! About your question, the @ operation in KeOps is just a shortcut for element-wise multiplication followed by sum reduction (lines 1 and 3 in my previous message). The element-wise multiplication command is a symbolic operation, it does not yet compute anything so KG is still a LazyTensor, and then the sum(dim=1) command performs all the computations in one single call, i.e. it computes the sequence of operations that have been registered into KG together with the sum reduction.

gjhuizing commented 2 years ago

Oh yes I understand now! Thank you very much :) With all this information I'm sure I'll get something working! Best, GJ