databrickslabs / tempo

API for manipulating time series on top of Apache Spark: lagged time values, rolling statistics (mean, avg, sum, count, etc), AS OF joins, downsampling, and interpolation
https://pypi.org/project/dbl-tempo
Other
303 stars 50 forks source link

Time-series split for cross validation not available in Spark ML library #409

Open cblyton-byte opened 1 week ago

cblyton-byte commented 1 week ago

Scikit learn includes a useful time-series split for cross validation. Importantly, this split does not shuffle data and ensures test sets follow training sets in time. See here: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html

A similar feature is not available in Spark ML. I have written an implementation here (may not be very efficient):

def time_series_split(output, n_splits):

creates a new column in Spark DF 'output' called 'row_num'

w = Window().partitionBy(lit('a')).orderBy(lit('a'))
output = output.withColumn("row_num", row_number().over(w))

# creates three lists of indices for 'row_num'
num_rows = output.count()
num_rows_per_split = int(num_rows / (n_splits + 1))

end_train_row_list = []
start_test_row_list = []
end_test_row_list = []
for i in range(1, n_splits + 1):
    end_train_row_list.append(i * num_rows_per_split + 1)
    start_test_row_list.append(i * num_rows_per_split + 2)
    end_test_row_list.append(i * num_rows_per_split + 2 + num_rows_per_split)

# creates new Spark DFs using lists of indices for 'row_num'
train_data = {}
test_data = {}
for i in range(n_splits):
    train_data[f"train_data_{i}"] = output.filter(col("row_num").between(1, end_train_row_list[i]))
    test_data[f"test_data_{i}"] = output.filter(col("row_num").between(start_test_row_list[i], end_test_row_list[i]))

return train_data, test_data

Such a feature in Spark ML would be very useful.

tnixon commented 1 week ago

Looks like your code creates splits based on the number of rows (trying to keep a balanced number of rows in each split) - I'm assuming this is desired behavior for your use case.

Another possible way to generate the splits is by even slices of time, regardless of how many records fall into those intervals. I can see some use cases would want to split that way, too.

cblyton-byte commented 1 week ago

For this application, each row is a 15 min average, i.e. the timestamps are evenly spaced. Hence even time splits or even rows splits would be equivalent in this case and work for our application.

In general (nice to have but not required in not our application) it would be good to support even time splits for unevenly spaced data.