Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.53k stars 490 forks source link

Detection Visualization Callback #1109

Closed sewty closed 1 year ago

sewty commented 1 year ago

💡 Your Question

Hey all! I'm trying to set up a detection visualization callback in a code-based model (as in, I'm not using configuration files). I referenced https://github.com/DeciAI/supergradients/blob/master/documentation/source/ObjectDetection.md to start, and defined the following phase callback inside of my training hyperparameter dictionary: "phase_callbacks": [ DetectionVisualizationCallback( phase=Phase.VALIDATION_EPOCH_END, freq=1, post_prediction_callback=YoloPostPredictionCallback(), classes=data_config["names"], ) ], The model is able to complete training and validation, but when I load tensorboard with the events file specified, I can only find my standard validation metrics (DetectionMetrics_050). No image data appears to be available.

Does anyone have a working example of a Detection Visualization Callback using code and not a config file? Also, is there any way to get the image data produced by this callback (assuming it IS producing something and not erroneous) other than tensorboard?

NOTE: Since it is not clear from the versions list below, I am running a local copy of super-gradients 3.1.1. It is edited to fix the issue described here: #999.

Versions

absl-py==1.4.0 alabaster==0.7.13 antlr4-python3-runtime==4.9.3 asttokens==2.2.1 astunparse==1.6.3 attrs==23.1.0 Babel==2.12.1 backcall==0.2.0 beautifulsoup4==4.12.2 blinker==1.6.2 boto3==1.26.142 botocore==1.29.142 Brotli==1.0.9 build==0.10.0 cachetools==5.3.1 certifi==2023.5.7 chardet==5.1.0 charset-normalizer==3.1.0 click==8.1.3 colorama==0.4.6 coloredlogs==15.0.1 contourpy==1.0.7 coverage==5.3.1 cycler==0.11.0 daemonize==2.5.0 debugpy==1.6.7 decorator==5.1.1 Deprecated==1.2.14 distlib==0.3.6 docutils==0.17.1 einops==0.3.2 exceptiongroup==1.1.1 executing==1.2.0 filelock==3.12.0 Flask==2.3.2 Flask-Compress==1.13 flatbuffers==23.5.26 fonttools==4.39.4 future==0.18.3 gast==0.4.0 google-auth==2.19.0 google-auth-oauthlib==1.0.0 google-pasta==0.2.0 grpcio==1.54.2 guildai==0.9.0 h5py==3.8.0 hiplot==0.1.33 humanfriendly==10.0 hydra-core==1.3.2 idna==3.4 imagesize==1.4.1 iniconfig==2.0.0 ipython==8.13.2 itsdangerous==2.1.2 jax==0.4.10 jedi==0.18.2 Jinja2==3.1.2 jmespath==1.0.1 joblib==1.2.0 json-tricks==3.16.1 jsonschema==4.17.3 keras==2.12.0 kiwisolver==1.4.4 libclang==16.0.0 Markdown==3.4.3 markdown-it-py==2.2.0 MarkupSafe==2.1.2 matplotlib==3.7.1 matplotlib-inline==0.1.6 mdurl==0.1.2 ml-dtypes==0.1.0 mpmath==1.3.0 natsort==8.3.1 networkx==3.1 numpy==1.23.0 oauthlib==3.2.2 omegaconf==2.3.0 onnx==1.13.0 onnx-simplifier==0.4.28 onnxruntime==1.13.1 opencv-python==4.7.0.72 opt-einsum==3.3.0 packaging==23.1 pandas==2.0.2 parso==0.8.3 pickleshare==0.7.5 Pillow==9.5.0 pip-tools==6.13.0 pkginfo==1.9.6 platformdirs==3.5.1 pluggy==1.0.0 prompt-toolkit==3.0.38 protobuf==3.20.3 psutil==5.9.5 pure-eval==0.2.2 pyasn1==0.5.0 pyasn1-modules==0.3.0 pycocotools==2.0.4 pyDeprecate==0.3.2 Pygments==2.15.1 pyparsing==2.4.5 pyproject_hooks==1.0.0 pyreadline3==3.4.1 pyrsistent==0.19.3 pytest==7.3.1 python-dateutil==2.8.2 pytz==2023.3 PyYAML==6.0 rapidfuzz==3.0.0 requests==2.31.0 requests-oauthlib==1.3.1 rich==13.3.5 rsa==4.9 s3transfer==0.6.1 scikit-learn==1.2.2 scipy==1.10.1 six==1.16.0 snowballstemmer==2.2.0 soupsieve==2.4.1 Sphinx==4.0.3 sphinx-rtd-theme==1.2.1 sphinxcontrib-applehelp==1.0.4 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-jquery==4.1 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 stack-data==0.6.2 stringcase==1.2.0 -e git+https://github.com/Deci-AI/super-gradients.git@b969d7c0761c49aae819b921ecde394b7288867c#egg=super_gradients sympy==1.12 tabview==1.4.4 tensorboard==2.12.3 tensorboard-data-server==0.7.0 tensorflow==2.12.0 tensorflow-estimator==2.12.0 tensorflow-intel==2.12.0 tensorflow-io-gcs-filesystem==0.31.0 termcolor==1.1.0 threadpoolctl==3.1.0 tomli==2.0.1 torch==2.0.1+cu117 torchinfo==1.8.0 torchmetrics==0.8.0 torchvision==0.15.2+cu117 tqdm==4.65.0 traitlets==5.9.0 treelib==1.6.1 typing_extensions==4.6.2 tzdata==2023.3 urllib3==1.26.16 virtualenv==20.23.0 wcwidth==0.2.6 Werkzeug==2.3.4 windows-curses==2.3.1 wrapt==1.14.1

