JinheonBaek / GMT

Official Code Repository for the paper "Accurate Learning of Graph Representations with Graph Multiset Pooling" (ICLR 2021)
https://arxiv.org/abs/2102.11533
101 stars 21 forks source link

ask about validation dataset #18

Closed betteryy closed 2 years ago

betteryy commented 2 years ago

Hi, I'm very interested in this method.

There is a question about code in trainers. Is the dataset for validation the same as the test dataset?

    train_idxes = torch.as_tensor(np.loadtxt('./datasets/%s/10fold_idx/train_idx-%d.txt' % (self.args.data, fold_number),
                                            dtype=np.int32), dtype=torch.long)
    val_idxes = torch.as_tensor(np.loadtxt('./datasets/%s/10fold_idx/test_idx-%d.txt' % (self.args.data, val_fold_number),
                                            dtype=np.int32), dtype=torch.long)     
    test_idxes = torch.as_tensor(np.loadtxt('./datasets/%s/10fold_idx/test_idx-%d.txt' % (self.args.data, fold_number),
                                            dtype=np.int32), dtype=torch.long)

i look forward to your reply.

JinheonBaek commented 2 years ago

Hi, Thank you for your interest!

It looks like the validation data is the same as the test data, however, they are different for every fold. In other words, there are no overlaps between training, validation, and test indices in https://github.com/JinheonBaek/GMT/blob/main/trainers/trainer_classification_TU.py#L70.


The dataset is preprocessed under the setting of 10-fold cross validation. Thus, for 10 splits of each dataset, the training fold is [1, 2, 3, 4, 5, 6, 7, 8, 9] for fold_number 10, and the test fold is [10] for fold_number 10.

However, for a fair evaluation, we further use the validation set, where the validation set is not overlapped with training and test sets. In other words, if the fold_number is 10, then the val_fold_number is 8 (see https://github.com/JinheonBaek/GMT/blob/main/trainers/trainer_classification_TU.py#L169), therefore, the training fold is [1, 2, 3, 4, 5, 6, 7, 8] for fold_number 10, the validation fold is [9] for fold_number 10 (val_fold_number 8), and the test fold is [10] for fold_number 10.

If you have any further questions, then please ask them! :)

betteryy commented 2 years ago

Thanks for the reply. :) I didn't see that part. Sorry..

It means that the whole data set was split in a ratio of 8(train):1(test):1(validation) and do 10 times for 10-fold cross-validation. right?

JinheonBaek commented 2 years ago

Yes, exactly!

betteryy commented 2 years ago

Thank you for your kind reply :)