Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.54k stars 496 forks source link

YOLO NAS: inference time depending on training? #1960

Open ullsen opened 5 months ago

ullsen commented 5 months ago

💡 Your Question

Hello everybody and first of all thanks to the super_gradients team for the great work!

I have a theory question on the YOLO NAS object detection models. I am training on a custom data set and check the CPU inference time after training. I observed that longer model training leads to larger inference times, which surprises me a lot. I thought the main aspect of inference is the number of weights, which should be the same for a given model size? Instead, I see an increasing inference time the longer I train the model. Despite the different times, the model sizes remain more or less the same.

For training the models I used an identical setup (same input image size, model size, pre-processing pipeline, training hardware etc etc) with different number of epochs (I tried both, training from scratch and resuming a model by loading weights). I tried the medium and the large model with the same behaviour. After converting the models to onnx I see the same effect.

Let me know if you need more details and thanks for the help!

Versions

No response

ofrimasad commented 5 months ago

Hi @ullsen .

Thank you for your kind words.

My main suspect here is NMS.

If you are exporting your model with NMS (The default behavior that can be overridden), then you are benchmarking the NMS along with the model.

The NMS is an iterative algorithm that starts with thresholding and sorting. When the model is better trained, more predictions will get higher confidence, and more predictions will cross the threshold and be included in the sorting part and the iterative part of the NMS.

To validate this assumption, you can export with the following flag:

model.export(... , postprocessing=False, ...)

You can also try calibrating the following params of the export function: num_pre_nms_predictions: (int) Number of predictions to keep before NMS. nms_threshold: (float) NMS threshold for the exported model. confidence_threshold: (float) Confidence threshold for the exported model.

Hope that helps.

ullsen commented 5 months ago

Hi @ofrimasad and thanks for the quick reply!

I tried both of your suggestions, but unfortunately there is no difference to observe.

Is the model architecture adapted during the training, or does NAS refere to a architecture optimisation prior to training?

I am open to further tests. ullsen

shaydeci commented 5 months ago

@ullsen the model architecture is not adapated through training. There should be no difference regarding the time the feed gorward takes. However, as @ofrimasad stated - the number of predictions might vary throughout training, which might explain differences in forward pass times. It would be helpful if you could share your code and environment details.

ullsen commented 5 months ago

sure! I am using a notebook instance on AWS with the pytorch conda environment.

