Closed PPWangyc closed 1 month ago
Thanks for bringing this to my attention. I added replace=False
flag and confirmed that this indeed solves the issue.
Please note that the expected behavior here is the following:
sum_intervals = len(train_intervals) + len(test_intervals) # finetune_intervals is part of train_intervals
assert sum_intervals == len(intervals), f"Sum of intervals is not equal to the original intervals: {sum_intervals} != {len(intervals)}"
Hi,
I encountered an issue in the following function:
When I run the default command:
The function
split_data_by_interval
produces inconsistent results. Specifically:len(intervals): 30116
len(train_intervals): 24092
len(test_intervals): 13547
The sum of
train_intervals
andtest_intervals
is greater than the total number of intervals, meaning thatchosen_idx
is sampling repeated indices intrain_intervals
. As a result, the actual number of unique entries intrain_intervals
is only 16569, meaning that the training dataset is not actually 80% of the total, as expected.This discrepancy happens because
np.random.choice
allows repeated sampling, which leads to inflated counts.Thanks!