BUZZ-Blimps / CatchingBlimp

4 stars 3 forks source link

Train a new yolo model #46

Open Rhyme0730 opened 1 month ago

Rhyme0730 commented 1 month ago

import os import cv2 import numpy as np from PIL import Image, ImageDraw import matplotlib.pyplot as plt

def resize_image(image_path, new_size): img = Image.open(image_path) img = img.resize(new_size, Image.LANCZOS) return img

def scale_bounding_boxes(labels, original_size, new_size): scaled_labels = [] for label in labels: class_id, x, y, w, h = map(float, label.strip().split())

    # Convert from center coordinates to top-left, bottom-right
    x1 = (x - w/2) * original_size[0]
    y1 = (y - h/2) * original_size[1]
    x2 = (x + w/2) * original_size[0]
    y2 = (y + h/2) * original_size[1]

    # Scale coordinates
    x1_new = x1 * new_size[0] / original_size[0]
    y1_new = y1 * new_size[1] / original_size[1]
    x2_new = x2 * new_size[0] / original_size[0]
    y2_new = y2 * new_size[1] / original_size[1]

    # Convert back to center coordinates
    x_new = (x1_new + x2_new) / (2 * new_size[0])
    y_new = (y1_new + y2_new) / (2 * new_size[1])
    w_new = (x2_new - x1_new) / new_size[0]
    h_new = (y2_new - y1_new) / new_size[1]

    scaled_labels.append(f"{int(class_id)} {x_new} {y_new} {w_new} {h_new}\n")

return scaled_labels

def draw_bounding_boxes(image, labels): draw = ImageDraw.Draw(image) for label in labels: class_id, x, y, w, h = map(float, label.split()) left = (x - w/2) image.width top = (y - h/2) image.height right = (x + w/2) image.width bottom = (y + h/2) image.height draw.rectangle([left, top, right, bottom], outline="red", width=2) return image

def process_dataset(images_folder, labels_folder, output_folder, plots_folder, new_size, plot_results): if not os.path.exists(output_folder): os.makedirs(output_folder) os.makedirs(os.path.join(output_folder, 'images')) os.makedirs(os.path.join(output_folder, 'labels'))

if not os.path.exists(plots_folder):
    os.makedirs(plots_folder)

for filename in os.listdir(images_folder):
    if filename.endswith(('.jpg', '.png', '.jpeg')):
        image_path = os.path.join(images_folder, filename)
        label_path = os.path.join(labels_folder, os.path.splitext(filename)[0] + '.txt')

        # Original image with bounding boxes
        original_image = Image.open(image_path)
        with open(label_path, 'r') as f:
            original_labels = f.readlines()
        original_with_boxes = draw_bounding_boxes(original_image.copy(), original_labels)

        # Resize image
        resized_image = resize_image(image_path, new_size)

        # Scale bounding boxes
        scaled_labels = scale_bounding_boxes(original_labels, original_image.size, new_size)

        # Resized image with scaled bounding boxes
        resized_with_boxes = draw_bounding_boxes(resized_image.copy(), scaled_labels)

        if plot_results:
            # Create plot
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
            ax1.imshow(original_with_boxes)
            ax1.set_title('Original Image')
            ax1.axis('off')
            ax2.imshow(resized_with_boxes)
            ax2.set_title('Resized Image')
            ax2.axis('off')
            plt.tight_layout()

            # Save plot
            plot_path = os.path.join(plots_folder, f'plot_{filename[:-4]}.png')
            plt.savefig(plot_path)
            plt.close()

        # Save resized image and scaled labels
        resized_image.save(os.path.join(output_folder, 'images', filename))
        with open(os.path.join(output_folder, 'labels', os.path.splitext(filename)[0] + '.txt'), 'w') as f:
            f.writelines(scaled_labels)

        print(f"Processed {filename}")

Usage

images_folder = '/home/sahaj/Downloads/150_YOLO_dataset/images' labels_folder = '/home/sahaj/Downloads/150_YOLO_dataset/labels' output_folder = '/home/sahaj/Downloads/150_YOLO_dataset/resized' plots_folder = '/home/sahaj/Downloads/150_YOLO_dataset/plots' plot_results = False new_size = (640, 640)

process_dataset(images_folder, labels_folder, output_folder, plots_folder, new_size, plot_results)