JSchuurmans / tddl

Tensor Decomposition for Deep Learning
European Union Public License 1.1
2 stars 0 forks source link

class-balanced validation set #15

Closed JSchuurmans closed 2 years ago

JSchuurmans commented 2 years ago

Make the validation set balanced with respect to the classes.

JSchuurmans commented 2 years ago

PyTorch split data keeping equal proportions

https://discuss.pytorch.org/t/how-to-split-test-and-train-data-keeping-equal-proportions-of-each-class/21063

[train_D, valid_D, train_L, valid_L] = train_test_split( WholeData.numpy(), WholeTargetArray, test_size=0.2, train_size=0.8, shuffle=True, stratify=WholeTargetArray )

DatasetTrain=Dataset(train_D,train_L) DatasetValid=Dataset(valid_D,valid_L)

trainloader = torch.utils.data.DataLoader( DatasetTrain, batch_size=32, shuffle=True, drop_last=True, num_workers=0 )

validationloader = torch.utils.data.DataLoader( DatasetValid, batch_size=6, drop_last=True, num_workers=0 )


[train_D, valid_D,train_L,valid_L]= train_test_split(WholeData.numpy(),WholeTargetArray, test_size=0.2,train_size=0.8, shuffle=True, stratify=WholeTargetArray)

DatasetTrain = TensorDataset(torch.from_numpy(train_D),torch.from_numpy(train_L))

DatasetValid = TensorDataset(torch.from_numpy(valid_D),torch.from_numpy(valid_L))

trainloader=torch.utils.data.DataLoader(DatasetTrain,batch_size=32,shuffle=True,drop_last=True, num_workers=0)

validationloader=torch.utils.data.DataLoader(DatasetValid, batch_size=6, drop_last=True,num_workers=0)


JSchuurmans commented 2 years ago

KFold validation with sklearn and PyTorch

https://www.machinecurve.com/index.php/2021/02/03/how-to-use-k-fold-cross-validation-with-pytorch/#summary-and-code-example-k-fold-cross-validation-with-pytorch

torch.manual_seed(42)

dataset_train_part = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor(), train=True) dataset_test_part = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor(), train=False) dataset = ConcatDataset([dataset_train_part, dataset_test_part])

kfold = KFold(n_splits=k_folds, shuffle=True)

for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)): train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids) test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

trainloader = torch.utils.data.DataLoader(
                  dataset, 
                  batch_size=10, sampler=train_subsampler)

testloader = torch.utils.data.DataLoader(
                  dataset,
                  batch_size=10, sampler=test_subsampler)
JSchuurmans commented 2 years ago

CIFAR dataloader based on indices

https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb

# load the dataset
train_dataset = datasets.CIFAR10(
    root=data_dir, train=True,
    download=True, transform=train_transform,
)

valid_dataset = datasets.CIFAR10(
    root=data_dir, train=True,
    download=True, transform=valid_transform,
)

num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))

if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)
JSchuurmans commented 2 years ago

CIFAR dataloader based on indices

https://discuss.pytorch.org/t/train-on-a-fraction-of-the-data-set/16743/12

cifar_dataset = torchvision.datasets.CIFAR10(root='./data', transform=transform)

select train indices according to your rule

train_indices =

select test indices according to your rule

test_indices =

train_loader = torch.utils.data.DataLoader( cifar_dataset, batch_size=32, shuffle=True, sampler=SubsetRandomSampler(train_indices) ) test_loader = torch.utils.data.DataLoader( cifar_dataset, batch_size=32, shuffle=True, sampler=SubsetRandomSampler(test_indices) )

JSchuurmans commented 2 years ago

A combination of the above solved the issue. (TODO check commit)