johertrich / sliced_kernel_fastsum

Fast Kernel Summation via Slicing and Fourier Transforms
https://arxiv.org/abs/2401.08260
MIT License
1 stars 0 forks source link

add keops to your work #1

Open LekangJia opened 1 week ago

LekangJia commented 1 week ago

Hello, this is a perfect work! have you tried to add pykeops to your work so that it will has a lower memory cost? This sometimes will be a problem, for example, if the value of Kernel-based methods output is used as loss function in machine learning field, all the varible in different batches will be stored to be used in backpropagation. In this situation, the batch size method can not make memory cost become lower. May be combine keops and your work will be a good direction to make memory cost and time cost lower at the same time. I wander if you have tried or plan to try this. Best wishes.

johertrich commented 1 week ago

Hi, thanks for your interest! I experimented a bit to put keops in the 1D computations with mixed success (in the sense I was not really able to beat the NFFT performance-wise). However, you can use batching within the slicing also when you are using it as a loss function without storing all 1D-operations (and I would recommend to do so), by modifiying the backward pass of the backprop function in torch. Unfortunately, this week is really busy for me, but I can write a few lines how this works later this week or in the beginning of next week. Best Johannes