d4nst / RotNet

https://d4nst.github.io/2017/01/12/image-orientation/
MIT License
537 stars 187 forks source link

Train with smaller classes #50

Open KazuhideMimura opened 1 year ago

KazuhideMimura commented 1 year ago

Hi, thank you for sharing nice programs.

The following changes to utils.py will allow training with fewer classes. Rotating with 4 classes may be especially useful for training with small data. There seems to be something wrong in visualization (function display_example), but I was able to train four classes.

def angle_difference(x, y, nb_classes = 360):
    """
    Calculate minimum difference between two angles.
    """
    assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360'
    unit_angle = 360 // nb_classes
    return 180 - abs(abs(x - y) * unit_angle - 180)

def angle_error(y_true, y_pred):
    """
    Calculate the mean diference between the true angles
    and the predicted angles. Each angle is represented
    as a binary vector.
    """
    diff = angle_difference(K.argmax(y_true), K.argmax(y_pred), nb_classes = y_pred.shape[1])
    return K.mean(K.cast(K.abs(diff), K.floatx()))

(omitted)

@class RotNetDataGenerator

def __init__(self, input, input_shape=None, color_mode='rgb', batch_size=64,
                one_hot=True, preprocess_func=None, rotate=True, crop_center=False,
                crop_largest_rect=False, shuffle=False, seed=None, nb_classes = 360): # nb_classes is added

        assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360' # inserted

        self.images = None
        self.filenames = None
        self.input_shape = input_shape
        self.color_mode = color_mode
        self.batch_size = batch_size
        self.one_hot = one_hot
        self.preprocess_func = preprocess_func
        self.rotate = rotate
        self.crop_center = crop_center
        self.crop_largest_rect = crop_largest_rect
        self.shuffle = shuffle
        self.nb_classes = nb_classes # added
        self.unit_angle = 360 // nb_classes # added
        ...

(omitted)

def _get_batches_of_transformed_samples(self, index_array):
    # create array to hold the images
    batch_x = np.zeros((len(index_array),) + self.input_shape, dtype='float32')
    # create array to hold the labels
    batch_y = np.zeros(len(index_array), dtype='float32')

    # iterate through the current batch
    for i, j in enumerate(index_array):
        if self.filenames is None:
            image = self.images[j]
        else:
            is_color = int(self.color_mode == 'rgb')
            image = cv2.imread(self.filenames[j], is_color)
            if is_color:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.rotate:
            # get a random angle
            rotation_angle = self.unit_angle * np.random.randint(self.nb_classes)
        else:
            rotation_angle = 0

        # generate the rotated image
        rotated_image = generate_rotated_image(
            image,
            rotation_angle,
            size=self.input_shape[:2],
            crop_center=self.crop_center,
            crop_largest_rect=self.crop_largest_rect
        )

        # add dimension to account for the channels if the image is greyscale
        if rotated_image.ndim == 2:
            rotated_image = np.expand_dims(rotated_image, axis=2)

        # store the image and label in their corresponding batches
        batch_x[i] = rotated_image
        batch_y[i] = rotation_angle // self.unit_angle

    if self.one_hot:
        # convert the numerical labels to binary labels
        batch_y = to_categorical(batch_y, self.nb_classes) # modified
    else:
        batch_y /= self.nb_classes
    ...

(omitted)

def display_examples(model, input, num_images=5, size=None, crop_center=False,
                    crop_largest_rect=False, preprocess_func=None, save_path=None,
                    nb_classes = 360): # nb_class was added
    """
    Given a model that predicts the rotation angle of an image,
    and a NumPy array of images or a list of image paths, display
    the specified number of example images in three columns:
    Original, Rotated and Corrected.
    """
    assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360' # added
    unit_angle = 360 // nb_classes # added

    if isinstance(input, (np.ndarray)):
        images = input
        N, h, w = images.shape[:3]
        if not size:
            size = (h, w)
        indexes = np.random.choice(N, num_images)
        images = images[indexes, ...]
    else:
        images = []
        filenames = input
        N = len(filenames)
        indexes = np.random.choice(N, num_images)
        for i in indexes:
            image = cv2.imread(filenames[i])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images.append(image)
        images = np.asarray(images)

    x = []
    y = []
    for image in images:
        rotation_angle = np.random.randint(nb_classes) * unit_angle
        rotated_image = generate_rotated_image(
            image,
            rotation_angle,
            size=size,
            crop_center=crop_center,
            crop_largest_rect=crop_largest_rect
        )
        x.append(rotated_image)
        y.append(rotation_angle // unit_angle)

    x = np.asarray(x, dtype='float32')
    y = np.asarray(y, dtype='float32')

    if x.ndim == 3:
        x = np.expand_dims(x, axis=3)

    y = to_categorical(y, nb_classes)

    x_rot = np.copy(x)

    if preprocess_func:
        x = preprocess_func(x)

    y = np.argmax(y, axis=1)
    y_pred = np.argmax(model.predict(x), axis=1)

    plt.figure(figsize=(10.0, 2 * num_images))

    title_fontdict = {
        'fontsize': 14,
        'fontweight': 'bold'
    }

    fig_number = 0
    for rotated_image, true_angle, predicted_angle in zip(x_rot, y, y_pred):
        true_angle *= unit_angle # added
        predicted_angle *= unit_angle # added
        ...
HAsarvesh commented 8 months ago

HI @KazuhideMimura , I tried to train the model for 4 classes, its giving ValueError: Error when checking target: expected fc360 to have shape (4,) but got array with shape (360,). do we need to change any files other than utils.py

KazuhideMimura commented 8 months ago

Hi @HAsarvesh , thank you for the repo. Would you provide me with a detailed error log? I'd like to know where the error occurred.

HAsarvesh commented 8 months ago

Hi @KazuhideMimura , thanks for the reply. error is fixed, I just gave the value for parameter nb_classes in RotNetDataGenerator class as 4. But after training the model, while testing I'm not getting the proper output i.e., the images are as it is, they are not rotating. I'm using custom data and I'm giving the raw images, do I need to manipulate the images before training?

KazuhideMimura commented 8 months ago

You don't need add special treatment to the original image. I think the possibilities are:

  1. Training did not go well
  2. nb_classes is not designated

For example, if you run a function display example...

display_examples(
    model, 
    test_filenames,
    num_images=num_images,
    size=(224, 224),
    crop_center=True,
    crop_largest_rect=True,
    preprocess_func=preprocess_input,
    nb_classes=4, # designate nb_class
)

Please note that nb_class will be set at 360 unless designated. This is because I didn't want to disturb the original concept.

HAsarvesh commented 8 months ago

Thanks for the clarification, I'll train it again and check.

KazuhideMimura commented 8 months ago

Here?

https://github.com/d4nst/RotNet/blob/a56ea59818bbdd76d4dd8d83b8bbbaae6a802310/correct_rotation.py#L32

    predictions = model.predict_generator(
        RotNetDataGenerator(
            image_paths,
            input_shape=(224, 224, 3),
            batch_size=64,
            one_hot=True,
            preprocess_func=preprocess_input,
            rotate=False,
            crop_largest_rect=True,
            crop_center=True,
            nb_classes = model.layers[-1].units, # added. Not sure it works, but statement is required to obtain model's class number.
        ),
        val_samples=len(image_paths)
    )
HAsarvesh commented 8 months ago

Thanks @KazuhideMimura, I'll try this.