Bing-su / adetailer

Auto detecting, masking and inpainting with detection model.
GNU Affero General Public License v3.0
4.19k stars 327 forks source link

[Feature Request]: Classes support for basic YoloV8 models (DIFFERENT FROM WORLD MODELS) #715

Closed swimmingyoshi closed 1 month ago

swimmingyoshi commented 2 months ago

Is your feature request related to a problem? Please describe.

I have a trained YoloV8 model, that detects multiple classes. When i use it with adetailer it will always detect and detail ALL detected areas.

currently adetailer only supports classes for "-world" models. i tried changing the model name to have "-world" in it but it always returned an error.

Describe the solution you'd like

Other detailers allow you to set the classes that you want to be detailed so that the rest are ignored. for example, image contains (head, hands, and eyes) but i only want to detail (eyes, hands). the way most other detailers work is by detecting everything and then filtering the results to only keep the wanted areas.

Describe alternatives you've considered

I managed to find a work around but it involved changing the code inside of ultralytics.py. (im not 100% if i broke anything)

for it to show the "classes" text input in the UI you can just change your model to have "-world" in the name but that can be confusing, so you can also edit ui.py so it detects a new name.

in my example code, i used the word "Bean" because my models were named "YOLO_BeanV8.pt"

# \extensions\adetailer\aaaaaa\ui.py
def on_ad_model_update(model: str):
    if "-world" in model or "Bean" in model:
        return gr.update(
            visible=True,
            placeholder="Comma separated class names to detect, ex: 'person,cat'. default: COCO 80 classes",
        )
    return gr.update(visible=False, placeholder="")
# \extensions\adetailer\adetailer\ultralytics.py
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, List, Union

import cv2
import numpy as np
from PIL import Image, ImageDraw
from torchvision.transforms.functional import to_pil_image

from adetailer import PredictOutput
from adetailer.common import create_mask_from_bbox

if TYPE_CHECKING:
    import torch
    from ultralytics import YOLO

def ultralytics_predict(
    model_path: str | Path,
    image: Image.Image,
    confidence: float = 0.3,
    device: str = "",
    classes: str = "",
) -> PredictOutput[float]:
    from ultralytics import YOLO

    model = YOLO(model_path)
    filtered_class_indices = apply_classes(model, model_path, classes)
    print(f"Filtered class indices: {filtered_class_indices}")  # Debugging output

    pred = model(image, conf=confidence, device=device)

    # Filter predictions based on specified classes
    if filtered_class_indices is not None:
        mask = np.isin(pred[0].boxes.cls.cpu().numpy(), filtered_class_indices)
        print(f"Detected classes: {pred[0].boxes.cls.cpu().numpy()}")  # Debugging output
        print(f"Mask: {mask}")  # Debugging output
        bboxes = pred[0].boxes.xyxy.cpu().numpy()[mask]
    else:
        bboxes = pred[0].boxes.xyxy.cpu().numpy()

    print(f"Number of bboxes after filtering: {len(bboxes)}")  # Debugging output

    if bboxes.size == 0:
        return PredictOutput()
    bboxes = bboxes.tolist()

    if pred[0].masks is None:
        masks = create_mask_from_bbox(bboxes, image.size)
    else:
        if filtered_class_indices is not None:
            filtered_masks = pred[0].masks.data[mask]
        else:
            filtered_masks = pred[0].masks.data
        masks = mask_to_pil(filtered_masks, image.size)

    # Create preview with only filtered classes
    preview = create_filtered_preview(pred[0], filtered_class_indices, image)

    return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)

def apply_classes(model: YOLO, model_path: str | Path, classes: str) -> Union[List[int], None]:
    if not classes:
        print("No classes specified. Using all classes.")
        return None

    parsed = [c.strip().lower() for c in classes.split(",") if c.strip()]
    print(f"Requested classes: {parsed}")

    available_classes = [name.lower() for name in model.names.values()]
    print(f"Available classes: {available_classes}")

    matched_indices = []
    for requested_class in parsed:
        matched = False
        for i, available_class in enumerate(available_classes):
            if requested_class in available_class or available_class in requested_class:
                matched_indices.append(i)
                matched = True
                print(f"Matched '{requested_class}' to '{available_class}' (index: {i})")
                break
        if not matched:
            print(f"Warning: No match found for '{requested_class}'")

    if hasattr(model.model, 'set_classes') and ("-world" in Path(model_path).stem or "Bean" in Path(model_path).stem):
        # For YOLOv8-world models
        matched_classes = [model.names[i] for i in matched_indices]
        print(f"Setting classes for YOLOv8-world model: {matched_classes}")
        model.model.set_classes(matched_classes)

    if not matched_indices:
        print("No classes were matched. Using all classes.")
        return None

    print(f"Final matched class indices: {matched_indices}")
    return matched_indices

def create_filtered_preview(prediction, filtered_class_indices: Union[List[int], None], original_image: Image.Image) -> Image.Image:
    # Create a copy of the original image
    preview = original_image.copy()
    draw = ImageDraw.Draw(preview)

    for box, cls in zip(prediction.boxes.xyxy, prediction.boxes.cls):
        if filtered_class_indices is not None and int(cls) not in filtered_class_indices:
            continue

        # Draw bounding box
        draw.rectangle(box.tolist(), outline="red", width=2)

        # Add label
        label = f"{prediction.names[int(cls)]} {prediction.boxes.conf[0]:.2f}"
        draw.text((box[0], box[1] - 10), label, fill="red")

    return preview

def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
    """
    Parameters
    ----------
    masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
        The device can be CUDA, but `to_pil_image` takes care of that.

    shape: tuple[int, int]
        (W, H) of the original image
    """
    n = masks.shape[0]
    return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]```

Additional context

No response

Bing-su commented 2 months ago

The YOLO model's filename shows what it predicts. If we support selecting certain classes, how does the user know which class names are supported by their YOLO model? I tried to make this easier for the user without adding more "need to know" for the user, but I couldn't. I can't think of a better way than the current structure.

This feature isn't implemented because of UI problems, not implementation difficulties.