arcelien / pba

Efficient Learning of Augmentation Policy Schedules
https://bair.berkeley.edu/blog/2019/06/07/data_aug/
Apache License 2.0
505 stars 86 forks source link

use for own dataset #10

Open tienduchoang opened 5 years ago

tienduchoang commented 5 years ago

hello everyone, can you show me any guide to run augmentation on my own dataset.

arcelien commented 5 years ago

You should update the dataloader, the Dataset class in pba/data_utils.py to load your dataset. It should work for any image size.

You can see load_test() in pba/data_utils.py or scripts/test_search.sh for an example.

tienduchoang commented 5 years ago

Hello @arcelien Thank for your respond. Because your input data is resized in fix size. like below:

train_data = train_data.reshape(-1, 3, 32, 32)

How to work with a folder that contain different size of images.

sayakpaul commented 5 years ago

I was looking for this and I am pretty sure many people will be, An example notebook might be extremely helpful in that case. Maybe an end-to-end demo covering augmenting your own dataset using PBA and train a model using them. It will really be immensely helpful.

I ran the pba.ipynb demo notebook and it went pretty smoothly.

arcelien commented 5 years ago

Note that the line train_data = train_data.reshape(-1, 3, 32, 32) is in the load_cifar() function which loads CIFAR data. You should define a new function that loads your dataset instead with the correct size.

See load_test() for an example of loading 224 x 224 data. The image size is passed around in a self.image_size variable in the dataloader not constrained to any fixed size.

I will look into such a notebook, but may not have the bandwidth for that very soon.

tienduchoang commented 5 years ago

@arcelien @sayakpaul thank for your help.

tienduchoang commented 5 years ago

Hi @arcelien, i have word image dataset, i want to use your code to augment this dataset for my OCR system. Each images have width in range(26 - 150), height in range(26 - 40). Total number of classes is 1, because dataset just contain word images.

I don't know how to get start with this dataset. Do you have any guide to help me augment this dataset?

My sample dataset follow this link:

https://drive.google.com/drive/folders/1pUiWY6PZwvNWCSX46vvD3g9QCfKH7yTk?usp=sharin

Thank you!

zengruizhao commented 5 years ago

Hi @arcelien , can we only load all data in data_loader? But it's not friendly to large data sets. And I want to know whether we can build a data generator. Looking forward your reply. Thanks.

arcelien commented 5 years ago

The way the current dataloader is implemented only supports loading all data at once, but it should not be too difficult to adapt this to generate data on demand or use a new dataloader which does so.

To adapt this, I'd imagine we could use something like the torchvision.datasets.SVHN data generator to load data in the next_batch() function and apply all the augmentation/transformations then instead of in the initialization.

441711335 commented 5 years ago

Hi @arcelien , how can we get " MEAN" and "STD" in files of augmentation_transforms.py while using PBT in my own database.

arcelien commented 5 years ago

If you're using the default dataloader you can pass the flag --recompute_dset_stats while training. Otherwise, you can just compute the per-channel mean and standard deviation of your dataset (i.e. call mean and std on your data tensors).

441711335 commented 5 years ago

@arcelien thank you for your help!

jaspreet-sambee commented 5 years ago

The way the current dataloader is implemented only supports loading all data at once, but it should not be too difficult to adapt this to generate data on demand or use a new dataloader which does so.

To adapt this, I'd imagine we could use something like the torchvision.datasets.SVHN data generator to load data in the next_batch() function and apply all the augmentation/transformations then instead of in the initialization.

Hi @arcelien , Hvae a doubt on this, probably did not understand it clearly. 1) Even if we load the data in the 'next_batch()' function, aren't we still loading the whole dataset into memory? Or did you mean to create a kind of data iterator which loads only a batch inside the next_batch function? 2) I am currently using a dataset with imagenet like images of size 224x224x3 with a resnet50 model. But i see it takes a lot of GPU memory for each of the workers and hence other workers have to wait for the first few workers to finish. I am not clear how so much GPU memory (roughly 13GB) is being used per worker. Am i missing something in the data_loader side?

