wannesm / dtaidistance

Time series distances: Dynamic Time Warping (fast DTW implementation in C)
Other
1.08k stars 184 forks source link

Change Reqest | Custom cost function #189

Closed KarazhovAndrii closed 2 months ago

KarazhovAndrii commented 1 year ago

Hello DTAI Team,

thank you for sharing this great project. Consider in subsequent releases a possibility for users to define cost function, as they could need different metrics to evaluate the distance when computing DTW. Example of the interface (_distfunc optional parameter is added):

def distance(s1, s2, dist_func=None,
             window=None, max_dist=None, max_step=None,
             max_length_diff=None, penalty=None, psi=None,
             use_c=False, use_pruning=False, only_ub=False)
def warping_paths(s1, s2, dist_func=None, window=None, max_dist=None, use_pruning=False,
                  max_step=None, max_length_diff=None, penalty=None, psi=None, psi_neg=True,
                  use_c=False, use_ndim=False):

where _distfunc suppose to have a signature:

def dist_func(a, b):
    return some_computation_over_a_b
wannesm commented 1 year ago

We would need to support this also in the C code to be consistent. While not necessarily difficult to add in one location, it requires changes at multiple locations (e.g. also in the bounds) and increases the cost of maintaining. So we will have to think about how to do this consistent throughout the codebase (we already have this a bit to support ndim series).

If you urgently need this. It is not too difficult to change this in the python part (no compilation required). Simply change these two lines in the code: https://github.com/wannesm/dtaidistance/blob/054e97ed96a1a14f2981d5e76e957148796567a7/dtaidistance/dtw.py#L289 https://github.com/wannesm/dtaidistance/blob/054e97ed96a1a14f2981d5e76e957148796567a7/dtaidistance/dtw.py#L409

wannesm commented 1 year ago

In the master branch, this functionality is now available by using the inner_dist argument. When using the pure Python implementation, this can be any callable function (wrapped in an object that has as callable arguments inner_dist and result). When using the the fast dtw computation in C this needs to be one of 'euclidean' or 'squared euclidean'.

KarazhovAndrii commented 2 months ago

@wannesm , thank you for the updated API - it works. Closing the issue.