Open meric-sakarya opened 3 years ago
If you import train_test_split from sklearn.model_selection you can do something like this
# CHANGED JPG TO PNG
images = [im for im in os.listdir(args.folder) if im.endswith('.png')]
images = np.array(images)
n_samples = len(images)
if args.n_samples > 0:
n_samples = min(n_samples, args.n_samples)
# indices for all time steps where the episode continues
indices = np.arange(n_samples, dtype='int64')
np.random.shuffle(indices)
# NEW SECTION THAT SPLITS INDICES INTO A TRAIN AND VAL SET FIRST BEFORE BATCHING
indices_df = pd.DataFrame(indices, columns = ['indices'])
train_series, val_series = train_test_split(indices_df['indices'], train_size = 0.8)
train = train_series.to_numpy()
val = val_series.to_numpy()
print("{} images in total".format(n_samples))
print("{} images in training set".format(len(train)))
print("{} images in validation set".format(len(val)))
# split indices into minibatches. minibatchlist is a list of lists; each
# list is the id of the observation preserved through the training
train_minibatchlist = [np.array(sorted(train[start_idx:start_idx + args.batch_size]))
for start_idx in range(0, len(train) - args.batch_size + 1, args.batch_size)]
val_minibatchlist = [np.array(sorted(val[start_idx:start_idx + args.batch_size]))
for start_idx in range(0, len(val) - args.batch_size + 1, args.batch_size)]
train_data_loader = DataLoader(train_minibatchlist, images, n_workers=2, folder=args.folder)
val_data_loader = DataLoader(val_minibatchlist, images, n_workers=2, folder=args.folder)
This is within train.py, no need to edit DataLoader class
Hello, I am trying to train your VAE for a project of my own and I noticed there is no validation part in the training. Is there an easy way to add validation to that training? Could you help me with how I might do that using your DataLoader? Thanks in advance.