Thanks!

arcelien commented 5 years ago
  1. I'm not certain, but I don't think pytorch's dataloader loads the whole dataset into memory.

  2. Perhaps you can look at the batch size? Or you can set the fractional gpu size per worker to something smaller.

jaspreet-sambee commented 5 years ago
  1. Perhaps you can look at the batch size? Or you can set the fractional gpu size per worker to something smaller.

My batch size is only 16. And if i set the GPU size anything lesser than 0.6 (~13GB), it runs out of GPU memory. I see that all the workers call the load_data from RAY and hence each of them loads all the data. And in the next batch function, each worker is sampling a batch. And i hope each worker is not putting all the data on the GPU or is my understanding wrong here?

arcelien commented 5 years ago

It could be that with ImageNet, you just need to use much smaller batches or distribute over multiple GPUs or run fewer trials concurrently. For reference, with a (32x32x3) input like CIFAR/SVHN, the size per trial is ~2GB with a network like ResNet 50.

jaspreet-sambee commented 5 years ago

I actually changed the dataloader so that it loads the data on the fly. Unfortunately could not get very good results. Have you tried getting some numbers for Imagenet by any chance?

arcelien commented 5 years ago

Unfortunately, I haven't tried ImageNet. The original method from AutoAugment requires some special handling (Section 4.2 in https://arxiv.org/pdf/1805.09501.pdf):

we use a reduced subset of the ImageNet training set, with 120 classes (randomly chosen) and 6,000 samples, to search for policies. We train a Wide-ResNet 40-2 using cosine decay for 200 epochs...

AutoAugment hadn't released their code for ImageNet at the time and I unfortunately did not have the compute to redo their hyperparameter tuning to achieve good results. I'd imagine you wouldn't get great results if you didn't carefully tune things like LR/WD/BS in both the search phase and the re-training phase.

441711335 commented 5 years ago

Hi @arcelien while using tfrecords as my database I have met the problem that I can't start the "coords" and "threads" to get my data in data_loader.next_batch . Can you give me some advice ? Thanks!

Xujianzhong commented 4 years ago

Hi @tienduchoang , I have the same problem as you. Have you solved it now? Can you give me some advice ?

441711335 commented 4 years ago

Hi @arcelien , while I have run the scripts search.sh in my own datasets and I have get some files in "results" ,the question is that how to get the schedules like show in file "schedules/rcifar10_16_wrn.txt", which can directly used in training procedures. Looking forward to your reply!

arcelien commented 4 years ago

Hey, there should be a set of files in the results folder with the schedules corresponding to each worker. You can take a look at the tensorboard to select which worker schedule to pick.

441711335 commented 4 years ago

@arcelien Your help is much appreciated!

monkeyDemon commented 4 years ago

Hey, guys. I think I've successfully applied search to my own dataset which need to make some changes to the data_utils.py. The core idea is save img paths in self.train_images but not load the real data into memory as the original file. On this basis, the image is actually read into memory in function next_batch() You can do the preprocess the way you want in next_batch(). So the problem that want to work with a folder that contain different size of images asked by @tienduchoang has been solved. And you can also run on a large dataset and don't have trouble with memory. Don't forget to do normalization after load images in next_batch(), and "MEAN" and "STD" need to be computed by your self(I used a separate script to calculate it) and add into augmentation_transforms.py, like this:

MEANS = {
    'cifar10_50000': [0.49139968, 0.48215841, 0.44653091],
    'cifar10_4000': [0.49056774, 0.48116026, 0.44726052],
    'cifar100_50000': [0.50707516, 0.48654887, 0.44091784],
    'svhn_1000': [0.45163885, 0.4557915, 0.48093327],
    'svhn-full_604388': [0.43090966, 0.4302428, 0.44634357],
    'test_10000': [0.458078, 0.4862745, 0.530588]  # TODO:
}
STDS = {
    'cifar10_50000': [0.24703223, 0.24348513, 0.26158784],
    'cifar10_4000': [0.24710728, 0.24451308, 0.26235099],
    'cifar100_50000': [0.26733429, 0.25643846, 0.27615047], 
    'svhn_1000': [0.20385217, 0.20957996, 0.20804394],
    'svhn-full_604388': [0.19652855, 0.19832038, 0.19942076],
    'test_10000': [0.16196, 0.15887, 0.1668]  # TODO:
}

