AISViz / Soft-DTW

Soft-DTW loss function for Keras/TensorFlow
GNU Affero General Public License v3.0
4 stars 2 forks source link

Gradient Calculation #6

Open stellarpower opened 2 months ago

stellarpower commented 2 months ago

Hey,

I'm trying to implement the backward pass explicitly, using the equation from the paper (and other repos' implementations), in an effort to improve the speed.

If I understand correctly, as the master branch here doesn't include a custom gradient, tensorflow will be using its automatic differentiator to compute the gradients.

However, obviously the algorithm for the forward pass is quite complicated - we have loops, and the softmin is implemented in Cython, which wouldn't be automatically differentiable (although maybe this has no effect on the gradient). I'm therefore wondering, do we know if the gradients tensorflow computes automatically are correct? Have they been verified thus far and checked to be numerically close to computing using the explicit expression?

Or am I missing something, and it's calculated a different way?

Thanks

stellarpower commented 2 months ago

So, I have been in the process of implementing the backwards pass calculations in a branch here, and as I have been testing that, I've noticed that for the gradients, there’s a discrepancy between the implementation here and for at least one of the Torch implementations (in this case, this version, using a numba CUDA kernel

The losses I get for each sequence in a batch are identical (first row), but the gradients are all zeros bar the last (second row).

My branch master Torch Implementation
image image
image image image

Given that there are so many zeros - and that amazingly it seems my implementation produced the same numbers - I suspect it may be the expression the auto-differentiator comes up with is not valid, which I think would then invalidate use of the loss function in general - unless something else is going on here, e.g. loss reductions(?)

gabrielspadon commented 2 months ago

Hi @stellarpower, thanks for the feedback. The implementation follows https://arxiv.org/abs/1703.01541, https://rtavenar.github.io/ml4ts_ensai/contents/align/softdtw.html, and https://github.com/mblondel/soft-dtw. We will review the results next week to check for anything we let pass.

stellarpower commented 2 months ago

Okay, great, thanks. My testing setup was not exactly rigorous at this stage, but here's a bit of a messy test setup. Hopefully it's reasonable to follow, but let me know if anything needs explanation. The testing file for Soft-DTW is in my branch.

Currently I'm trying to see if the performance can be improved (conversation #4), as TF does not seem to be parallelising that well. Still debating the idea of just writing the kernel in SYCL and a custom TF op.