ztsv-av / spellbook

Shortcuts to your ML workflow!
4 stars 1 forks source link

Load Data in Parts to Fasten Training #20

Closed ztsv-av closed 2 years ago

ztsv-av commented 2 years ago
ztsv-av commented 2 years ago

done. here is the code that you asked:

train.py

train_paths_list = getFullPaths(train_paths)
val_paths_list = getFullPaths(val_paths)

for epoch in range(num_epochs):

    train_paths_list_shuffled = shuffle(train_paths_list)
    val_paths_list_shuffled = shuffle(val_paths_list)

    total_loss = 0.0
    num_batches = 0

    for part in range(max_fileparts_train):

        start_part_time = time.time()

        train_filepaths_part = train_paths_list_shuffled[
            int(part * len(train_paths_list_shuffled) / max_fileparts_train) :
            int(((part + 1) / max_fileparts_train) * len(train_paths_list_shuffled))]

        start_load_data_time = time.time()
        print('Loading Data...', flush=True)

        train_distributed_part = prepareClassificationDataset(
            batch_size, train_filepaths_part, permutations, normalization, strategy, is_val=False)

        end_load_data_time = time.time()

        print('Finished Loading Data. Time Passed: ' + str(end_load_data_time - start_load_data_time), flush=True)

        for batch in train_distributed_part:

            total_loss += wrapperTrain(
                batch, model, compute_total_loss, optimizer, train_accuracy, strategy)

            num_batches += 1

        end_part_time = time.time()

        print('training: part ' + str(part) + '/' + str(max_fileparts_train) +
            ', passed time: ' + str(end_part_time - start_part_time), flush=True)

    train_loss = total_loss / num_batches

    for part in range(max_fileparts_val):

        start_part_time = time.time()

        val_filepaths_part = val_paths_list_shuffled[
            int(part * len(val_paths_list_shuffled) / max_fileparts_val) :
            int(((part + 1) / max_fileparts_val) * len(val_paths_list_shuffled))]

        start_load_data_time = time.time()
        print('Loading Data...', flush=True)

        val_distributed_batch = prepareClassificationDataset(
            batch_size, val_filepaths_part, None, normalization, strategy, is_val=True)

        end_load_data_time = time.time()
        print('Finished Loading Data. Time Passed: ' + str(end_load_data_time - start_load_data_time), flush=True)

        for batch in val_distributed_batch:

            wrapperVal(
                batch, model, loss_object, val_loss, val_accuracy, strategy)

        end_part_time = time.time()

        print('validation: part ' + str(part) + '/' + str(max_fileparts_val) +
            ', passed time: ' + str(end_part_time - start_part_time), flush=True)
prepareClassificationDataset.py

images_list_part = []
labels_list_part = []
for path in filepaths_part:
    images_list_part.append(loadNumpy(path))
    labels_list_part.append(tf.convert_to_tensor(getLabelFromFilename(path)))

images_map = map(lambda image: permuteImageGetLabelBoxes(
    image, permutations, normalization, is_val, bboxes=None, bbox_format=None, is_detection=False), images_list_part)
images_mapped_list_part = list(images_map)

data_part = tf.data.Dataset.from_tensor_slices(
    (images_mapped_list_part, labels_list_part))

data_part = data_part.batch(batch_size)
data_part_dist = strategy.experimental_distribute_dataset(
    data_part)

return data_part_dist
ztsv-av commented 2 years ago

major time fix was calling tf.convert_to_tensor() on image and label before calling tf.data.Dataset.from_tensor_slices()