visipedia / tf_classification

Training, evaluation and testing code for image classification using TensorFlow
MIT License
132 stars 35 forks source link

Question on image processing and data augmentation #10

Open Arkadeep-sophoIITG opened 6 years ago

Arkadeep-sophoIITG commented 6 years ago

How to turn off the data augmentation part ? Like I just want to resize the original image and train the model with it, not on the augmented images because I have sufficient data . Can you please guide me regarding this ?

gvanhorn38 commented 6 years ago

Sure, so to turn of the augmentation you'll need your IMAGE_PROCESSING section of your config script to look something like:

IMAGE_PROCESSING : {
    # All images will be resized to the [INPUT_SIZE, INPUT_SIZE, 3]
    INPUT_SIZE : 299,

    # 1) First we extract regions from the image
    # What type of region should be extracted, either 'image' or 'bbox'
    REGION_TYPE : 'image',

    # Specific whole image region extraction configuration
    WHOLE_IMAGE_CFG : {},

    # 2) Then we take a random crop from the region
    # The fraction of time to take a random crop, 0 is never, 1 is always
    DO_RANDOM_CROP : 0,
    RANDOM_CROP_CFG : {
        MIN_AREA : 0.5, # between 0 and 1, how much of the region must be included
        MAX_AREA : 1.0, # between 0 and 1, how much of the region can be included
        MIN_ASPECT_RATIO : 0.7, # minimum aspect ratio of the crop
        MAX_ASPECT_RATIO : 1.33, # maximum aspect ratio of the crop
        MAX_ATTEMPTS : 100, # maximum number of attempts before returning the whole region
    },

    # Alternatively we can take a central crop from the image
    DO_CENTRAL_CROP : 0, # Fraction of the time to take a central crop, 0 is never, 1 is always
    CENTRAL_CROP_FRACTION : 0.875, # Between 0 and 1, fraction of size to crop

    # 3) We need to resize the extracted regions to feed into the network.
    MAINTAIN_ASPECT_RATIO : false,
    # Avoid slower resize operations (bi-cubic, etc.)
    RESIZE_FAST : false,

    # 4) We can flip the regions
    # Randomly flip the image left right, 50% chance of flipping
    DO_RANDOM_FLIP_LEFT_RIGHT : false,

    # 5) We can distort the colors of the regions
    # The fraction of time to distort the color, 0 is never, 1 is always
    DO_COLOR_DISTORTION : 0,
    # Avoids slower ops (random_hue and random_contrast)
    COLOR_DISTORT_FAST : false
}

This configuration will simply resize the images to 299x299.

You can visualize the inputs by running:

$ python visualize_train_inputs.py \
--tfrecords <path to your tfrecords> \
--config <path to your config.yaml file>
Arkadeep-sophoIITG commented 6 years ago

I did the same but still the tensorboard shows 4 identical images with different summaries. 0.original_image 1.image_with_random_crop 2.cropped_resized_image 3.final_distorted_image as defined in the summaries.

I think this has to do with the piece of code below from input.py. Even if every parameter is set to 0 or false in the image_processing section , augmentation still takes place. Can you please tell me how to turn this off which will allow me to train on only the resized original images ?

def bbox_crop_loop_cond(original_image, bboxes, distorted_inputs, image_summaries, current_index):
    num_bboxes = tf.shape(bboxes)[0]
    return current_index < num_bboxes

def get_distorted_inputs(original_image, bboxes, cfg, add_summaries):

    distorter = DistortedInputs(cfg, add_summaries)
    num_bboxes = tf.shape(bboxes)[0]
    distorted_inputs = tf.TensorArray(
        dtype=tf.float32,
        size=num_bboxes,
        element_shape=tf.TensorShape([1, cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3])
    )

    if add_summaries:
        image_summaries = tf.TensorArray(
            dtype=tf.float32,
            size=4,
            element_shape=tf.TensorShape([1, cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3])
        )
    else:
        image_summaries = tf.constant([])

    current_index = tf.constant(0, dtype=tf.int32)

    loop_vars = [original_image, bboxes, distorted_inputs, image_summaries, current_index]
    original_image, bboxes, distorted_inputs, image_summaries, current_index = tf.while_loop(
        cond=bbox_crop_loop_cond,
        body=distorter.apply,
        loop_vars=loop_vars,
        parallel_iterations=10, back_prop=False, swap_memory=False
    )

    distorted_inputs = distorted_inputs.concat()

    if add_summaries:
        tf.summary.image('0.original_image', image_summaries.read(0))
        tf.summary.image('1.image_with_random_crop', image_summaries.read(1))
        tf.summary.image('2.cropped_resized_image', image_summaries.read(2))
        tf.summary.image('3.final_distorted_image', image_summaries.read(3))

    return distorted_inputs
gvanhorn38 commented 6 years ago

Ah, I see. But in TensorBoard are you seeing the same image for all of the summaries? Hopefully that is the case. The summary images will still be produced (and shown on tensorboard) regardless of whether the augmentation is occurring.

Arkadeep-sophoIITG commented 6 years ago

Yes, same images. I will train again and confirm. So, my model is being trained on all 4 input images (duplicated images) right or only on the original image? So, if i set add_summaries=False, only one image will be appearing on the tensorboard?

gvanhorn38 commented 6 years ago

Cool. The only image that is being queued for training is the last summary image (3.final_distorted_image). The other image summaries are a convenience to see the intermediate augmentations that were applied to the original image.

Arkadeep-sophoIITG commented 6 years ago

Thanks for the help. One more question, can you please suggest how do I evaluate the model on my training set.. like I have 78k images , so what batch-size and number of batches should I keep and the test config file should also be changed likewise right(As done in the train config file) ?

gvanhorn38 commented 6 years ago

Yeah, you can keep the same IMAGE_PROCESSING configs for your test situation.

For the batch size, what I'll do is use the following code snippet (taken from here) to find a few different numbers that evenly divide the total number of test images:

from functools import reduce

def factors(n):    
    return set(reduce(list.__add__, 
                ([i, n//i] for i in range(1, int(pow(n, 0.5) + 1)) if n % i == 0)))

So factors(78000) gives a lot of options, and I choose something reasonable, like 30. Then I use the test script like so:

python test.py \
--tfrecords <path to tfrecords> \
--save_dir <path to store the tensorboard summary files> \
--checkpoint_path <path to the model> \
--config <path to the test_config.yaml file> \
--batch_size 30 \
--batches 2600
Arkadeep-sophoIITG commented 6 years ago

Thanks for the help. One more : What's the ACCURACY_AT_K_METRIC parameter in the test config file ? Here's what I understood: Correct me if I'm wrong. So, the parameter is to be set to 1 for a binary classification problem. And for a multi class classification problem of n classes , if I set it to [k.m] (k<m<n) then it will count it as correctly classified if the model predicts it in between the top k and m labels (both inclusive) ?

gvanhorn38 commented 6 years ago

The ACCURACY_AT_K_METRIC is way of specifying the accuracy measurement you would like rendered on TensorBoard. Specifying ACCURACY_AT_K_METRIC : [3, 5] means that you would like accuracy at 3 and accuracy at 5 rendered (two separate plots will be created)(this assumes you have at least 5 classes for your task). "accuracy at 3" is referring to the metric that an image is classified correctly if the correct class is among the top 3 most confident predictions.

Arkadeep-sophoIITG commented 6 years ago

Can we use the repo with other network architectures like mobilenet_v2 ? Or do we need to make appropriate changes ? ((https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py)