zzw922cn / TF2_soft_dtw

Custom TensorFlow2 implementations of forward and backward computation of soft-DTW algorithm in batch mode.
21 stars 1 forks source link

How to use it as a custom loss function in tf? #3

Open wangshuo1994 opened 2 years ago

wangshuo1994 commented 2 years ago

Thank you very much for your share! Could you help me figure out how to implement the soft dtw as a custom loss in the deep learning training framework (in tf)? What I hope is to replace the RMSE/MAE.. metrics by the dtw. Thank you in advance.

jackz314 commented 1 year ago

Did you ever figure out how to do this? I believe the goal is to wrap/transform the function into a subclass of Loss, but I'm not familiar with TensorFlow or the implementation here enough to do this successfully yet.

Christoper-Harvey commented 1 year ago

@wangshuo1994 @jackz314 did either of you ever figure this out? I have been trying to implement this with TF2's custom gradient functions but haven't been able to get the backward pass to update anything. It is spitting out a value but the gradient is missing.

JayKumarr commented 8 months ago

https://github.com/AISViz/Soft-DTW