Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.53k stars 490 forks source link

RuntimeError: CUDA error: device-side assert triggered when tranning #1537

Closed 714273725 closed 11 months ago

714273725 commented 11 months ago

🐛 Describe the bug

I'm not sure its bug or not.When i try to train ,i get RuntimeError: CUDA error: device-side assert triggered,and the GPU seems not working when tranning.

1697506876917

the train code:

from super_gradients import setup_device
from super_gradients.training import Trainer
from super_gradients.training.datasets import YoloDarknetFormatDetectionDataset
from super_gradients.training.losses import YoloXDetectionLoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from super_gradients.training.transforms.transforms import DetectionRescale
from torch.utils.data import DataLoader
from super_gradients.training.utils.detection_utils import DetectionCollateFN
from super_gradients.training import models
from prettyformatter import pprint

CHECKPOINT_DIR = '.\\test'

def train():
    trainer = Trainer(experiment_name='breast_detection',
                      ckpt_root_dir=CHECKPOINT_DIR)
    train_dataset = YoloDarknetFormatDetectionDataset(data_dir="datasets\\yolo-darknet",
                                                      images_dir="train2007", labels_dir="train2007",
                                                      classes=["belt_logo"],
                                                      transforms=[DetectionRescale(output_shape=(1280, 640))])
    val_dataset = YoloDarknetFormatDetectionDataset(data_dir="datasets\\yolo-darknet",
                                                    images_dir="val2007", labels_dir="val2007",
                                                    classes=["belt_logo"],
                                                    transforms=[DetectionRescale(output_shape=(1280, 640))])
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=8,
                                  collate_fn=DetectionCollateFN())
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=8,
                                collate_fn=DetectionCollateFN())
    model = models.get("yolox_n", pretrained_weights="coco", num_classes=1)
    train_params = {
        # ENABLING SILENT MODE
        'silent_mode': False,
        "average_best_models": True,
        "warmup_mode": "linear_epoch_step",
        "warmup_initial_lr": 1e-6,
        "lr_warmup_epochs": 3,
        "initial_lr": 5e-4,
        "lr_mode": "cosine",
        "cosine_final_lr_ratio": 0.1,
        "optimizer": "Adam",
        "optimizer_params": {"weight_decay": 0.0001},
        "zero_weight_decay_on_bias_and_bn": True,
        "ema": True,
        "ema_params": {"decay": 0.9, "decay_type": "threshold"},
        # ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK
        "max_epochs": 20,
        "mixed_precision": True,
        "loss": YoloXDetectionLoss(
            # NOTE: num_classes needs to be defined here
            num_classes=1,
            strides=[8, 16, 32]
        ),
        "valid_metrics_list": [
            DetectionMetrics_050(
                score_thres=0.1,
                top_k_predictions=300,
                # NOTE: num_classes needs to be defined here
                num_cls=1,
                normalize_targets=True,
                post_prediction_callback=PPYoloEPostPredictionCallback(
                    score_threshold=0.01,
                    nms_top_k=1000,
                    max_predictions=300,
                    nms_threshold=0.7
                )
            )
        ],
        "metric_to_watch": "mAP@0.50:0.95"
    }
    # train_params["num_gpus"] = 1
    pprint(train_params, json=True)
    setup_device(num_gpus=1)
    trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)

if __name__ == "__main__":
    train()

Versions

packages in environment at C:\Users\Ge\anaconda3\envs\py310:

absl-py 2.0.0 pypi_0 pypi accelerate 0.20.3 pypi_0 pypi aiofiles 23.1.0 pypi_0 pypi aiohttp 3.8.4 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi alabaster 0.7.13 pypi_0 pypi altair 5.0.1 pypi_0 pypi antlr4-python3-runtime 4.9.3 pypi_0 pypi anyio 3.7.0 pypi_0 pypi appdirs 1.4.4 pypi_0 pypi async-timeout 4.0.2 pypi_0 pypi attrs 23.1.0 pypi_0 pypi babel 2.13.0 pypi_0 pypi boto3 1.28.64 pypi_0 pypi botocore 1.31.64 pypi_0 pypi build 1.0.3 pypi_0 pypi bzip2 1.0.8 h8ffe710_4 conda-forge ca-certificates 2022.12.7 h5b45459_0 conda-forge cachetools 5.3.1 pypi_0 pypi certifi 2023.5.7 pypi_0 pypi charset-normalizer 3.1.0 pypi_0 pypi click 8.1.3 pypi_0 pypi colorama 0.4.6 pypi_0 pypi coverage 5.3.1 pypi_0 pypi cpm-kernels 1.0.11 pypi_0 pypi deprecated 1.2.14 pypi_0 pypi docutils 0.17.1 pypi_0 pypi einops 0.3.2 pypi_0 pypi exceptiongroup 1.1.1 pypi_0 pypi fastapi 0.99.1 pypi_0 pypi ffmpy 0.3.0 pypi_0 pypi filelock 3.12.2 pypi_0 pypi frozenlist 1.3.3 pypi_0 pypi fsspec 2023.6.0 pypi_0 pypi future 0.18.3 pypi_0 pypi google-auth 2.23.0 pypi_0 pypi google-auth-oauthlib 1.0.0 pypi_0 pypi gradio 3.35.2 pypi_0 pypi gradio-client 0.2.7 pypi_0 pypi grpcio 1.58.0 pypi_0 pypi h11 0.14.0 pypi_0 pypi httpcore 0.17.2 pypi_0 pypi httpx 0.24.1 pypi_0 pypi huggingface-hub 0.15.1 pypi_0 pypi hydra-core 1.3.2 pypi_0 pypi idna 3.4 pypi_0 pypi imagesize 1.4.1 pypi_0 pypi jinja2 3.1.2 pypi_0 pypi jmespath 1.0.1 pypi_0 pypi json-tricks 3.16.1 pypi_0 pypi jsonschema 4.17.3 pypi_0 pypi labelimg 1.8.6 pypi_0 pypi latex2mathml 3.76.0 pypi_0 pypi libffi 3.4.2 h8ffe710_5 conda-forge libsqlite 3.40.0 hcfcfb64_0 conda-forge libzlib 1.2.13 hcfcfb64_4 conda-forge linkify-it-py 2.0.2 pypi_0 pypi loguru 0.7.2 pypi_0 pypi lxml 4.9.3 pypi_0 pypi mako 1.2.4 pypi_0 pypi markdown 3.4.3 pypi_0 pypi markdown-it-py 2.2.0 pypi_0 pypi markupsafe 2.1.2 pypi_0 pypi mdit-py-plugins 0.3.3 pypi_0 pypi mdtex2html 1.2.0 pypi_0 pypi mdurl 0.1.2 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi multidict 6.0.4 pypi_0 pypi ncnn 1.0.20220420 pypi_0 pypi ninja 1.11.1 pypi_0 pypi numpy 1.23.0 pypi_0 pypi oauthlib 3.2.2 pypi_0 pypi omegaconf 2.3.0 pypi_0 pypi onnx 1.13.0 pypi_0 pypi onnx-simplifier 0.4.10 pypi_0 pypi onnxruntime 1.13.1 pypi_0 pypi opencv-python 4.8.1.78 pypi_0 pypi openssl 3.1.0 hcfcfb64_0 conda-forge orjson 3.9.1 pypi_0 pypi pandas 2.0.3 pypi_0 pypi pillow 9.3.0 pypi_0 pypi pip 23.0.1 pyhd8ed1ab_0 conda-forge pip-tools 7.3.0 pypi_0 pypi platformdirs 3.2.0 pypi_0 pypi portalocker 2.8.2 pypi_0 pypi prettyformatter 2.0.13 pypi_0 pypi protobuf 3.20.3 pypi_0 pypi psutil 5.9.5 pypi_0 pypi py-cpuinfo 9.0.0 pypi_0 pypi pyasn1 0.5.0 pypi_0 pypi pyasn1-modules 0.3.0 pypi_0 pypi pycocotools 2.0.6 pypi_0 pypi pycuda 2022.1+cuda116 pypi_0 pypi pydantic 1.10.10 pypi_0 pypi pydeprecate 0.3.2 pypi_0 pypi pydub 0.25.1 pypi_0 pypi pygments 2.15.1 pypi_0 pypi pyparsing 2.4.5 pypi_0 pypi pyproject-hooks 1.0.0 pypi_0 pypi pyqt5 5.15.9 pypi_0 pypi pyqt5-plugins 5.15.9.2.3 pypi_0 pypi pyqt5-qt5 5.15.2 pypi_0 pypi pyqt5-sip 12.12.2 pypi_0 pypi pyqt5-tools 5.15.9.3.3 pypi_0 pypi pyrsistent 0.19.3 pypi_0 pypi python 3.10.10 h4de0772_0_cpython conda-forge python-dotenv 1.0.0 pypi_0 pypi python-multipart 0.0.6 pypi_0 pypi pytools 2022.1.14 pypi_0 pypi pytz 2023.3 pypi_0 pypi pywin32 306 pypi_0 pypi pyyaml 6.0 pypi_0 pypi qt5-applications 5.15.2.2.3 pypi_0 pypi qt5-tools 5.15.2.1.3 pypi_0 pypi rapidfuzz 3.4.0 pypi_0 pypi regex 2023.6.3 pypi_0 pypi requests 2.31.0 pypi_0 pypi requests-oauthlib 1.3.1 pypi_0 pypi rich 13.5.3 pypi_0 pypi rsa 4.9 pypi_0 pypi s3transfer 0.7.0 pypi_0 pypi safetensors 0.3.1 pypi_0 pypi seaborn 0.13.0 pypi_0 pypi semantic-version 2.10.0 pypi_0 pypi sentencepiece 0.1.99 pypi_0 pypi setuptools 68.2.2 pypi_0 pypi sniffio 1.3.0 pypi_0 pypi snowballstemmer 2.2.0 pypi_0 pypi sphinx 4.0.3 pypi_0 pypi sphinx-rtd-theme 1.3.0 pypi_0 pypi sphinxcontrib-applehelp 1.0.4 pypi_0 pypi sphinxcontrib-devhelp 1.0.2 pypi_0 pypi sphinxcontrib-htmlhelp 2.0.1 pypi_0 pypi sphinxcontrib-jquery 4.1 pypi_0 pypi sphinxcontrib-jsmath 1.0.1 pypi_0 pypi sphinxcontrib-qthelp 1.0.3 pypi_0 pypi sphinxcontrib-serializinghtml 1.1.5 pypi_0 pypi starlette 0.27.0 pypi_0 pypi stringcase 1.2.0 pypi_0 pypi super-gradients 3.2.1 pypi_0 pypi sympy 1.12 pypi_0 pypi tabulate 0.9.0 pypi_0 pypi tensorboard 2.14.0 pypi_0 pypi tensorboard-data-server 0.7.1 pypi_0 pypi termcolor 1.1.0 pypi_0 pypi thop 0.1.1-2209072238 pypi_0 pypi tk 8.6.12 h8ffe710_0 conda-forge tokenizers 0.13.3 pypi_0 pypi tomli 2.0.1 pypi_0 pypi toolz 0.12.0 pypi_0 pypi torch 2.0.1+cu118 pypi_0 pypi torchaudio 2.0.2+cu118 pypi_0 pypi torchmetrics 0.8.0 pypi_0 pypi torchvision 0.15.2+cu118 pypi_0 pypi tqdm 4.65.0 pypi_0 pypi transformers 4.30.2 pypi_0 pypi treelib 1.6.1 pypi_0 pypi typing-extensions 4.5.0 pypi_0 pypi tzdata 2023.3 pypi_0 pypi uc-micro-py 1.0.2 pypi_0 pypi ucrt 10.0.22621.0 h57928b3_0 conda-forge ultralytics 8.0.197 pypi_0 pypi urllib3 1.26.16 pypi_0 pypi uvicorn 0.22.0 pypi_0 pypi vc 14.3 hb6edc58_10 conda-forge vs2015_runtime 14.34.31931 h4c5c07a_10 conda-forge websockets 11.0.3 pypi_0 pypi werkzeug 2.3.7 pypi_0 pypi wheel 0.41.2 pypi_0 pypi win32-setctime 1.1.0 pypi_0 pypi wrapt 1.15.0 pypi_0 pypi xz 5.2.6 h8d14728_0 conda-forge yarl 1.9.2 pypi_0 pypi yolox 0.3.0 dev_0

714273725 commented 11 months ago

I have try issues 1400,it works for me