I put the detailed modification below, which may be a little long.

modify test_search.sh to:

#!/bin/bash
export PYTHONPATH="$(pwd)"

python pba/search.py \
    --local_dir "$PWD/results/" \
    --model_name wrn_40_2 \
    --dataset test \
    --train_size 10000 --val_size 3000 \
    --checkpoint_freq 0 \
    --name "test_search" --gpu 0.49 --cpu 2 \
    --num_samples 2 --perturbation_interval 3 --epochs 6 \
    --explore cifar10 --aug_policy cifar10 \
    --lr 0.1 --wd 0.0005 --bs 2 --test_bs 2

modify data_utils.py to:

# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utils."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
try:
    import cPickle as pickle
except:
    import pickle
import os
import random
import numpy as np
from PIL import Image
import tensorflow as tf
import torchvision

from autoaugment.data_utils import unpickle
import pba.policies as found_policies
from pba.utils import parse_log_schedule
import pba.augmentation_transforms_hp as augmentation_transforms_pba
import pba.augmentation_transforms as augmentation_transforms_autoaug

# pylint:disable=logging-format-interpolation

def parse_policy(policy_emb, augmentation_transforms):
    policy = []
    num_xform = augmentation_transforms.NUM_HP_TRANSFORM
    xform_names = augmentation_transforms.HP_TRANSFORM_NAMES
    assert len(policy_emb
               ) == 2 * num_xform, 'policy was: {}, supposed to be: {}'.format(
                   len(policy_emb), 2 * num_xform)
    for i, xform in enumerate(xform_names):
        policy.append((xform, policy_emb[2 * i] / 10., policy_emb[2 * i + 1]))
    return policy

def shuffle_data(data, labels):
    """Shuffle data using numpy."""
    np.random.seed(0)
    perm = np.arange(len(data))
    np.random.shuffle(perm)
    data = data[perm]
    labels = labels[perm]
    return data, labels

