Deci-AI / data-gradients

Computer Vision dataset analysis
Apache License 2.0
293 stars 33 forks source link

Fixes a crash when some labels/boxes are NaN #230

Closed BloodAxe closed 8 months ago

BloodAxe commented 8 months ago

Test dataset to reproduce. See the attached report generated Report.pdf

If batch/sample with NaN is detected it is excluded from processing entirely. But it is tracked and added to list of errors which is not also written to PDF

import numpy as np
import torch

from data_gradients import DetectionAnalysisManager

train_samples = [
    # 4 images with labels
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (1, 5), dtype=torch.float32)),
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (2, 5), dtype=torch.float32)),
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (3, 5),

 dtype=torch.float32)),
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (4, 5), dtype=torch.float32)),
]
train_samples[3][1][0:2, 2] = np.nan

valid_samples = [
    # 4 images with labels
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (1, 5), dtype=torch.float32)),
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (2, 5), dtype=torch.float32)),
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (3, 5), dtype=torch.float32)),
    (torch.randn((3, 100, 100), dtype=torch.float32), torch.randint(0, 10, (4, 5), dtype=torch.float32)),
]
valid_samples[2][1][0, 2:4] = np.nan
manager = DetectionAnalysisManager(
    train_data=train_samples,
    val_data=valid_samples,
    report_title="Detection Test",
    class_names=["class_1", "class_2", "class_3", "class_4", "class_5", "class_6", "class_7", "class_8", "class_9",
                 "class_10"],
    batches_early_stop=None,
)
manager.run()