environment

Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge absl-py 2.1.0 pypi_0 pypi alabaster 0.7.16 pypi_0 pypi albumentations 1.3.1 pypi_0 pypi aniso8601 9.0.1 pypi_0 pypi annotated-types 0.6.0 pyhd8ed1ab_0 conda-forge ansi2html 1.9.1 pypi_0 pypi antlr4-python3-runtime 4.9.3 pypi_0 pypi anyio 4.3.0 pyhd8ed1ab_0 conda-forge arabic-reshaper 3.0.0 pypi_0 pypi argon2-cffi 23.1.0 pyhd8ed1ab_0 conda-forge argon2-cffi-bindings 21.2.0 py310h2372a71_4 conda-forge arrow 1.3.0 pyhd8ed1ab_0 conda-forge asn1crypto 1.5.1 pypi_0 pypi asttokens 2.4.1 pyhd8ed1ab_0 conda-forge async-lru 2.0.4 pyhd8ed1ab_0 conda-forge attrs 23.2.0 pyh71513ae_0 conda-forge autovizwidget 0.21.0 pypi_0 pypi aws-ofi-nccl 1.7.4 aws_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ awscli 1.32.78 pypi_0 pypi babel 2.14.0 pyhd8ed1ab_0 conda-forge beautifulsoup4 4.12.3 pyha770c72_0 conda-forge blas 2.116 mkl conda-forge blas-devel 3.9.0 16_linux64_mkl conda-forge bleach 6.1.0 pyhd8ed1ab_0 conda-forge blinker 1.7.0 pypi_0 pypi bokeh 3.3.4 pyhd8ed1ab_0 conda-forge boto3 1.34.78 pypi_0 pypi botocore 1.34.78 pypi_0 pypi brotli 1.1.0 hd590300_1 conda-forge brotli-bin 1.1.0 hd590300_1 conda-forge brotli-python 1.1.0 py310hc6cd4ac_1 conda-forge build 1.2.1 pypi_0 pypi bzip2 1.0.8 hd590300_5 conda-forge c-ares 1.26.0 hd590300_0 conda-forge ca-certificates 2024.2.2 hbcca054_0 conda-forge cached-property 1.5.2 hd8ed1ab_1 conda-forge cached_property 1.5.2 pyha770c72_1 conda-forge captum 0.6.0 pyhd8ed1ab_0 conda-forge certifi 2024.2.2 pyhd8ed1ab_0 conda-forge cffi 1.16.0 py310h2fee648_0 conda-forge charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge click 8.1.7 pypi_0 pypi cloudpickle 2.2.1 pypi_0 pypi cmake 3.26.4 hcfe8598_0 conda-forge colorama 0.4.4 pyh9f0ad1d_0 conda-forge coloredlogs 15.0.1 pypi_0 pypi comm 0.2.1 pyhd8ed1ab_0 conda-forge contextlib2 21.6.0 pypi_0 pypi contourpy 1.2.0 py310hd41b1e2_0 conda-forge coverage 5.3.1 pypi_0 pypi cryptography 42.0.4 py310h75e40e8_0 conda-forge cssselect2 0.7.0 pypi_0 pypi cuda-cudart 12.1.105 0 nvidia cuda-cupti 12.1.105 0 nvidia cuda-libraries 12.1.0 0 nvidia cuda-nvrtc 12.1.105 0 nvidia cuda-nvtx 12.1.105 0 nvidia cuda-opencl 12.3.101 0 nvidia cuda-runtime 12.1.0 0 nvidia cycler 0.12.1 pyhd8ed1ab_0 conda-forge data-gradients 0.3.2 pypi_0 pypi debugpy 1.8.1 py310hc6cd4ac_0 conda-forge decorator 5.1.1 pyhd8ed1ab_0 conda-forge defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge deprecated 1.2.14 pypi_0 pypi dill 0.3.8 pypi_0 pypi docker 6.1.3 pypi_0 pypi docutils 0.16 py310hff52083_4 conda-forge dparse 0.6.3 pypi_0 pypi einops 0.3.2 pypi_0 pypi entrypoints 0.4 pyhd8ed1ab_0 conda-forge environment-kernels 1.2.0 pypi_0 pypi exceptiongroup 1.2.0 pyhd8ed1ab_2 conda-forge executing 2.0.1 pyhd8ed1ab_0 conda-forge expat 2.5.0 hcb278e6_1 conda-forge ffmpeg 4.2 h3fd9d12_1 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ filelock 3.13.1 pyhd8ed1ab_0 conda-forge flask 3.0.2 pypi_0 pypi flask-restful 0.3.10 pypi_0 pypi flatbuffers 24.3.25 pypi_0 pypi fonttools 4.49.0 py310h2372a71_0 conda-forge fqdn 1.5.1 pyhd8ed1ab_0 conda-forge freetype 2.12.1 h267a509_2 conda-forge fsspec 2024.2.0 pypi_0 pypi future 1.0.0 pypi_0 pypi gettext 0.21.1 h27087fc_0 conda-forge gmp 6.3.0 h59595ed_0 conda-forge gmpy2 2.1.2 py310h3ec546c_1 conda-forge gnutls 3.6.15 he1e5248_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ google-pasta 0.2.0 pypi_0 pypi grpcio 1.62.1 pypi_0 pypi gssapi 1.8.3 pypi_0 pypi gym 0.26.2 pypi_0 pypi gym-notices 0.0.8 pypi_0 pypi h11 0.14.0 pyhd8ed1ab_0 conda-forge h2 4.1.0 pyhd8ed1ab_0 conda-forge hdijupyterutils 0.21.0 pypi_0 pypi hpack 4.0.0 pyh9f0ad1d_0 conda-forge html5lib 1.1 pypi_0 pypi httpcore 1.0.4 pyhd8ed1ab_0 conda-forge httpx 0.27.0 pyhd8ed1ab_0 conda-forge humanfriendly 10.0 pypi_0 pypi hwloc 2.9.2 h2bc3f7f_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ hydra-core 1.3.2 pypi_0 pypi hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge icu 73.2 h59595ed_0 conda-forge idna 3.6 pyhd8ed1ab_0 conda-forge imagededup 0.3.2 pypi_0 pypi imageio 2.34.0 pyh4b66e23_0 conda-forge imagesize 1.4.1 pypi_0 pypi importlib-metadata 6.11.0 pypi_0 pypi importlib_metadata 7.0.1 hd8ed1ab_0 conda-forge importlib_resources 6.1.1 pyhd8ed1ab_0 conda-forge ipykernel 6.29.2 pyhd33586a_0 conda-forge ipython 8.22.0 pyh707e725_0 conda-forge ipywidgets 8.1.2 pyhd8ed1ab_0 conda-forge isoduration 20.11.0 pyhd8ed1ab_0 conda-forge itsdangerous 2.1.2 pypi_0 pypi jedi 0.19.1 pyhd8ed1ab_0 conda-forge jinja2 3.1.3 pyhd8ed1ab_0 conda-forge jmespath 1.0.1 pyhd8ed1ab_0 conda-forge joblib 1.3.2 pypi_0 pypi json-tricks 3.16.1 pypi_0 pypi json5 0.9.17 pyhd8ed1ab_0 conda-forge jsonpointer 2.4 py310hff52083_3 conda-forge jsonschema 4.21.1 pyhd8ed1ab_0 conda-forge jsonschema-specifications 2023.12.1 pyhd8ed1ab_0 conda-forge jsonschema-with-format-nongpl 4.21.1 pyhd8ed1ab_0 conda-forge jupyter 1.0.0 pypi_0 pypi jupyter-console 6.6.3 pypi_0 pypi jupyter-lsp 2.2.2 pyhd8ed1ab_0 conda-forge jupyter_client 8.6.0 pyhd8ed1ab_0 conda-forge jupyter_core 5.7.1 py310hff52083_0 conda-forge jupyter_events 0.9.0 pyhd8ed1ab_0 conda-forge jupyter_server 2.12.5 pyhd8ed1ab_0 conda-forge jupyter_server_terminals 0.5.2 pyhd8ed1ab_0 conda-forge jupyterlab 4.1.2 pyhd8ed1ab_0 conda-forge jupyterlab_pygments 0.3.0 pyhd8ed1ab_1 conda-forge jupyterlab_server 2.25.3 pyhd8ed1ab_0 conda-forge jupyterlab_widgets 3.0.10 pyhd8ed1ab_0 conda-forge keyutils 1.6.1 h166bdaf_0 conda-forge kiwisolver 1.4.5 py310hd41b1e2_1 conda-forge krb5 0.5.1 pypi_0 pypi lame 3.100 h166bdaf_1003 conda-forge lazy-loader 0.4 pypi_0 pypi lcms2 2.16 hb7c19ff_0 conda-forge ld_impl_linux-64 2.40 h41732ed_0 conda-forge lerc 4.0.0 h27087fc_0 conda-forge libblas 3.9.0 16_linux64_mkl conda-forge libbrotlicommon 1.1.0 hd590300_1 conda-forge libbrotlidec 1.1.0 hd590300_1 conda-forge libbrotlienc 1.1.0 hd590300_1 conda-forge libcblas 3.9.0 16_linux64_mkl conda-forge libcublas 12.1.0.26 0 nvidia libcufft 11.0.2.4 0 nvidia libcufile 1.8.1.2 0 nvidia libcurand 10.3.4.107 0 nvidia libcurl 8.5.0 hca28451_0 conda-forge libcusolver 11.4.4.55 0 nvidia libcusparse 12.0.2.55 0 nvidia libdeflate 1.19 hd590300_0 conda-forge libedit 3.1.20191231 he28a2e2_2 conda-forge libev 4.33 hd590300_2 conda-forge libexpat 2.5.0 hcb278e6_1 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-ng 13.2.0 h807b86a_5 conda-forge libgfortran-ng 13.2.0 h69a702a_5 conda-forge libgfortran5 13.2.0 ha4646dd_5 conda-forge libiconv 1.17 hd590300_2 conda-forge libidn2 2.3.7 hd590300_0 conda-forge libjpeg-turbo 3.0.0 hd590300_1 conda-forge liblapack 3.9.0 16_linux64_mkl conda-forge liblapacke 3.9.0 16_linux64_mkl conda-forge libnghttp2 1.58.0 h47da74e_1 conda-forge libnpp 12.0.2.50 0 nvidia libnsl 2.0.1 hd590300_0 conda-forge libnvjitlink 12.1.105 0 nvidia libnvjpeg 12.1.1.14 0 nvidia libpng 1.6.42 h2797004_0 conda-forge libprotobuf 3.21.12 hfc55251_2 conda-forge libsodium 1.0.18 h36c2ea0_1 conda-forge libsqlite 3.45.1 h2797004_0 conda-forge libssh2 1.11.0 h0841786_0 conda-forge libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge libtasn1 4.19.0 h166bdaf_0 conda-forge libtiff 4.6.0 ha9c0a0a_2 conda-forge libunistring 0.9.10 h7f98852_0 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libuv 1.47.0 hd590300_0 conda-forge libwebp-base 1.3.2 hd590300_0 conda-forge libxcb 1.15 h0b41bf4_0 conda-forge libxcrypt 4.4.36 hd590300_1 conda-forge libxml2 2.11.6 h232c23b_0 conda-forge libzlib 1.2.13 hd590300_5 conda-forge llvm-openmp 15.0.7 h0cdce71_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ llvmlite 0.42.0 pypi_0 pypi lxml 5.2.1 pypi_0 pypi markdown 3.6 pypi_0 pypi markdown-it-py 3.0.0 pypi_0 pypi markupsafe 2.1.5 py310h2372a71_0 conda-forge matplotlib-base 3.8.3 py310h62c0568_0 conda-forge matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge mdurl 0.1.2 pypi_0 pypi mistune 3.0.2 pyhd8ed1ab_0 conda-forge mkl 2022.1.0 h84fe81f_915 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ mkl-devel 2022.1.0 ha770c72_916 conda-forge mkl-include 2022.1.0 h84fe81f_915 conda-forge mpc 1.3.1 hfe3b2da_0 conda-forge mpfr 4.2.1 h9458935_0 conda-forge mpi4py 3.1.5 pypi_0 pypi mpmath 1.3.0 pyhd8ed1ab_0 conda-forge multiprocess 0.70.16 pypi_0 pypi munkres 1.1.4 pyh9f0ad1d_0 conda-forge nbclient 0.8.0 pyhd8ed1ab_0 conda-forge nbconvert-core 7.16.1 pyhd8ed1ab_0 conda-forge nbformat 5.9.2 pyhd8ed1ab_0 conda-forge ncurses 6.4 h59595ed_2 conda-forge nest-asyncio 1.6.0 pyhd8ed1ab_0 conda-forge nettle 3.7.3 hbbd107a_1 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ networkx 3.2.1 pyhd8ed1ab_0 conda-forge notebook 7.1.0 pyhd8ed1ab_0 conda-forge notebook-shim 0.2.4 pyhd8ed1ab_0 conda-forge numba 0.59.0 pypi_0 pypi numpy 1.23.0 pypi_0 pypi nvgpu 0.10.0 pypi_0 pypi nvidia-ml-py 12.535.133 pypi_0 pypi omegaconf 2.3.0 pypi_0 pypi onnx 1.15.0 pypi_0 pypi onnxruntime 1.15.0 pypi_0 pypi onnxsim 0.4.36 pypi_0 pypi opencv-python 4.9.0.80 pypi_0 pypi opencv-python-headless 4.9.0.80 pypi_0 pypi openh264 2.1.1 h780b84a_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ openjpeg 2.5.0 h488ebb8_3 conda-forge openssl 3.2.1 hd590300_0 conda-forge oscrypto 1.3.0 pypi_0 pypi overrides 7.7.0 pyhd8ed1ab_0 conda-forge packaging 21.3 pypi_0 pypi pandas 1.5.3 pypi_0 pypi pandocfilters 1.5.0 pyhd8ed1ab_0 conda-forge parso 0.8.3 pyhd8ed1ab_0 conda-forge pathos 0.3.2 pypi_0 pypi patsy 0.5.6 pyhd8ed1ab_0 conda-forge pexpect 4.9.0 pyhd8ed1ab_0 conda-forge pickleshare 0.7.5 py_1003 conda-forge pillow 10.2.0 py310h01dd4db_0 conda-forge pip 24.0 pyhd8ed1ab_0 conda-forge pip-tools 7.4.1 pypi_0 pypi pkgutil-resolve-name 1.3.10 pyhd8ed1ab_1 conda-forge platformdirs 4.2.0 pyhd8ed1ab_0 conda-forge plotly 5.19.0 pypi_0 pypi pox 0.3.4 pypi_0 pypi ppft 1.7.6.8 pypi_0 pypi prometheus_client 0.20.0 pyhd8ed1ab_0 conda-forge prompt-toolkit 3.0.42 pyha770c72_0 conda-forge protobuf 3.20.3 pypi_0 pypi psutil 5.9.8 py310h2372a71_0 conda-forge pthread-stubs 0.4 h36c2ea0_1001 conda-forge ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge py4j 0.10.9.5 pypi_0 pypi pyarrow 15.0.0 pypi_0 pypi pyasn1 0.5.1 pyhd8ed1ab_0 conda-forge pybind11 2.11.1 py310hd41b1e2_2 conda-forge pybind11-global 2.11.1 py310hd41b1e2_2 conda-forge pycparser 2.21 pyhd8ed1ab_0 conda-forge pydantic 2.6.1 pyhd8ed1ab_0 conda-forge pydantic-core 2.16.2 py310hcb5633a_1 conda-forge pydeprecate 0.3.2 pypi_0 pypi pyfunctional 1.4.3 pypi_0 pypi pygame 2.5.2 pypi_0 pypi pygments 2.17.2 pyhd8ed1ab_0 conda-forge pyhanko 0.23.2 pypi_0 pypi pyhanko-certvalidator 0.26.3 pypi_0 pypi pynvml 11.5.0 pypi_0 pypi pyparsing 2.4.5 pypi_0 pypi pypdf 4.2.0 pypi_0 pypi pypng 0.20220715.0 pypi_0 pypi pyproject-hooks 1.0.0 pypi_0 pypi pysocks 1.7.1 pyha2e5f31_6 conda-forge pyspark 3.3.0 pypi_0 pypi pyspnego 0.10.2 pypi_0 pypi python 3.10.13 hd12c33a_1_cpython conda-forge python-bidi 0.4.2 pypi_0 pypi python-dateutil 2.8.2 pyhd8ed1ab_0 conda-forge python-fastjsonschema 2.19.1 pyhd8ed1ab_0 conda-forge python-json-logger 2.0.7 pyhd8ed1ab_0 conda-forge python-tzdata 2024.1 pyhd8ed1ab_0 conda-forge python_abi 3.10 4_cp310 conda-forge pytorch 2.1.0 aws_py3.10_cuda12.1_cudnn8.9.2_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ pytorch-cuda 12.1 ha16c6d3_5 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ pytorch-mutex 1.0 cuda https://aws-ml-conda.s3.us-west-2.amazonaws.com/ pytz 2024.1 pyhd8ed1ab_0 conda-forge pywavelets 1.6.0 pypi_0 pypi pyyaml 6.0.1 py310h2372a71_1 conda-forge pyzmq 25.1.2 py310h795f18f_0 conda-forge qrcode 7.4.2 pypi_0 pypi qtconsole 5.5.1 pypi_0 pypi qtpy 2.4.1 pypi_0 pypi qudida 0.0.4 pypi_0 pypi rapidfuzz 3.8.1 pypi_0 pypi readline 8.2 h8228510_1 conda-forge referencing 0.33.0 pyhd8ed1ab_0 conda-forge reportlab 3.6.13 pypi_0 pypi requests 2.31.0 pyhd8ed1ab_0 conda-forge requests-kerberos 0.14.0 pypi_0 pypi rfc3339-validator 0.1.4 pyhd8ed1ab_0 conda-forge rfc3986-validator 0.1.1 pyh9f0ad1d_0 conda-forge rhash 1.4.3 hd590300_2 conda-forge rich 13.7.1 pypi_0 pypi rpds-py 0.18.0 py310hcb5633a_0 conda-forge rsa 4.7.2 pyh44b312d_0 conda-forge ruamel-yaml 0.18.6 pypi_0 pypi ruamel-yaml-clib 0.2.8 pypi_0 pypi s3fs 0.4.2 pypi_0 pypi s3transfer 0.10.0 pyhd8ed1ab_0 conda-forge sagemaker 2.214.3 pypi_0 pypi sagemaker-pyspark 1.4.5 pypi_0 pypi schema 0.7.5 pypi_0 pypi scikit-image 0.23.1 pypi_0 pypi scikit-learn 1.4.1.post1 pypi_0 pypi scipy 1.12.0 py310hb13e2d6_2 conda-forge seaborn 0.13.2 hd8ed1ab_0 conda-forge seaborn-base 0.13.2 pyhd8ed1ab_0 conda-forge send2trash 1.8.2 pyh41d4057_0 conda-forge setuptools 69.1.0 pyhd8ed1ab_1 conda-forge shap 0.40.0 pypi_0 pypi six 1.16.0 pyh6c4a22f_0 conda-forge slicer 0.0.7 pypi_0 pypi smclarify 0.5 pypi_0 pypi smdebug-rulesconfig 1.0.1 pypi_0 pypi sniffio 1.3.0 pyhd8ed1ab_0 conda-forge snowballstemmer 2.2.0 pypi_0 pypi soupsieve 2.5 pyhd8ed1ab_1 conda-forge sparkmagic 0.21.0 pypi_0 pypi sphinx 4.0.3 pypi_0 pypi sphinx-rtd-theme 1.3.0 pypi_0 pypi sphinxcontrib-applehelp 1.0.8 pypi_0 pypi sphinxcontrib-devhelp 1.0.6 pypi_0 pypi sphinxcontrib-htmlhelp 2.0.5 pypi_0 pypi sphinxcontrib-jquery 4.1 pypi_0 pypi sphinxcontrib-jsmath 1.0.1 pypi_0 pypi sphinxcontrib-qthelp 1.0.7 pypi_0 pypi sphinxcontrib-serializinghtml 1.1.10 pypi_0 pypi stack_data 0.6.2 pyhd8ed1ab_0 conda-forge statsmodels 0.14.1 py310h1f7b6fc_0 conda-forge stringcase 1.2.0 pypi_0 pypi super-gradients 3.6.1 pypi_0 pypi svglib 1.5.1 pypi_0 pypi sympy 1.12 pypyh9d50eac_103 conda-forge tabulate 0.9.0 pypi_0 pypi tbb 2021.8.0 hdb19cb5_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ tblib 2.0.0 pypi_0 pypi tenacity 8.2.3 pypi_0 pypi tensorboard 2.16.2 pypi_0 pypi tensorboard-data-server 0.7.2 pypi_0 pypi termcolor 1.1.0 pypi_0 pypi terminado 0.18.0 pyh0d859eb_0 conda-forge threadpoolctl 3.3.0 pypi_0 pypi tifffile 2024.2.12 pypi_0 pypi tinycss2 1.2.1 pyhd8ed1ab_0 conda-forge tk 8.6.13 noxft_h4845f30_101 conda-forge tomli 2.0.1 pyhd8ed1ab_0 conda-forge torch-model-archiver 0.7.1 py310_0 pytorch torch-workflow-archiver 0.2.11 py310_0 pytorch torchaudio 2.1.0 py310_cu121 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ torchdata 0.7.0 py310 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ torchmetrics 0.8.0 pypi_0 pypi torchserve 0.8.2 py310_0 pytorch torchtext 0.16.0 py310 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ torchtriton 2.1.0 py310 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ torchvision 0.16.0 py310_cu121 https://aws-ml-conda.s3.us-west-2.amazonaws.com/ tornado 6.4 py310h2372a71_0 conda-forge tqdm 4.66.2 pyhd8ed1ab_0 conda-forge traitlets 5.14.1 pyhd8ed1ab_0 conda-forge treelib 1.6.1 pypi_0 pypi types-python-dateutil 2.8.19.20240106 pyhd8ed1ab_0 conda-forge typing-extensions 4.9.0 hd8ed1ab_0 conda-forge typing_extensions 4.9.0 pyha770c72_0 conda-forge typing_utils 0.1.0 pyhd8ed1ab_0 conda-forge tzdata 2024a h0c530f3_0 conda-forge tzlocal 5.2 pypi_0 pypi ujson 5.9.0 pypi_0 pypi unicodedata2 15.1.0 py310h2372a71_0 conda-forge uri-template 1.3.0 pyhd8ed1ab_0 conda-forge uritools 4.0.2 pypi_0 pypi urllib3 2.0.7 pyhd8ed1ab_0 conda-forge wcwidth 0.2.13 pyhd8ed1ab_0 conda-forge webcolors 1.13 pyhd8ed1ab_0 conda-forge webencodings 0.5.1 pyhd8ed1ab_2 conda-forge websocket-client 1.7.0 pyhd8ed1ab_0 conda-forge werkzeug 3.0.1 pypi_0 pypi wheel 0.42.0 pyhd8ed1ab_0 conda-forge widgetsnbextension 4.0.10 pyhd8ed1ab_0 conda-forge wrapt 1.16.0 pypi_0 pypi xhtml2pdf 0.2.11 pypi_0 pypi xorg-libxau 1.0.11 hd590300_0 conda-forge xorg-libxdmcp 1.1.3 h7f98852_0 conda-forge xyzservices 2023.10.1 pyhd8ed1ab_0 conda-forge xz 5.2.6 h166bdaf_0 conda-forge yaml 0.2.5 h7f98852_2 conda-forge zeromq 4.3.5 h59595ed_0 conda-forge zipp 3.17.0 pyhd8ed1ab_0 conda-forge zlib 1.2.13 hd590300_5 conda-forge zstd 1.5.5 hfc55251_0 conda-forge

