vincefn / pyvkfft

Python interface to VkFFT
MIT License
51 stars 6 forks source link

Convolution support #33

Open vincefn opened 9 months ago

vincefn commented 9 months ago

Hi @Dtolm , now that the release is out, I made some tests with on-the-fly convolution following Osamu's email exchange which picked my curiosity.

There is now a branch with convolution support

What I have seen (but I made so far a very limited number of tests):

The tests are all visible on the pyvkfft-convolve notebook.

DTolm commented 9 months ago

Hello

The convolutions code design is 3 years old, I mostly did the things needed for my Master's thesis with them then - matrix vector convolutions for multidimensional systems. So not all things have been fully implemented - R2C convolutions breaking for 1D is expected as I didn't know how to combine all the things (R2C decomposition and convolution) in one kernel then. It will be easier to implement now with the new modular structure.

The 3D convolutions not working seems to be a bug that I think I have fixed on the dev branch (the modified test 51 passes now).

The numberBatches should work the same as coordinateFeatures now as well (unless you use the matrix-vector functionality).

vincefn commented 9 months ago

Thanks ! I am almost finished with the release (needed to update the conda packages), so I can look at this. The 3D convolutions work nicely, thanks !

I was wondering (since I do not yet completely understand how the coordinateFeatures work), if it is possible to perform the following: use an array of shape n_batch*ny*nx, and perform a 2D convolution with a single array of size ny*nx ? Can numberBatches and coordinateFeatures be configured to perform that ?

It can be very useful to compute the cross-correlation of N images vs a single reference, or for near-field propagation (this is practically a 2D convolution of stacks of array, with the same kernel).

DTolm commented 8 months ago

coordinateFeatures should behave the same as numberBatches, unless you do matrix-vector multiplication convolutions (kernel is a matrix, system is a vector). It is a second form of batching (both work at the same time), since omitDimension is not working with convolutions right now.

As for multiple input - single kernel convolutions, I have only implemented reverse so far (one system, multiple kernels, multiple outputs). I can make it work for this case as well, as using the reverse option should have worse performance.

vincefn commented 8 months ago

As for multiple input - single kernel convolutions, I have only implemented reverse so far (one system, multiple kernels, multiple outputs). I can make it work for this case as well, as using the reverse option should have worse performance.

Yes, this would be very interesting e.g. for multiple (batched) images alignment vs a single reference: 1 batch of 2D images and a single reference image

DTolm commented 8 months ago

@vincefn I have added an option to do this functionality and an example that shows how to set it up (53). It didn't require any big changes, so, hopefully, it will work straight away.

vincefn commented 8 months ago

Hi @DTolm, I've now updated the code to support various types of batch transforms, e.g. an array shape of (nbatch, ny, nx) with the same kernel shape, or a smaller kernel shape e.g. (ny, nx) using singleKernelMultipleBatches or even (nbatchk, ny, nx) using both singleKernelMultipleBatches and coordinatefeatures as long as nbatch is a multiple of nbatchk. Nice !

Now I have found issues with some odd transforms, e.g. (on my mac) a 1D+convolution transform of shape 3*7 fails with a compilation error (same for many odd transforms with small primes). It's true for a C2C but also R2C, regardless of other parameters (in/out, batched or not,..).

Also, in-place R2C transforms work for 2D and 3D, but not out-of-place. I guess it's not be possible since the complex array needs 1 or 2 extra bytes, and I don't think there's an easy way around this. (incidentally this is related to the discussion in https://github.com/DTolm/VkFFT/issues/159)

vincefn commented 8 months ago

Hi @DTolm,

I've updated the code so I can use batch convolution also for cuda, and also the systematic command-line test can be used for convolution.

I've also clarified the systems which work (c2c, inplace r2c ndim>1, radix, single upload only).