NatanBagrov commented 1 year ago

Try to look at "Images" tab there. I believe you're in the "Scalars" tab. Let me know if that helps :)

sewty commented 1 year ago

@NatanBagrov Thanks for the suggestion, but no, they are most certainly not appearing in tensorboard, even the "Images" tab. However, to narrow the problem, I've abstracted tensorboard completely for now.

It appears that my callback is not generating any images at all. I created a small experiment where I create a dataloader, undo function (see https://github.com/Deci-AI/super-gradients/blob/master/documentation/source/ObjectDetection.md) and then run the following:

model = models.get("yolo_nas_s", pretrained_weights="coco", num_classes=data_config["nc"]).to(device)

imgs, targets = next(iter(train_data)) imgs, targets = imgs.cuda(), targets.cuda() preds = YoloPostPredictionCallback(conf=0.1, iou=0.6)(model(imgs))

DetectionVisualization.visualize_batch(imgs, preds, targets, batch_name="train", class_names=data_config["names"], checkpoint_dir="/dvc", gt_alpha=0.5, undo_preprocessing_func=undo_img_prep,)

So, if I understand correctly, the preds = YoloPostPredictionCallback(...) line is essentially using model to run inference on my images (imgs). The DetectionVisualization.visualize_batch should then capture the predictions, set them on the images, and then save to the defined checkpoint_dir.

Unfortunately, after running all of this, my checkpoint_dir remains empty (no images have been saved!).

Edit: I'm now using super-gradients version 3.1.2, not the custom version I had mentioned before. This issue persists on both versions.

NatanBagrov commented 1 year ago

Thanks for your elaboration. @Louis-Dupont?

karl-joan commented 1 year ago

@sewty For me, I got it working by setting the phase to Phase.VALIDATION_BATCH_END. This because DetectionVisualizationCallback checks if the batch index matches with the batch index given at initialization. Namely it performs the following check

if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:

If you are using Phase.VALIDATION_EPOCH_END, then you need to set the batch_idx to your last batch.

Also be aware of the model you are using. I am experiencing some bugs with YOLO-NAS, so I had to modify super_gradients code to get the predictions visualized.

BloodAxe commented 1 year ago

Detection visualization callback is currently being worked on by @shaydeci I expect it will be made publicly available in the next versions of SG