araffin / learning-to-drive-in-5-minutes

Implementation of reinforcement learning approach to make a car learn to drive smoothly in minutes
https://towardsdatascience.com/learning-to-drive-smoothly-in-minutes-450a7cdb35f4
MIT License
287 stars 85 forks source link

[question] [feature request] Validation for VAE #37

Open meric-sakarya opened 3 years ago

meric-sakarya commented 3 years ago

Hello, I am trying to train your VAE for a project of my own and I noticed there is no validation part in the training. Is there an easy way to add validation to that training? Could you help me with how I might do that using your DataLoader? Thanks in advance.

kncrane commented 3 years ago

If you import train_test_split from sklearn.model_selection you can do something like this

# CHANGED JPG TO PNG
images = [im for im in os.listdir(args.folder) if im.endswith('.png')]
images = np.array(images)
n_samples = len(images)

if args.n_samples > 0:
    n_samples = min(n_samples, args.n_samples)

# indices for all time steps where the episode continues
indices = np.arange(n_samples, dtype='int64')
np.random.shuffle(indices)

# NEW SECTION THAT SPLITS INDICES INTO A TRAIN AND VAL SET FIRST BEFORE BATCHING
indices_df = pd.DataFrame(indices, columns = ['indices'])
train_series, val_series = train_test_split(indices_df['indices'], train_size = 0.8)
train = train_series.to_numpy()
val = val_series.to_numpy()

print("{} images in total".format(n_samples))
print("{} images in training set".format(len(train)))
print("{} images in validation set".format(len(val)))

# split indices into minibatches. minibatchlist is a list of lists; each
# list is the id of the observation preserved through the training
train_minibatchlist = [np.array(sorted(train[start_idx:start_idx + args.batch_size]))
                 for start_idx in range(0, len(train) - args.batch_size + 1, args.batch_size)]

val_minibatchlist = [np.array(sorted(val[start_idx:start_idx + args.batch_size]))
                 for start_idx in range(0, len(val) - args.batch_size + 1, args.batch_size)]

train_data_loader = DataLoader(train_minibatchlist, images, n_workers=2, folder=args.folder)
val_data_loader = DataLoader(val_minibatchlist, images, n_workers=2, folder=args.folder)

This is within train.py, no need to edit DataLoader class