libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.82k stars 178 forks source link

CIFAR10 Example: Does it really use channels_last order? #119

Closed paulgavrikov closed 2 years ago

paulgavrikov commented 2 years ago

Hi, thank you for this very interesting library. I have a question regarding the cifar10 example you provide. In your transformation pipeline I find the following:

    image_pipeline.extend([
        ToTensor(),
        ToDevice('cuda:0', non_blocking=True),
         ToTorchImage(),
         Convert(torch.float16),
         torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

Now, it seems to me as if ToTorchImage() would change the order back to the regular torch format, which I also verified by outputting the shape of the loaded batch.

And then your load your model and add .to(memory_format=torch.channels_last).cuda(). At least on my machine this does not change anything at all. However, I have to add that I never tried channels last order.

The code works, but it seems to me like everything is in regular torch order. Am I missing something?

I followed your conda env guideline:

name: ffcv
channels:
  - pytorch
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=1_llvm
  - alsa-lib=1.2.3=h516909a_0
  - aom=3.2.0=h9c3ff4c_2
  - asttokens=2.0.5=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=py_2
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - binutils=2.36.1=hdd6e379_2
  - binutils_impl_linux-64=2.36.1=h193b22a_2
  - binutils_linux-64=2.36=hf3e587d_4
  - black=21.12b0=pyhd8ed1ab_0
  - blas=2.113=mkl
  - blas-devel=3.9.0=13_linux64_mkl
  - bzip2=1.0.8=h7f98852_4
  - c-ares=1.18.1=h7f98852_0
  - c-compiler=1.3.0=h7f98852_0
  - ca-certificates=2021.10.8=ha878542_0
  - cairo=1.16.0=ha00ac49_1009
  - click=8.0.3=py39hf3d152e_1
  - compilers=1.3.0=ha770c72_0
  - cudatoolkit=11.3.1=ha36c431_10
  - cupy=10.1.0=py39hccaf5a2_0
  - cxx-compiler=1.3.0=h4bd325d_0
  - dataclasses=0.8=pyhc8e2a94_3
  - dbus=1.13.6=h5008d03_3
  - debugpy=1.5.1=py39he80948d_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - entrypoints=0.3=pyhd8ed1ab_1003
  - executing=0.8.2=pyhd8ed1ab_0
  - expat=2.4.3=h9c3ff4c_0
  - fastrlock=0.8=py39he80948d_1
  - ffmpeg=4.4.1=h6987444_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=hab24e00_0
  - fontconfig=2.13.94=ha180cfb_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fortran-compiler=1.3.0=h1990efc_0
  - freeglut=3.2.1=h9c3ff4c_2
  - freetype=2.10.4=h0708190_1
  - gcc=9.4.0=h192d537_4
  - gcc_impl_linux-64=9.4.0=h03d3576_12
  - gcc_linux-64=9.4.0=h391b98a_4
  - gettext=0.19.8.1=h73d1719_1008
  - gfortran=9.4.0=h2018a41_4
  - gfortran_impl_linux-64=9.4.0=h0003116_12
  - gfortran_linux-64=9.4.0=hf0ab688_4
  - gmp=6.2.1=h58526e2_0
  - gnutls=3.6.13=h85f3911_1
  - graphite2=1.3.13=h58526e2_1001
  - gst-plugins-base=1.18.5=hf529b03_3
  - gstreamer=1.18.5=h9f60fe5_3
  - gxx=9.4.0=h192d537_4
  - gxx_impl_linux-64=9.4.0=h03d3576_12
  - gxx_linux-64=9.4.0=h0316aca_4
  - harfbuzz=3.2.0=hb4a5f5f_0
  - hdf5=1.12.1=nompi_h2750804_103
  - icu=69.1=h9c3ff4c_0
  - ipykernel=6.7.0=py39hef51801_0
  - ipython=8.0.1=py39hf3d152e_0
  - jasper=2.0.33=ha77e612_0
  - jbig=2.1=h7f98852_2003
  - jedi=0.18.1=py39hf3d152e_0
  - jpeg=9e=h7f98852_0
  - jupyter_client=7.1.2=pyhd8ed1ab_0
  - jupyter_core=4.9.1=py39hf3d152e_1
  - kernel-headers_linux-64=2.6.32=he073ed8_15
  - krb5=1.19.2=hcc1bbae_3
  - lame=3.100=h7f98852_1001
  - lcms2=2.12=hddcbb42_0
  - ld_impl_linux-64=2.36.1=hea4e1c9_2
  - lerc=3.0=h9c3ff4c_0
  - libblas=3.9.0=13_linux64_mkl
  - libcblas=3.9.0=13_linux64_mkl
  - libclang=13.0.0=default_hc23dcda_0
  - libcurl=7.81.0=h2574ce0_0
  - libdeflate=1.8=h7f98852_0
  - libdrm=2.4.109=h7f98852_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=h516909a_1
  - libevent=2.1.10=h9b69904_4
  - libffi=3.4.2=h7f98852_5
  - libgcc-devel_linux-64=9.4.0=hd854feb_12
  - libgcc-ng=11.2.0=h1d223b6_12
  - libgfortran-ng=11.2.0=h69a702a_12
  - libgfortran5=11.2.0=h5c6108e_12
  - libglib=2.70.2=h174f98d_1
  - libglu=9.0.0=he1b5a44_1001
  - libgomp=11.2.0=h1d223b6_12
  - libiconv=1.16=h516909a_0
  - libjpeg-turbo=2.1.1=h7f98852_0
  - liblapack=3.9.0=13_linux64_mkl
  - liblapacke=3.9.0=13_linux64_mkl
  - libllvm11=11.1.0=hf817b99_2
  - libllvm13=13.0.0=hf817b99_0
  - libnghttp2=1.46.0=h812cca2_0
  - libnsl=2.0.0=h7f98852_0
  - libogg=1.3.4=h7f98852_1
  - libopencv=4.5.5=py39h7d09d5f_0
  - libopus=1.3.1=h7f98852_1
  - libpciaccess=0.16=h516909a_0
  - libpng=1.6.37=h21135ba_2
  - libpq=14.1=hd57d9b9_1
  - libprotobuf=3.19.4=h780b84a_0
  - libsanitizer=9.4.0=h79bfe98_12
  - libsodium=1.0.18=h36c2ea0_1
  - libssh2=1.10.0=ha56f1ee_2
  - libstdcxx-devel_linux-64=9.4.0=hd854feb_12
  - libstdcxx-ng=11.2.0=he4da1e4_12
  - libtiff=4.3.0=h6f004c6_2
  - libuuid=2.32.1=h7f98852_1000
  - libuv=1.43.0=h7f98852_0
  - libva=2.13.0=h7f98852_2
  - libvorbis=1.3.7=h9c3ff4c_0
  - libvpx=1.11.0=h9c3ff4c_3
  - libwebp-base=1.2.2=h7f98852_1
  - libxcb=1.13=h7f98852_1004
  - libxkbcommon=1.0.3=he3ba5ed_0
  - libxml2=2.9.12=h885dcf4_1
  - libzlib=1.2.11=h36c2ea0_1013
  - llvm-openmp=12.0.1=h4bd325d_1
  - llvmlite=0.38.0=py39h1bbdace_0
  - lz4-c=1.9.3=h9c3ff4c_1
  - matplotlib-inline=0.1.3=pyhd8ed1ab_0
  - mkl=2022.0.1=h8d4b97c_803
  - mkl-devel=2022.0.1=ha770c72_804
  - mkl-include=2022.0.1=h8d4b97c_803
  - mypy_extensions=0.4.3=py39hf3d152e_4
  - mysql-common=8.0.28=ha770c72_0
  - mysql-libs=8.0.28=hfa10184_0
  - ncurses=6.3=h9c3ff4c_0
  - nest-asyncio=1.5.4=pyhd8ed1ab_0
  - nettle=3.6=he412f7d_0
  - nspr=4.32=h9c3ff4c_1
  - nss=3.74=hb5efdd6_0
  - numba=0.55.0=py39h56b8d98_0
  - numpy=1.21.5=py39haac66dc_0
  - olefile=0.46=pyh9f0ad1d_1
  - opencv=4.5.5=py39hf3d152e_0
  - openh264=2.1.1=h780b84a_0
  - openjpeg=2.4.0=hb52868f_1
  - openssl=1.1.1l=h7f98852_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pathspec=0.9.0=pyhd8ed1ab_0
  - pcre=8.45=h9c3ff4c_0
  - pexpect=4.8.0=pyh9f0ad1d_2
  - pickleshare=0.7.5=py_1003
  - pillow=8.4.0=py39ha612740_0
  - pip=22.0=pyhd8ed1ab_0
  - pixman=0.40.0=h36c2ea0_0
  - pkg-config=0.29.2=h36c2ea0_1008
  - platformdirs=2.3.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.26=pyha770c72_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - py-opencv=4.5.5=py39hef51801_0
  - pygments=2.11.2=pyhd8ed1ab_0
  - python=3.9.10=h85951f9_0_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.9=2_cp39
  - pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0
  - pytorch-mutex=1.0=cuda
  - pyzmq=22.3.0=py39h37b5a0c_1
  - qt=5.12.9=ha98a1a1_5
  - readline=8.1=h46c0cb4_0
  - setuptools=60.5.0=py39hf3d152e_0
  - six=1.16.0=pyh6c4a22f_0
  - sqlite=3.37.0=h9cd32fc_0
  - stack_data=0.1.4=pyhd8ed1ab_0
  - svt-av1=0.9.0=h9c3ff4c_0
  - sysroot_linux-64=2.12=he073ed8_15
  - tbb=2021.5.0=h4bd325d_0
  - tk=8.6.11=h27826a3_1
  - tomli=1.2.2=pyhd8ed1ab_0
  - torchvision=0.11.3=py39_cu113
  - tornado=6.1=py39h3811e60_2
  - traitlets=5.1.1=pyhd8ed1ab_0
  - typed-ast=1.5.2=py39h3811e60_0
  - typing_extensions=4.0.1=pyha770c72_0
  - tzdata=2021e=he74cb21_0
  - wcwidth=0.2.5=pyh9f0ad1d_2
  - wheel=0.37.1=pyhd8ed1ab_0
  - x264=1!161.3030=h7f98852_1
  - x265=3.5=h4bd325d_1
  - xorg-fixesproto=5.0=h7f98852_1002
  - xorg-inputproto=2.3.2=h7f98852_1002
  - xorg-kbproto=1.0.7=h7f98852_1002
  - xorg-libice=1.0.10=h7f98852_0
  - xorg-libsm=1.2.3=hd9c2040_1000
  - xorg-libx11=1.7.2=h7f98852_0
  - xorg-libxau=1.0.9=h7f98852_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xorg-libxext=1.3.4=h7f98852_1
  - xorg-libxfixes=5.0.3=h7f98852_1004
  - xorg-libxi=1.7.10=h7f98852_0
  - xorg-libxrender=0.9.10=h7f98852_1003
  - xorg-renderproto=0.11.1=h7f98852_1002
  - xorg-xextproto=7.3.0=h7f98852_1002
  - xorg-xproto=7.0.31=h7f98852_1007
  - xz=5.2.5=h516909a_1
  - zeromq=4.3.4=h9c3ff4c_1
  - zlib=1.2.11=h36c2ea0_1013
  - zstd=1.5.2=ha95c52a_0
  - pip:
    - assertpy==1.1
    - braceexpand==0.1.7
    - cycler==0.11.0
    - fastargs==1.2.0
    - ffcv==0.0.3
    - fonttools==4.29.0
    - imgcat==0.5.0
    - joblib==1.1.0
    - kiwisolver==1.3.2
    - matplotlib==3.5.1
    - packaging==21.3
    - pandas==1.4.0
    - psutil==5.9.0
    - pyparsing==3.0.7
    - pytorch-pfn-extras==0.5.6
    - pytz==2021.3
    - pyyam
GuillaumeLeclerc commented 2 years ago

Hello,

This doesn't change the shape of the tensor but it does change its strides. In other words, It doesn't change how you use it as a developer but it does change how it is stored in memory. channel_last will provide better performance on modern GPUs. For this reason we use it by default. It should not change anything for you other than improved performance.

Please refer to pytorch's documentation about channel_last for more details: https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html