tslearn-team / tslearn

The machine learning toolkit for time series analysis in Python
https://tslearn.readthedocs.io
BSD 2-Clause "Simplified" License
2.91k stars 339 forks source link

KNeighborsTimeSeriesClassifier with DTW slowness #407

Open liorLimudim opened 2 years ago

liorLimudim commented 2 years ago

Greetings As part of testing of different machine learning models I test the KNN model together with DTW and encounter a great slowness in the predictor function of the classifier, I would be very grateful if you could explain to me why there is such a great slowness and how to calibrate the classifier in such a way that it is not slow, attach the code That I run Thank you Lior

import numpy as np from tslearn.generators import random_walk_blobs from tslearn.neighbors import KNeighborsTimeSeriesClassifier, KNeighborsTimeSeries import time

np.random.seed(0) n_ts_per_blob, sz = 20000, 100 #, 1, 1 # , d, n_blobs X, y = random_walk_blobs(n_ts_per_blob=n_ts_per_blob, sz=sz) print(y) print(X.shape) print(len(y))

start_time = time.time() knn = KNeighborsTimeSeriesClassifier(n_neighbors=1,metric="dtw",metric_params=dict({"global_constraint":"sakoe_chiba","sakoe_chiba_radius":3}),n_jobs=-1) end_time = time.time() print("Creating the Classifier took: {} seconds".format(end_time - start_time))

start_time = time.time() knn.fit(X,y) end_time = time.time() print("FIT function took: {} seconds".format(end_time - start_time))

start_time = time.time() y_predict = knn.predict(X) end_time = time.time() print("Predict function took: {} seconds".format(end_time - start_time))

GillesVandewiele commented 2 years ago

DTW is a rather slow function (O(m**2) with m the length of time series)), and KNN will internally keep track of pairwise distances of those 20K time series (although with optimized data structures such as BallTree and KDTree). Hence, why it is rather slow.

liorLimudim commented 2 years ago

Hi Gilles Thanks for the answer. is there a way to calibrate the classifier ? for example with dtw_path_limited_warping_length parameter ? Lior