yolo code

import datetime print('\nsuccessful start at ', datetime.datetime.now(), '\n')

from libs_yolo import *

EPOCHS = 50 BATCH_SIZE = 4#16 WORKERS = 4

resumePath = r'.../ckpt_latest.pth'

trainShape = [960, 1280]#[768, 1024]#[960, 1280]#[480, 640]#/4*3 scoreThresh = 0.5

ROOT_DIR = '...' train_imgs_dir = 'images_split/train' train_labels_dir = 'labels_split/train' val_imgs_dir = 'images_split/validate' val_labels_dir = 'labels_split/validate'

DEVICE = 'cuda'# if torch.cuda.is_available() else 'cpu' classes = ['className'] model_to_train = 'yolo_nas_m'

addSamples = 2#, 11 lr, opt = 'CosineLRScheduler', 'Adam' transformSettings = {'degrees':20, 'translate':0.2, 'scales':0.2, 'shear':10, 'target_size':trainShape[::-1]}

CHECKPOINT_DIR = os.path.join(ROOT_DIR, datetime.datetime.today().strftime('%Y%m%d')+'_ckpts')#'checkpoints'

trainer = Trainer( experiment_name=model_to_train, ckpt_root_dir=CHECKPOINT_DIR )

transformsTrain = [ DetectionTargetsFormatTransform(input_dim=(trainShape), output_format="LABEL_CXCYWH"), DetectionRandomAffine(**transformSettings), DetectionHSV(prob=.5), DetectionHorizontalFlip(prob=.5), DetectionStandardize(max_value=255), ]