class DataSet(object):
    """Dataset object that produces augmented training and eval data."""

    def __init__(self, hparams):
        self.hparams = hparams
        self.epochs = 0
        self.curr_train_index = 0

        self.parse_policy(hparams)
        self.load_data(hparams)

        ## Apply normalization
        #self.train_images = self.train_images.transpose(0, 2, 3, 1) / 255.0
        #self.val_images = self.val_images.transpose(0, 2, 3, 1) / 255.0
        #self.test_images = self.test_images.transpose(0, 2, 3, 1) / 255.0
        #if not hparams.recompute_dset_stats:
        #    mean = self.augmentation_transforms.MEANS[hparams.dataset + '_' +
        #                                              str(hparams.train_size)]
        #    std = self.augmentation_transforms.STDS[hparams.dataset + '_' +
        #                                            str(hparams.train_size)]
        #else:
        #    mean = self.train_images.mean(axis=(0, 1, 2))
        #    std = self.train_images.std(axis=(0, 1, 2))
        #    self.augmentation_transforms.MEANS[hparams.dataset + '_' + str(hparams.train_size)] = mean
        #    self.augmentation_transforms.STDS[hparams.dataset + '_' + str(hparams.train_size)] = std
        #tf.logging.info('mean:{}    std: {}'.format(mean, std))

        #self.train_images = (self.train_images - mean) / std
        #self.val_images = (self.val_images - mean) / std
        #self.test_images = (self.test_images - mean) / std

        assert len(self.test_images) == len(self.test_labels)
        assert len(self.train_images) == len(self.train_labels)
        assert len(self.val_images) == len(self.val_labels)
        tf.logging.info('train dataset size: {}, test: {}, val: {}'.format(
            len(self.train_images), len(self.test_images), len(self.val_images)))

    def parse_policy(self, hparams):
        """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list.

        If list is not nested, then uses the same policy for all epochs.

        Args:
        hparams: tf.hparams object.
        """
        # Parse policy
        if hparams.use_hp_policy:
            self.augmentation_transforms = augmentation_transforms_pba

            if isinstance(hparams.hp_policy,
                          str) and hparams.hp_policy.endswith('.txt'):
                if hparams.num_epochs % hparams.hp_policy_epochs != 0:
                    tf.logging.warning(
                        "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.",
                        hparams.num_epochs, hparams.hp_policy_epochs)
                tf.logging.info(
                    'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}'
                    .format(
                        hparams.hp_policy_epochs, hparams.hp_policy,
                        float(hparams.num_epochs) / hparams.hp_policy_epochs))
                raw_policy = parse_log_schedule(
                    hparams.hp_policy,
                    epochs=hparams.hp_policy_epochs,
                    multiplier=float(hparams.num_epochs) /
                    hparams.hp_policy_epochs)
            elif isinstance(hparams.hp_policy,
                            str) and hparams.hp_policy.endswith('.p'):
                assert hparams.num_epochs % hparams.hp_policy_epochs == 0
                tf.logging.info('custom .p file, policy number: {}'.format(
                    hparams.schedule_num))
                with open(hparams.hp_policy, 'rb') as f:
                    policy = pickle.load(f)[hparams.schedule_num]
                raw_policy = []
                for num_iters, pol in policy:
                    for _ in range(num_iters * hparams.num_epochs //
                                   hparams.hp_policy_epochs):
                        raw_policy.append(pol)
            else:
                raw_policy = hparams.hp_policy

            if isinstance(raw_policy[0], list):
                self.policy = []
                split = len(raw_policy[0]) // 2
                for pol in raw_policy:
                    cur_pol = parse_policy(pol[:split],
                                           self.augmentation_transforms)
                    cur_pol.extend(
                        parse_policy(pol[split:],
                                     self.augmentation_transforms))
                    self.policy.append(cur_pol)
                tf.logging.info('using HP policy schedule, last: {}'.format(
                    self.policy[-1]))
            elif isinstance(raw_policy, list):
                split = len(raw_policy) // 2
                self.policy = parse_policy(raw_policy[:split],
                                           self.augmentation_transforms)
                self.policy.extend(
                    parse_policy(raw_policy[split:],
                                 self.augmentation_transforms))
                tf.logging.info('using HP Policy, policy: {}'.format(
                    self.policy))

        else:
            self.augmentation_transforms = augmentation_transforms_autoaug
            tf.logging.info('using ENAS Policy or no augmentaton policy')
            if 'svhn' in hparams.dataset:
                self.good_policies = found_policies.good_policies_svhn()
            else:
                assert 'cifar' in hparams.dataset
                self.good_policies = found_policies.good_policies()

    def reset_policy(self, new_hparams):
        self.hparams = new_hparams
        self.parse_policy(new_hparams)
        tf.logging.info('reset aug policy')
        return

    def load_cifar(self, hparams):
        train_labels = []
        test_labels = []
        num_data_batches_to_load = 5
        total_batches_to_load = num_data_batches_to_load
        train_batches_to_load = total_batches_to_load
        assert hparams.train_size + hparams.validation_size <= 50000
        # Determine how many images we have loaded
        train_dataset_size = 10000 * num_data_batches_to_load

        if hparams.dataset == 'cifar10':
            train_data = np.empty(
                (total_batches_to_load, 10000, 3072), dtype=np.uint8)
            datafiles = [
                'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
                'data_batch_5'
            ]
            datafiles = datafiles[:train_batches_to_load]
            test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
            datafiles.append('test_batch')
            num_classes = 10
        elif hparams.dataset == 'cifar100':
            assert num_data_batches_to_load == 5
            train_data = np.empty((1, 50000, 3072), dtype=np.uint8)
            datafiles = ['train']
            test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
            datafiles.append('test')
            num_classes = 100

        for file_num, f in enumerate(datafiles):
            d = unpickle(os.path.join(hparams.data_path, f))
            if f == 'test' or f == 'test_batch':
                test_data[0] = copy.deepcopy(d['data'])
            else:
                train_data[file_num] = copy.deepcopy(d['data'])
            if hparams.dataset == 'cifar10':
                labels = np.array(d['labels'])
            else:
                labels = np.array(d['fine_labels'])
            nsamples = len(labels)
            for idx in range(nsamples):
                if f == 'test' or f == 'test_batch':
                    test_labels.append(labels[idx])
                else:
                    train_labels.append(labels[idx])
        train_data = train_data.reshape(train_dataset_size, 3072)
        test_data = test_data.reshape(10000, 3072)
        train_data = train_data.reshape(-1, 3, 32, 32)
        test_data = test_data.reshape(-1, 3, 32, 32)
        train_labels = np.array(train_labels, dtype=np.int32)
        test_labels = np.array(test_labels, dtype=np.int32)

        self.test_images, self.test_labels = test_data, test_labels
        train_data, train_labels = shuffle_data(train_data, train_labels)
        train_size, val_size = hparams.train_size, hparams.validation_size
        assert 50000 >= train_size + val_size
        self.train_images = train_data[:train_size]
        self.train_labels = train_labels[:train_size]
        self.val_images = train_data[train_size:train_size + val_size]
        self.val_labels = train_labels[train_size:train_size + val_size]
        self.num_classes = num_classes

    def load_svhn(self, hparams):
        train_labels = []
        test_labels = []
        if hparams.dataset == 'svhn':
            assert hparams.train_size == 1000
            assert hparams.train_size + hparams.validation_size <= 73257
            train_loader = torchvision.datasets.SVHN(
                root=hparams.data_path, split='train', download=True)
            test_loader = torchvision.datasets.SVHN(
                root=hparams.data_path, split='test', download=True)
            num_classes = 10
            train_data = train_loader.data
            test_data = test_loader.data
            train_labels = train_loader.labels
            test_labels = test_loader.labels
        elif hparams.dataset == 'svhn-full':
            assert hparams.train_size == 73257 + 531131
            assert hparams.validation_size == 0
            train_loader = torchvision.datasets.SVHN(
                root=hparams.data_path, split='train', download=True)
            test_loader = torchvision.datasets.SVHN(
                root=hparams.data_path, split='test', download=True)
            extra_loader = torchvision.datasets.SVHN(
                root=hparams.data_path, split='extra', download=True)
            num_classes = 10
            train_data = np.concatenate(
                [train_loader.data, extra_loader.data], axis=0)
            test_data = test_loader.data
            train_labels = np.concatenate(
                [train_loader.labels, extra_loader.labels], axis=0)
            test_labels = test_loader.labels
        else:
            raise ValueError(hparams.dataset)

        self.test_images, self.test_labels = test_data, test_labels
        train_data, train_labels = shuffle_data(train_data, train_labels)
        train_size, val_size = hparams.train_size, hparams.validation_size
        if hparams.dataset == 'svhn-full':
            assert train_size + val_size <= 604388
        else:
            assert train_size + val_size <= 73257
        self.train_images = train_data[:train_size]
        self.train_labels = train_labels[:train_size]
        self.val_images = train_data[-val_size:]
        self.val_labels = train_labels[-val_size:]
        self.num_classes = num_classes

    def load_test(self, hparams):
        """Load random data and labels."""
        test_size = 3000
        self.num_classes = 2
        img_root_dir = 'the root img dir of yours'

        labels_dir_list = []
        for label_name in os.listdir(img_root_dir):
            labels_dir_list.append(os.path.join(img_root_dir, label_name))

        total_imgs_path_list = []
        total_imgs_labels = []
        for idx, label_dir in enumerate(labels_dir_list):
            for img_name in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img_name)
                label = idx
                total_imgs_path_list.append(img_path)
                total_imgs_labels.append(label)

        # shuffle
        perm = np.arange(len(total_imgs_labels))
        np.random.shuffle(perm)
        total_imgs_path_list = np.array(total_imgs_path_list)
        total_imgs_labels = np.array(total_imgs_labels)
        total_imgs_path_list = total_imgs_path_list[perm]
        total_imgs_labels = total_imgs_labels[perm]

        train_size, val_size = hparams.train_size, hparams.validation_size
        assert train_size + val_size + test_size <= len(total_imgs_labels)
        self.train_images = total_imgs_path_list[:train_size] 
        self.train_labels = total_imgs_labels[:train_size] 
        self.val_images = total_imgs_path_list[train_size:train_size+val_size] 
        self.val_labels = total_imgs_labels[train_size:train_size+val_size] 
        self.test_images = total_imgs_path_list[-test_size:]
        self.test_labels = total_imgs_labels[-test_size:] 

    def load_data(self, hparams):
        """Load raw data from specified dataset.

        Assumes data is in NCHW format.

        Populates:
            self.train_images: Training image data.
            self.train_labels: Training ground truth labels.
            self.val_images: Validation/holdout image data.
            self.val_labels: Validation/holdout ground truth labels.
            self.test_images: Testing image data.
            self.test_labels: Testing ground truth labels.
            self.num_classes: Number of classes.
            self.num_train: Number of training examples.
            self.image_size: Width/height of image.

        Args:
            hparams: tf.hparams object.
        """
        if hparams.dataset == 'cifar10' or hparams.dataset == 'cifar100':
            self.load_cifar(hparams)
        elif hparams.dataset == 'svhn' or hparams.dataset == 'svhn-full':
            self.load_svhn(hparams)
        elif hparams.dataset == 'test':
            self.load_test(hparams)
        else:
            raise ValueError('unimplemented')

        self.num_train = self.train_images.shape[0]
        self.image_size = 160 

        # make one-hot label
        self.train_labels = np.eye(self.num_classes)[np.array(self.train_labels, dtype=np.int32)]
        self.val_labels = np.eye(self.num_classes)[np.array(self.val_labels, dtype=np.int32)]
        self.test_labels = np.eye(self.num_classes)[np.array(self.test_labels, dtype=np.int32)]
        assert len(self.train_images) == len(self.train_labels)
        assert len(self.val_images) == len(self.val_labels)
        assert len(self.test_images) == len(self.test_labels)

        # convert img_path to real train data
        val_imgs = []
        for img_path in self.val_images:
            cur_img = self._process_single_img(img_path) 
            val_imgs.append(cur_img)
        self.val_images = np.array(val_imgs)
        test_imgs = []
        for img_path in self.test_images:
            cur_img = self._process_single_img(img_path) 
            test_imgs.append(cur_img)
        self.test_images = np.array(test_imgs)

    def _process_single_img(self, img_path):
        img = Image.open(img_path, 'r')

        # make sure channel order is RGB
        if img.mode != 'RGB':
            img = img.convert("RGB")

        # resize(maintain aspect ratio) 
        width, height = img.size
        long_edge_size = self.image_size
        if width > height:
            height = int(height * long_edge_size / width)
            width = long_edge_size
        else:
            width = int(width * long_edge_size / height)
            height = long_edge_size
        img = img.resize((width, height), Image.BILINEAR)

        # padding
        rgb_mean = self.augmentation_transforms.MEANS[self.hparams.dataset + '_' + str(self.hparams.train_size)]
        rgb_std = self.augmentation_transforms.STDS[self.hparams.dataset + '_' + str(self.hparams.train_size)]
        back_color = tuple([int(x * 255.0) for x in rgb_mean])
        img_padd = Image.new('RGB', (long_edge_size, long_edge_size), back_color)
        if width > height:
            h_st = int((long_edge_size - height)/2)
            img_padd.paste(img, (0, h_st))
        else:
            w_st = int((long_edge_size - width)/2)
            img_padd.paste(img, (w_st, 0))

        # convert PIL to numpy
        img = np.array(img_padd)

        # normalization
        img = img / 255.0
        img = (img - rgb_mean) / rgb_std
        return img

    def next_batch(self, iteration=None):
        """Return the next minibatch of augmented data."""
        next_train_index = self.curr_train_index + self.hparams.batch_size
        if next_train_index > self.num_train:
            # Increase epoch number
            epoch = self.epochs + 1
            self.reset()
            self.epochs = epoch

        # convert img_path to real train data
        batch_train_imgs = []
        for img_path in self.train_images[self.curr_train_index:self.curr_train_index + self.hparams.batch_size]:
            cur_img = self._process_single_img(img_path) 
            batch_train_imgs.append(cur_img)

        batched_data = (np.array(batch_train_imgs),
            self.train_labels[self.curr_train_index:self.curr_train_index +
                              self.hparams.batch_size])
        final_imgs = []

        dset = self.hparams.dataset + '_' + str(self.hparams.train_size)
        images, labels = batched_data
        for data in images:
            if not self.hparams.no_aug:
                if not self.hparams.use_hp_policy:
                    # apply autoaugment policy
                    epoch_policy = self.good_policies[np.random.choice(
                        len(self.good_policies))]
                    final_img = self.augmentation_transforms.apply_policy(
                        epoch_policy,
                        data,
                        dset=dset,
                        image_size=self.image_size)
                else:
                    # apply PBA policy)
                    if isinstance(self.policy[0], list):
                        # single policy
                        if self.hparams.flatten:
                            final_img = self.augmentation_transforms.apply_policy(
                                self.policy[random.randint(
                                    0,
                                    len(self.policy) - 1)],
                                data,
                                self.hparams.aug_policy,
                                dset,
                                image_size=self.image_size)
                        else:
                            final_img = self.augmentation_transforms.apply_policy(
                                self.policy[iteration],
                                data,
                                self.hparams.aug_policy,
                                dset,
                                image_size=self.image_size)
                    elif isinstance(self.policy, list):
                        # policy schedule
                        final_img = self.augmentation_transforms.apply_policy(
                            self.policy,
                            data,
                            self.hparams.aug_policy,
                            dset,
                            image_size=self.image_size)
                    else:
                        raise ValueError('Unknown policy.')
            else:
                final_img = data
            if self.hparams.dataset == 'cifar10' or self.hparams.dataset == 'cifar100':
                final_img = self.augmentation_transforms.random_flip(
                    self.augmentation_transforms.zero_pad_and_crop(
                        final_img, 4))
            elif 'svhn' in self.hparams.dataset:
                pass
            else:
                tf.logging.log_first_n(tf.logging.WARN, 'Using default random flip and crop.', 1)
                final_img = self.augmentation_transforms.random_flip(
                    self.augmentation_transforms.zero_pad_and_crop(
                        final_img, 4))
            # Apply cutout
            if not self.hparams.no_cutout:
                if 'cifar10' == self.hparams.dataset:
                    final_img = self.augmentation_transforms.cutout_numpy(
                        final_img, size=16)
                elif 'cifar100' == self.hparams.dataset:
                    final_img = self.augmentation_transforms.cutout_numpy(
                        final_img, size=16)
                elif 'svhn' in self.hparams.dataset:
                    final_img = self.augmentation_transforms.cutout_numpy(
                        final_img, size=20)
                else:
                    tf.logging.log_first_n(tf.logging.WARN, 'Using default cutout size (16x16).', 1)
                    final_img = self.augmentation_transforms.cutout_numpy(
                        final_img)
            final_imgs.append(final_img)
        batched_data = (np.array(final_imgs, np.float32), labels)
        self.curr_train_index += self.hparams.batch_size
        return batched_data

    def reset(self):
        """Reset training data and index into the training data."""
        self.epochs = 0
        # Shuffle the training data
        perm = np.arange(self.num_train)
        np.random.shuffle(perm)
        assert self.num_train == self.train_images.shape[
            0], 'Error incorrect shuffling mask'
        self.train_images = self.train_images[perm]
        self.train_labels = self.train_labels[perm]
        self.curr_train_index = 0
arcelien commented 4 years ago

Appreciate the detailed investigation @monkeyDemon! I'll see if I can check in something along these lines to be helpful for future users as well.