mikel-brostrom / boxmot

BoxMOT: pluggable SOTA tracking modules for segmentation, object detection and pose estimation models
GNU Affero General Public License v3.0
6.71k stars 1.71k forks source link

Loading Custom ReID models and Poor Performance #1618

Closed MahejabeenNidhi closed 1 month ago

MahejabeenNidhi commented 1 month ago

Search before asking

Question

Thank you so much for your work. I have two questions.

Question 1

I trained a ReID model with torchreid and it gives me a .pth file. I notices for StrongSORT repo, it's always a .pt file. Would I be able to just use the .pth file that my code below generates? It is especially important for my case as I am tracking non-human/non-vehicle objects.

import os
import torch
import string
import random
import argparse
import torchreid
from glob import glob
import os.path as osp

class NewDataset(torchreid.data.datasets.ImageDataset):
    dataset_dir = ''

    def __init__(self, root='', **kwargs):
        self.train_dir = self.dataset_dir
        self.query_dir = self.dataset_dir
        self.gallery_dir = self.dataset_dir

        train = self.process_dir(self.train_dir, isQuery=False)
        query = self.process_dir(self.query_dir, isQuery=True)
        gallery = self.process_dir(self.gallery_dir, isQuery=False)

        super(NewDataset, self).__init__(train, query, gallery, **kwargs)

    def process_dir(self, dir_path, isQuery, relabel=True):
        img_paths = glob(osp.join(dir_path, '*.jpg'))

        pid_container = set()
        for img_path in img_paths:
            img_name = img_path.split('/')[-1]
            name_splitted = img_name.split('_')
            pid = int(name_splitted[1][1:])
            pid_container.add(pid)

        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        data = []
        for img_path in img_paths:
            img_name = img_path.split('/')[-1]
            name_splitted = img_name.split('_')
            pid = int(name_splitted[1][1:])
            camid = int(name_splitted[0][1:])

            if isQuery:
                camid += 10  # index starts from 0

            if relabel:
                pid = pid2label[pid]

            data.append((img_path, pid, camid))

        return data

def get_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', type=str, default='osnet_x1_0', help="ReID model name")
    parser.add_argument('--img_h', type=int, default=256, help="image height")
    parser.add_argument('--img_w', type=int, default=128, help="image width")
    parser.add_argument('--bs', type=int, default=32, help="batch size")
    parser.add_argument('--optim', type=str, default='adam', help="optimizer")
    parser.add_argument('--lr', type=float, default=0.003, help="learning rate")
    parser.add_argument('--lr_sch', type=str, default="single_step", help="learning rate scheduler")
    parser.add_argument('--step', type=int, default=5, help="learning rate scheduler's step size")
    parser.add_argument('--epochs', type=int, default=20, help="epoch count for the training loop")
    parser.add_argument('--eval_freq', type=int, default=5, help="evaluation frequency")
    parser.add_argument('--data_path', type=str, default='path/to/data', help="path to the custom dataset")
    parser.add_argument('--save_path', type=str, default='path/to/save', help="path to save the model")

    args = parser.parse_args()

    return args

def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    NewDataset.dataset_dir = args.data_path
    dataset_name = ''.join(random.choices(string.ascii_uppercase + string.digits, k=random.randint(1, 25)))
    torchreid.data.register_image_dataset(dataset_name, NewDataset)

    datamanager = torchreid.data.ImageDataManager(
        sources=dataset_name,
        height=args.img_h,
        width=args.img_w,
        batch_size_train=args.bs,
        batch_size_test=100,
        transforms=["random_flip", "random_crop"]
    )

    model = torchreid.models.build_model(
        name=args.name,
        num_classes=datamanager.num_train_pids,
        loss="triplet",
        pretrained=True
    ).to(device).train()

    optimizer = torchreid.optim.build_optimizer(
        model,
        optim=args.optim,
        lr=args.lr,
    )

    scheduler = torchreid.optim.build_lr_scheduler(
        optimizer,
        lr_scheduler=args.lr_sch,
        stepsize=args.step,
    )

    engine = torchreid.engine.ImageTripletEngine(
        datamanager,
        model,
        optimizer=optimizer,
        scheduler=scheduler,
        margin=0.3,  # by default 0.3
        weight_t=1,  # weight for triplet loss
        weight_x=50,  # weight for softmax loss
    )

    engine.run(
        save_dir=f"log/{args.name}",
        max_epoch=args.epochs,
        eval_freq=args.eval_freq,
        print_freq=50,
        test_only=False
    )

    # Save the trained model
    model_save_path = os.path.join(args.save_path, f"{args.name}_model.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Trained model saved at: {model_save_path}")

