a-antoniades / Neuroformer

MIT License
30 stars 3 forks source link

Dataset Train/Test Split error #6

Closed PPWangyc closed 1 month ago

PPWangyc commented 1 month ago

Hi,

I encountered an issue in the following function:

def split_data_by_interval(intervals, r_split=0.8, r_split_ft=0.1):
    chosen_idx = np.random.choice(len(intervals), int(len(intervals) * r_split))
    train_intervals = intervals[chosen_idx]
    test_intervals = np.array([i for i in intervals if i not in train_intervals])
    finetune_intervals = np.array(train_intervals[:int(len(train_intervals) * r_split_ft)])
    return train_intervals, test_intervals, finetune_intervals

When I run the default command:

python neuroformer_train.py \
    --dataset lateral \
    --config configs/Visnav/lateral/mconf_pretrain.yaml

The function split_data_by_interval produces inconsistent results. Specifically:

The sum of train_intervals and test_intervals is greater than the total number of intervals, meaning that chosen_idx is sampling repeated indices in train_intervals. As a result, the actual number of unique entries in train_intervals is only 16569, meaning that the training dataset is not actually 80% of the total, as expected.

This discrepancy happens because np.random.choice allows repeated sampling, which leads to inflated counts.

Thanks!

a-antoniades commented 1 month ago

Thanks for bringing this to my attention. I added replace=False flag and confirmed that this indeed solves the issue.

Please note that the expected behavior here is the following:

sum_intervals = len(train_intervals) + len(test_intervals) # finetune_intervals is part of train_intervals
assert sum_intervals == len(intervals), f"Sum of intervals is not equal to the original intervals: {sum_intervals} != {len(intervals)}"