Maghoumi / pytorch-softdtw-cuda

Fast CUDA implementation of (differentiable) soft dynamic time warping for PyTorch
MIT License
626 stars 59 forks source link

Batch of variable sequence length #11

Closed daidedou closed 3 years ago

daidedou commented 3 years ago

Hi, Thanks for the implementation! I was wondering if there is a way to handle batch with sequences of variable size? Suppose I have length1, length2 variable that contains all sequences lenghts.

I guess for the forward function we just have to change the R[:, -2, -2] to something like R[:, -length1-1, -length2-1], but I'm not sure about what to do with the backward function.

Do you know if there is any mathematical document with the detailed computation for backward function?

Maghoumi commented 3 years ago

Hello,

Yes I actually did an implementation with variable batch sizes a while back. The way I handled it in my implementation was to derive a bool mask for the batch from the length of each sequence. In the forward() function, all irrelevant values were set to inf while in the backward() function all irrelevant entries were simply discarded. The reason for using a mask (rather than calculating things the way you described) was to squeeze out some performance by avoiding unnecessary branching in the CUDA kernel and leveraging mass-indexing operations.

Though the implementation seemed functional (without rigorous testing), it had significant performance issues not only in the variable-length version, but also with the current fixed-length version. I had a few ideas to explore but unfortunately I got carried away with other stuff and I never got back to it.

Overall, I personally cannot think of any CUDA-friendly approach that runs sufficiently fast for adding support for variable-length batches.

daidedou commented 3 years ago

Hi, Thank you for your reply! Too bad that the implementation is not as easy as it seems in a first view! I guess I'll just do an ugly for loop or something like this when I'll need it. Plus it seems that nobody is really trying to handle variable lengths in training (in the context of dynamic alignement) so I'll do the same too.