Open carstenf opened 6 months ago
I probably found a solution, but not fully tested:
class MultipleTimeSeriesCV:
"""Generates tuples of train_idx, test_idx pairs
Assumes the MultiIndex contains levels 'symbol' and 'date'
purges overlapping outcomes"""
def __init__(self,
n_splits=3,
train_period_length=126,
test_period_length=21,
lookahead=None,
date_idx='date',
shuffle=False):
self.n_splits = n_splits
self.lookahead = lookahead
self.test_length = test_period_length
self.train_length = train_period_length
self.shuffle = shuffle
self.date_idx = date_idx
def split(self, X, y=None, groups=None):
unique_dates = X.index.get_level_values(self.date_idx).unique()
days = sorted(unique_dates) # Ascending order
split_idx = []
for i in range(self.n_splits):
# Calculate split indices based on ascending order of days
train_start_idx = i * self.test_length
train_end_idx = train_start_idx + self.train_length
test_start_idx = train_end_idx + (self.lookahead or 0)
test_end_idx = test_start_idx + self.test_length
# Ensure we do not exceed the length of days
if test_end_idx >= len(days):
break
split_idx.append((train_start_idx, train_end_idx, test_start_idx, test_end_idx))
dates = X.reset_index()[[self.date_idx]]
for train_start, train_end, test_start, test_end in split_idx:
# Adjust the condition to select the right slice based on sorted ascending days
train_idx = dates[(dates[self.date_idx] >= days[train_start]) &
(dates[self.date_idx] < days[train_end])].index
test_idx = dates[(dates[self.date_idx] >= days[test_start]) &
(dates[self.date_idx] < days[test_end])].index
if self.shuffle:
train_idx = np.random.permutation(train_idx)
yield train_idx.to_numpy(), test_idx.to_numpy()
def get_n_splits(self, X, y, groups=None):
return self.n_splits
the new result:
It looks like the splits are the wrong way around and should be reversed. The first split will be used first to compute, than the second split should not use information from the first calculation. This looks like the other way around, please see the plot.
I made a small pice of code for plotting:
`