Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.59k stars 509 forks source link

Classification custom dataset #1418

Open Allison2324 opened 1 year ago

Allison2324 commented 1 year ago

Help me please to rewrite this training code for the classification task with ResNet18 model on my own custom dataset. Dataset structure: train: images + csv label file val: images + csv label file

from super_gradients.training import models
import cv2
import torch

DEVICE = 'cuda'
print(DEVICE)
from super_gradients.training import Trainer

CHECKPOINT_DIR = 'checkpoints'
EXPERIMENT_NAME = 'yolo_nas_s_25e'
trainer = Trainer(experiment_name=EXPERIMENT_NAME, ckpt_root_dir=CHECKPOINT_DIR)

from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import coco_detection_yolo_format_train, \
    coco_detection_yolo_format_val

dataset_params = {
    'data_dir': 'dataset',
    'train_images_dir': 'train/images',
    'train_labels_dir': 'train/labels',
    'val_images_dir': 'val/images',
    'val_labels_dir': 'val/labels',
    'test_images_dir': 'test/images',
    'test_labels_dir': 'test/labels',
    'classes': []
}

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': 16,
        'num_workers': 1
    }
)

val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': 16,
        'num_workers': 1
    }
)

from super_gradients.training import models

MODEL_ARCH = 'yolo_nas_s'
model = models.get(MODEL_ARCH,
                   num_classes=len(dataset_params['classes']),
                   pretrained_weights="coco"
                   )
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback

train_params = {
    'silent_mode': False,
    "average_best_models": True,
    "warmup_mode": "linear_epoch_step",
    "warmup_initial_lr": 1e-6,
    "lr_warmup_epochs": 3,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.1,
    "optimizer": "Adam",
    "optimizer_params": {"weight_decay": 0.0001},
    "zero_weight_decay_on_bias_and_bn": True,
    "ema": True,
    "ema_params": {"decay": 0.9, "decay_type": "threshold"},
    "max_epochs": 25,
    "mixed_precision": True,
    "loss": PPYoloELoss(
        use_static_assigner=False,
        # NOTE: num_classes needs to be defined here
        num_classes=len(dataset_params['classes']),
        reg_max=16
    ),
    "valid_metrics_list": [
        DetectionMetrics_050(
            score_thres=0.1,
            top_k_predictions=300,
            # NOTE: num_classes needs to be defined here
            num_cls=len(dataset_params['classes']),
            normalize_targets=True,
            post_prediction_callback=PPYoloEPostPredictionCallback(
                score_threshold=0.01,
                nms_top_k=1000,
                max_predictions=300,
                nms_threshold=0.7
            )
        )
    ],
    "metric_to_watch": 'mAP@0.50'
}
trainer.train(model=model,
              training_params=train_params,
              train_loader=train_data,
              valid_loader=val_data)

Tell me if I need to change my dataset structure.

bit-scientist commented 1 year ago

Hi, @Allison2324, rather than telling you how to write up your code, it'd better if you try it yourself until you get an error that can't solve. Please organize your dataset as shown here and run your code. Ask help with your actual error. Thak you.