if __name__ == '__main__':
    args = get_parser()
    main(args)

Question 2

I used the built in tracker in YOLOv8 and the performance was much better. I used a custom weight for my unique class and it worked a lot better than when I used this repo.

python tracking/track.py --source ../LabelledTracking/D05-AA-01_LM --yolo-model tracking/weights/yolov8_best.pt --tracking-method botsort --imgsz 2160 --save --save-txt

Better performance when using the following code,

import cv2
import numpy as np
from ultralytics import YOLO
import os
import random

# Function to generate a random RGB color
def random_color():
    return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

# Function to draw bounding boxes and center points
def draw_boxes_and_centers(frame, boxes, clss, track_ids, confs, img_size, object_colors):
    img = frame.copy()
    centers = []

    for i, box in enumerate(boxes):
        cls_id = int(clss[i])
        track_id = int(track_ids[i])
        conf = confs[i]

        x1, y1, x2, y2 = map(int, box)
        center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2

        # Get color for the object (generate a new color if not assigned)
        if track_id not in object_colors:
            object_colors[track_id] = random_color()
        color = object_colors[track_id]

        # Draw bounding box
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)

        # Draw track ID and center point
        label = f"{track_id}"
        cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
        cv2.circle(img, (center_x, center_y), 3, color, -1)

        centers.append((center_x / img_size[0], center_y / img_size[1], track_id))

    return img, centers, object_colors

# Function to create an image showing center point tracks
def create_tracks_image(centers, img_size, object_colors):
    track_img = np.ones((img_size[1], img_size[0], 3), dtype=np.uint8) * 255
    track_centers = {}

    for center_x, center_y, track_id in centers:
        if track_id not in track_centers:
            track_centers[track_id] = []

        track_centers[track_id].append((int(center_x * img_size[0]), int(center_y * img_size[1])))

    for track_id, points in track_centers.items():
        color = object_colors[track_id]
        for i in range(len(points) - 1):
            cv2.line(track_img, points[i], points[i + 1], color, 2)
        cv2.circle(track_img, points[-1], 3, color, -1)  # Draw the last point as a circle

    return track_img

output_directory = "../../path"
os.makedirs(output_directory, exist_ok=True)

model_path = '../../runs/detect/train17/weights/best.pt'
model = YOLO(model_path)  # Load a custom trained model
names = model.model.names

image_folder = "../../dataset"
assert os.path.exists(image_folder), "Image folder not found"

object_colors = {}
all_centers = []
frame_count = 0
for filename in sorted(os.listdir(image_folder)):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        img_path = os.path.join(image_folder, filename)
        frame = cv2.imread(img_path)
        assert frame is not None, f"Failed to read image {filename}"

        h, w, _ = frame.shape
        img_size = (w, h)

        # Extract prediction results
        results = model.track(frame, persist=True, verbose=False)
        boxes = results[0].boxes.xyxy.cpu().numpy()
        clss = results[0].boxes.cls.cpu().tolist()
        track_ids = results[0].boxes.id.int().cpu().tolist()
        confs = results[0].boxes.conf.float().cpu().tolist()

        # Draw bounding boxes and center points
        annotated_frame, centers, object_colors = draw_boxes_and_centers(frame, boxes, clss, track_ids, confs, img_size, object_colors)
        all_centers.extend(centers)

        frame_filename = f"output_{frame_count}.jpg"
        frame_path = os.path.join(output_directory, frame_filename)
        cv2.imwrite(frame_path, annotated_frame)

        frame_count += 1

The detections when I use Boxmot look very poor comparatively, making me wonder if the weight loaded properly. What would you advise?

Thank you so much for your time!

mikel-brostrom commented 1 month ago

Would I be able to just use the .pth file that my code below generates?

Yup, no problem. Just change the suffix from .pth to .pt

github-actions[bot] commented 1 month ago

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs. Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!