Maghoumi / pytorch-softdtw-cuda

Fast CUDA implementation of (differentiable) soft dynamic time warping for PyTorch
MIT License
626 stars 59 forks source link

Sakoe chiba band #24

Open toinsson opened 2 years ago

toinsson commented 2 years ago

This PR implement some way of using the Sakoe-Chiba bands for non-squared matrices. It does so by computing the approximate diagonal: the indices of the smallest input matrix are interpolated onto the indices of the biggest input matrix. This should address #8 and #18.

The case where N=4, M=7 and bandwidth=1 gives something similar to: D = [ [x, x, -, -, -, -, -] [-, x, x, -, -, -, -] [-, -, -, x, x, -, -] [-, -, -, -, -, x, x] ]

Note that:

This was tested like so:
In [1]: import torch

In [2]: import soft_dtw_cuda

In [3]: a = torch.rand((100, 40, 15))

In [4]: b = torch.rand((100, 50, 15))

In [5]: sdtw_cpu_sc3 = soft_dtw_cuda.SoftDTW(False, bandwidth=3)

In [6]: res_a = sdtw_cpu_sc3(a, b)

In [7]: torch.any(res_a == torch.inf)
Out[7]: tensor(False)

In [8]: sdtw_gpu_sc3 = soft_dtw_cuda.SoftDTW(True, bandwidth=3)

In [9]: res_b = sdtw_gpu_sc3(a.cuda(), b.cuda())

In [10]: torch.any(res_b == torch.inf)
Out[10]: tensor(False, device='cuda:0')

In [11]: torch.allclose(res_a, res_b.cpu())
Out[11]: True
Maghoumi commented 2 years ago

Thanks for this great contribution, really appreciate it! :)

I need some time to study and verify it. In the mean time, could you explain this a bit more?

the rounding in the interpolation of the indices (for instance, i_sc = i * N / M) could be formalised. I did not look too much into it.

toinsson commented 2 years ago

Rounding, could you explain this a bit more? There is a proper way of computing the interpolated indices, especially with regards to rounding the result of the division between integers. For example, should the result be rounded to the closest integer, or floored or ceiled?

pyts does seem to have done this with care, i.e. in the function pyts.metrics.sakoe_chiba_band but that is a lot of code : ), and I am not sure I want to spend too much time on these intricacies..

The PR as it is now does it in a naive way, but affords OK results for non-squared matrices (instead of returning inf). This could or should be tested against pyts maybe?