transformsValidate = [ DetectionTargetsFormatTransform(input_dim=(trainShape), output_format="LABEL_CXCYWH"), DetectionStandardize(max_value=255), ]

train_params = { 'silent_mode': True, "average_best_models":True, "warmup_mode": 'LinearEpochLRWarmup', #"linear_epoch_step", "warmup_initial_lr": 1e-6, "lr_warmup_epochs": 3, "initial_lr": 5e-4, "lr_mode": lr, 'lr_updates': [10, 20, 40, 70],# epochs at which lr is changed. applicable for 'lr_decay_factor':5e-4, # size of changed lr "cosine_final_lr_ratio": 0.1, "optimizer": opt,#'Adam','SGD','RMSProp' "optimizer_params": {"weight_decay": 0.0001}, "zero_weight_decay_on_bias_and_bn": True, "ema": True, "ema_params": {"decay": 0.9, "decay_type": "threshold"}, "max_epochs": EPOCHS, 'save_ckpt_epoch_list': [25, 50, 75,100, 125], "mixed_precision": True, "loss": PPYoloELoss( use_static_assigner=False, num_classes=len(classes), reg_max=16 ), "train_metrics_list": [ DetectionMetrics_050(score_thres=scoreThresh, top_k_predictions=300, num_cls=len(classes), normalize_targets=True, post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01, nms_top_k=1000, max_predictions=300,nms_threshold=0.7) ) ], "valid_metrics_list": [ DetectionMetrics_050(score_thres=scoreThresh, top_k_predictions=300, num_cls=len(classes), normalize_targets=True, post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01, nms_top_k=1000, max_predictions=300, nms_threshold=0.7) ) ], "metric_to_watch": 'mAP@0.50', 'phase_callbacks' : [logTrainingPerformance(trainer)] }

train_data = coco_detection_yolo_format_train( dataset_params={ 'data_dir': ROOT_DIR, 'images_dir': train_imgs_dir, 'labels_dir': train_labels_dir, 'classes': classes, 'input_dim': trainShape, 'transforms': transformsTrain }, dataloader_params={ 'batch_size':BATCH_SIZE, 'num_workers':WORKERS } )

val_data = coco_detection_yolo_format_val( dataset_params={ 'data_dir': ROOT_DIR, 'images_dir': val_imgs_dir, 'labels_dir': val_labels_dir, 'classes': classes, 'input_dim': trainShape, 'transforms': transformsValidate }, dataloader_params={ 'batch_size':BATCH_SIZE, 'num_workers':WORKERS } )

for k in train_data.dataset.transforms: k.additional_samples_count = addSamples

if resumePath: print('\nresuming training from %s \n' % resumePath) model = models.get(model_to_train, num_classes=len(classes), checkpoint_path=resumePath).to(DEVICE) else: print('\ntraining model from scratch\n') model = models.get(model_to_train, num_classes=len(classes)).to(DEVICE)

trainer.train( model=model, training_params=train_params, train_loader=train_data, valid_loader=val_data )

del model, trainer, train_data, val_data gc.collect() torch.cuda.empty_cache()

question

I tested model exports including postprocessing=False and calling the model without NMS. Are you suggesting, that while training there is still an influence of the NMS settings on the model, e.g. the maximum number of allowed detections per image? In this case the settings of the post prediction callback should be changed:

PPYoloEPostPredictionCallback(score_threshold=0.01, nms_top_k=1000, max_predictions=300,nms_threshold=0.7)

thanks for the support!