Here's an example of test between 2 and 128 for c2c out-of-place where we can see which radix sizes are failing (here using cuda on an A4500) - always compilation errors, some missing }. (there are odd and even sizes contrary to my previous message):

  pycuda C2C⨂           (2,2) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.6e-08 ninf=9.6e-08 < 2.3e-06 (0.042) 1 buf=    0   OK  
  pycuda C2C⨂           (3,3) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.7e-07 ninf=1.6e-07 < 2.5e-06 (0.066) 1 buf=    0   OK  
  pycuda C2C⨂           (4,4) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.6e-07 ninf=1.4e-07 < 2.6e-06 (0.055) 1 buf=    0   OK  
  pycuda C2C⨂           (5,5) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.1e-07 ninf=2.6e-07 < 2.7e-06 (0.095) 1 buf=    0   OK  
  pycuda C2C⨂           (6,6) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.7e-07 ninf=1.4e-07 < 2.8e-06 (0.049) 1 buf=    0   OK  
  pycuda C2C⨂           (7,7) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.9e-07 ninf=2.2e-07 < 2.8e-06 (0.078) 1 buf=    0   OK  
  pycuda C2C⨂           (8,8) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.6e-07 ninf=1.7e-07 < 2.9e-06 (0.060) 1 buf=    0   OK  
  pycuda C2C⨂           (9,9) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.9e-07 ninf=2.3e-07 < 3.0e-06 (0.079) 1 buf=    0   OK  
  pycuda C2C⨂         (10,10) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.3e-07 ninf=2.4e-07 < 3.0e-06 (0.080) 1 buf=    0   OK  
  pycuda C2C⨂         (11,11) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.2e-07 ninf=2.2e-07 < 3.0e-06 (0.071) 1 buf=    0   OK  
  pycuda C2C⨂         (12,12) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.1e-07 ninf=2.4e-07 < 3.1e-06 (0.079) 1 buf=    0   OK  
  pycuda C2C⨂         (13,13) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.0e-07 ninf=2.6e-07 < 3.1e-06 (0.084) 1 buf=    0   OK  
  pycuda C2C⨂         (14,14) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.8e-07 ninf=2.2e-07 < 3.1e-06 (0.071) 1 buf=    0   OK  
  pycuda C2C⨂         (15,15) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.3e-07 ninf=2.7e-07 < 3.2e-06 (0.085) 1 buf=    0   OK  
  pycuda C2C⨂         (16,16) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=2.6e-07 ninf=3.3e-07 < 3.2e-06 (0.102) 1 buf=    0   OK  
  pycuda C2C⨂         (18,18) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.9e-07 ninf=4.9e-07 < 3.3e-06 (0.150) 1 buf=    0   OK  
  pycuda C2C⨂         (20,20) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.9e-07 ninf=6.0e-07 < 3.3e-06 (0.182) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(21, 21), primes='3×7', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(22, 22), primes='2×11', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (24,24) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=3.4e-07 ninf=3.4e-07 < 3.4e-06 (0.102) 1 buf=    0   OK  
  pycuda C2C⨂         (25,25) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.3e-07 ninf=8.9e-07 < 3.4e-06 (0.263) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(26, 26), primes='2×13', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (27,27) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=6.1e-07 < 3.4e-06 (0.178) 1 buf=    0   OK  
  pycuda C2C⨂         (28,28) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=3.0e-07 ninf=3.8e-07 < 3.4e-06 (0.111) 1 buf=    0   OK  
  pycuda C2C⨂         (30,30) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.5e-07 ninf=6.4e-07 < 3.5e-06 (0.185) 1 buf=    0   OK  
  pycuda C2C⨂         (32,32) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=3.4e-07 ninf=3.5e-07 < 3.5e-06 (0.100) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(33, 33), primes='3×11', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (35,35) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.2e-07 ninf=5.7e-07 < 3.5e-06 (0.162) 1 buf=    0   OK  
  pycuda C2C⨂         (36,36) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.9e-07 ninf=5.0e-07 < 3.6e-06 (0.142) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(39, 39), primes='3×13', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (40,40) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=5.4e-07 < 3.6e-06 (0.149) 1 buf=    0   OK  
  pycuda C2C⨂         (42,42) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.6e-07 ninf=4.8e-07 < 3.6e-06 (0.133) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(44, 44), primes='2²×11', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (45,45) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.6e-07 ninf=5.3e-07 < 3.7e-06 (0.146) 1 buf=    0   OK  
  pycuda C2C⨂         (48,48) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.9e-07 ninf=4.9e-07 < 3.7e-06 (0.134) 1 buf=    0   OK  
  pycuda C2C⨂         (49,49) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=6.1e-07 < 3.7e-06 (0.165) 1 buf=    0   OK  
  pycuda C2C⨂         (50,50) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.9e-07 ninf=7.9e-07 < 3.7e-06 (0.213) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(52, 52), primes='2²×13', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂         (54,54) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.1e-07 ninf=8.2e-07 < 3.7e-06 (0.221) 1 buf=    0   OK  
  pycuda C2C⨂         (55,55) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.2e-07 ninf=6.7e-07 < 3.7e-06 (0.180) 1 buf=    0   OK  
  pycuda C2C⨂         (56,56) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=4.8e-07 ninf=5.3e-07 < 3.7e-06 (0.142) 1 buf=    0   OK  
  pycuda C2C⨂         (60,60) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.0e-07 ninf=6.8e-07 < 3.8e-06 (0.180) 1 buf=    0   OK  
  pycuda C2C⨂         (63,63) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.6e-07 ninf=7.0e-07 < 3.8e-06 (0.184) 1 buf=    0   OK  
  pycuda C2C⨂         (64,64) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.0e-07 ninf=5.6e-07 < 3.8e-06 (0.147) 1 buf=    0   OK  
  pycuda C2C⨂         (65,65) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.5e-07 ninf=7.7e-07 < 3.8e-06 (0.202) 1 buf=    0   OK  
  pycuda C2C⨂         (66,66) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=6.4e-07 < 3.8e-06 (0.167) 1 buf=    0   OK  
  pycuda C2C⨂         (70,70) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.8e-07 ninf=6.1e-07 < 3.8e-06 (0.159) 1 buf=    0   OK  
  pycuda C2C⨂         (72,72) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.0e-07 ninf=6.6e-07 < 3.9e-06 (0.170) 1 buf=    0   OK  
  pycuda C2C⨂         (75,75) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.8e-07 ninf=6.9e-07 < 3.9e-06 (0.178) 1 buf=    0   OK  
  pycuda C2C⨂         (77,77) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.6e-07 ninf=6.8e-07 < 3.9e-06 (0.174) 1 buf=    0   OK  
  pycuda C2C⨂         (78,78) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.9e-07 ninf=7.0e-07 < 3.9e-06 (0.180) 1 buf=    0   OK  
  pycuda C2C⨂         (80,80) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.9e-07 ninf=8.3e-07 < 3.9e-06 (0.212) 1 buf=    0   OK  
  pycuda C2C⨂         (81,81) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.7e-07 ninf=7.8e-07 < 3.9e-06 (0.198) 1 buf=    0   OK  
  pycuda C2C⨂         (84,84) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=7.2e-07 < 3.9e-06 (0.184) 1 buf=    0   OK  
  pycuda C2C⨂         (88,88) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.6e-07 ninf=5.5e-07 < 3.9e-06 (0.139) 1 buf=    0   OK  
  pycuda C2C⨂         (90,90) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=9.0e-07 ninf=9.0e-07 < 4.0e-06 (0.229) 1 buf=    0   OK  
  pycuda C2C⨂         (91,91) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=6.6e-07 < 4.0e-06 (0.167) 1 buf=    0   OK  
  pycuda C2C⨂         (96,96) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.9e-07 ninf=7.8e-07 < 4.0e-06 (0.195) 1 buf=    0   OK  
  pycuda C2C⨂         (98,98) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.7e-07 ninf=7.7e-07 < 4.0e-06 (0.193) 1 buf=    0   OK  
  pycuda C2C⨂         (99,99) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.9e-07 ninf=6.4e-07 < 4.0e-06 (0.161) 1 buf=    0   OK  
  pycuda C2C⨂       (100,100) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.1e-07 ninf=7.5e-07 < 4.0e-06 (0.186) 1 buf=    0   OK  
  pycuda C2C⨂       (104,104) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=5.7e-07 ninf=5.4e-07 < 4.0e-06 (0.134) 1 buf=    0   OK  
  test_systematic (pyvkfft.test.test_fft.TestFFTSystematic.test_systematic) (backend='pycuda', shape=(105, 105), primes='3×5×7', ndim=2, dtype=dtype('float32'), norm=1, use_lut=False, inplace=False, r2c=False, dct=False, dst=False, fstride=False, convolve=True) ... ERROR
  pycuda C2C⨂       (108,108) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=8.3e-07 ninf=8.2e-07 < 4.0e-06 (0.203) 1 buf=    0   OK  
  pycuda C2C⨂       (110,110) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.3e-07 ninf=7.8e-07 < 4.0e-06 (0.193) 1 buf=    0   OK  
  pycuda C2C⨂       (112,112) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.0e-07 ninf=5.7e-07 < 4.0e-06 (0.140) 1 buf=    0   OK  
  pycuda C2C⨂       (117,117) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.2e-07 ninf=6.6e-07 < 4.1e-06 (0.162) 1 buf=    0   OK  
  pycuda C2C⨂       (120,120) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.1e-07 ninf=6.4e-07 < 4.1e-06 (0.157) 1 buf=    0   OK  
  pycuda C2C⨂       (121,121) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.4e-07 ninf=6.3e-07 < 4.1e-06 (0.155) 1 buf=    0   OK  
  pycuda C2C⨂       (125,125) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=1.1e-06 ninf=1.1e-06 < 4.1e-06 (0.262) 1 buf=    0   OK  
  pycuda C2C⨂       (126,126) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=7.3e-07 ninf=7.7e-07 < 4.1e-06 (0.188) 1 buf=    0   OK  
  pycuda C2C⨂       (128,128) axes=        None ndim=   2    rr    11  complex64 lut=False inplace=0  norm=   1 C   FFT: n2=6.8e-07 ninf=8.5e-07 < 4.1e-06 (0.207) 1 buf=    0   OK  
DTolm commented 7 months ago

Hello,

Sorry for the long reply, I am currently busy with another project. I will investigate the systems failing in the near future. Thank you for reporting them.

Best regards, Dmitrii