aws-samples / awsome-distributed-training

Collection of best practices, reference architectures, model training examples and utilities to train large models on AWS.
MIT No Attribution
176 stars 74 forks source link

Example 10.FSDP is broken again with dependencies issue #219

Closed verdimrc closed 5 months ago

verdimrc commented 5 months ago

PR #118 causes below error, and I had to revert it to get the example works.

0:   File "/fsx/marcverd/awsome-distributed-training/3.test_cases/10.FSDP/pt_fsdp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1076, in _update_causal_mask
0:             causal_mask = torch.triu(causal_mask, diagonal=1)causal_mask = torch.triu(causal_mask, diagonal=1)causal_mask = torch.triu(causal_mask, diagonal=1)
0:
0:
0: RuntimeErrorRuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16':

Package versions (TL;DR):

Package list with PR-#118 ```text # packages in environment at /fsx/marcverd/awsome-distributed-training/3.test_cases/10.FSDP/pt_fsdp: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge aiohttp 3.9.3 py310h2372a71_1 conda-forge aiosignal 1.3.1 pyhd8ed1ab_0 conda-forge async-timeout 4.0.3 pyhd8ed1ab_0 conda-forge attrs 23.2.0 pyh71513ae_0 conda-forge aws-c-auth 0.7.16 haed3651_8 conda-forge aws-c-cal 0.6.10 ha9bf9b1_2 conda-forge aws-c-common 0.9.14 hd590300_0 conda-forge aws-c-compression 0.2.18 h4466546_2 conda-forge aws-c-event-stream 0.4.2 he635cd5_6 conda-forge aws-c-http 0.8.1 hbfc29b2_7 conda-forge aws-c-io 0.14.6 h6b388c4_1 conda-forge aws-c-mqtt 0.10.3 hffff1cc_2 conda-forge aws-c-s3 0.5.2 h4893938_2 conda-forge aws-c-sdkutils 0.1.15 h4466546_2 conda-forge aws-checksums 0.1.18 h4466546_2 conda-forge aws-crt-cpp 0.26.3 h137ae52_2 conda-forge aws-ofi-nccl 1.7.4 aws_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com aws-sdk-cpp 1.11.267 he0cb598_3 conda-forge blas 2.116 mkl conda-forge blas-devel 3.9.0 16_linux64_mkl conda-forge boto3 1.34.67 pyhd8ed1ab_0 conda-forge botocore 1.34.67 pyge310_1234567_0 conda-forge brotli-python 1.1.0 py310hc6cd4ac_1 conda-forge bzip2 1.0.8 h5eee18b_5 c-ares 1.27.0 hd590300_0 conda-forge ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 pyhd8ed1ab_0 conda-forge charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge colorama 0.4.6 pyhd8ed1ab_0 conda-forge cuda-cudart 12.2.140 0 nvidia cuda-cupti 12.2.142 0 nvidia cuda-libraries 12.2.2 0 nvidia cuda-nvrtc 12.2.140 0 nvidia cuda-nvtx 12.2.140 0 nvidia cuda-opencl 12.4.99 0 nvidia cuda-runtime 12.2.2 0 nvidia datasets 2.18.0 pyhd8ed1ab_0 conda-forge dill 0.3.8 pyhd8ed1ab_0 conda-forge ffmpeg 4.2 h3fd9d12_1 https://aws-ml-conda.s3.us-west-2.amazonaws.com filelock 3.13.1 pyhd8ed1ab_0 conda-forge freetype 2.12.1 h267a509_2 conda-forge frozenlist 1.4.1 py310h2372a71_0 conda-forge fsspec 2023.9.2 pyh1a96a4e_0 conda-forge gettext 0.21.1 h27087fc_0 conda-forge gflags 2.2.2 he1b5a44_1004 conda-forge glog 0.7.0 hed5481d_0 conda-forge gmp 6.3.0 h59595ed_1 conda-forge gnutls 3.6.15 he1e5248_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com huggingface_hub 0.21.4 pyhd8ed1ab_0 conda-forge hwloc 2.9.2 h2bc3f7f_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com icu 73.2 h59595ed_0 conda-forge idna 3.6 pyhd8ed1ab_0 conda-forge jinja2 3.1.3 pyhd8ed1ab_0 conda-forge jmespath 1.0.1 pyhd8ed1ab_0 conda-forge jpeg 9e h0b41bf4_3 conda-forge keyutils 1.6.1 h166bdaf_0 conda-forge krb5 1.21.2 h659d440_0 conda-forge lame 3.100 h166bdaf_1003 conda-forge lcms2 2.15 hfd0df8a_0 conda-forge ld_impl_linux-64 2.38 h1181459_1 lerc 4.0.0 h27087fc_0 conda-forge libabseil 20240116.1 cxx17_h59595ed_2 conda-forge libarrow 15.0.2 h6bfc85a_0_cpu conda-forge libarrow-acero 15.0.2 h59595ed_0_cpu conda-forge libarrow-dataset 15.0.2 h59595ed_0_cpu conda-forge libarrow-flight 15.0.2 hc6145d9_0_cpu conda-forge libarrow-flight-sql 15.0.2 h757c851_0_cpu conda-forge libarrow-gandiva 15.0.2 hb016d2e_0_cpu conda-forge libarrow-substrait 15.0.2 h757c851_0_cpu conda-forge libblas 3.9.0 16_linux64_mkl conda-forge libbrotlicommon 1.1.0 hd590300_1 conda-forge libbrotlidec 1.1.0 hd590300_1 conda-forge libbrotlienc 1.1.0 hd590300_1 conda-forge libcblas 3.9.0 16_linux64_mkl conda-forge libcrc32c 1.1.2 h9c3ff4c_0 conda-forge libcublas 12.2.5.6 0 nvidia libcufft 11.0.8.103 0 nvidia libcufile 1.9.0.20 0 nvidia libcurand 10.3.5.119 0 nvidia libcurl 8.6.0 hca28451_0 conda-forge libcusolver 11.5.2.141 0 nvidia libcusparse 12.1.2.141 0 nvidia libdeflate 1.17 h0b41bf4_0 conda-forge libedit 3.1.20191231 he28a2e2_2 conda-forge libev 4.33 hd590300_2 conda-forge libevent 2.1.12 hf998b51_1 conda-forge libffi 3.4.4 h6a678d5_0 libgcc-ng 13.2.0 h807b86a_5 conda-forge libgfortran-ng 13.2.0 h69a702a_5 conda-forge libgfortran5 13.2.0 ha4646dd_5 conda-forge libgomp 13.2.0 h807b86a_5 conda-forge libgoogle-cloud 2.22.0 h9be4e54_1 conda-forge libgoogle-cloud-storage 2.22.0 hc7a4891_1 conda-forge libgrpc 1.62.1 h15f2491_0 conda-forge libiconv 1.17 hd590300_2 conda-forge libidn2 2.3.7 hd590300_0 conda-forge liblapack 3.9.0 16_linux64_mkl conda-forge liblapacke 3.9.0 16_linux64_mkl conda-forge libllvm16 16.0.6 h5cf9203_2 conda-forge libnghttp2 1.58.0 h47da74e_1 conda-forge libnl 3.9.0 hd590300_0 conda-forge libnpp 12.2.1.4 0 nvidia libnvjitlink 12.2.140 0 nvidia libnvjpeg 12.2.2.4 0 nvidia libparquet 15.0.2 h352af49_0_cpu conda-forge libpng 1.6.43 h2797004_0 conda-forge libprotobuf 4.25.3 h08a7969_0 conda-forge libre2-11 2023.09.01 h5a48ba9_2 conda-forge libssh2 1.11.0 h0841786_0 conda-forge libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge libtasn1 4.19.0 h166bdaf_0 conda-forge libthrift 0.19.0 hb90f79a_1 conda-forge libtiff 4.5.0 h6adf6a1_2 conda-forge libunistring 0.9.10 h7f98852_0 conda-forge libutf8proc 2.8.0 h166bdaf_0 conda-forge libuuid 1.41.5 h5eee18b_0 libwebp-base 1.3.2 hd590300_0 conda-forge libxcb 1.13 h7f98852_1004 conda-forge libxml2 2.11.6 h232c23b_0 conda-forge libzlib 1.2.13 hd590300_5 conda-forge llvm-openmp 15.0.7 h0cdce71_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com lz4-c 1.9.4 hcb278e6_0 conda-forge markupsafe 2.1.5 py310h2372a71_0 conda-forge mkl 2022.1.0 h84fe81f_915 https://aws-ml-conda.s3.us-west-2.amazonaws.com mkl-devel 2022.1.0 ha770c72_916 conda-forge mkl-include 2022.1.0 h84fe81f_915 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge multidict 6.0.5 py310h2372a71_0 conda-forge multiprocess 0.70.16 py310h2372a71_0 conda-forge ncurses 6.4 h6a678d5_0 nettle 3.7.3 hbbd107a_1 https://aws-ml-conda.s3.us-west-2.amazonaws.com networkx 3.2.1 pyhd8ed1ab_0 conda-forge numpy 1.26.4 py310hb13e2d6_0 conda-forge openh264 2.1.1 h780b84a_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com openjpeg 2.5.0 hfec8fc6_2 conda-forge openssl 3.2.1 hd590300_1 conda-forge orc 2.0.0 h1e5e2c1_0 conda-forge packaging 24.0 pyhd8ed1ab_0 conda-forge pandas 2.2.1 py310hcc13569_0 conda-forge pillow 9.4.0 py310h023d228_1 conda-forge pip 23.3.1 py310h06a4308_0 pthread-stubs 0.4 h36c2ea0_1001 conda-forge pyarrow 15.0.2 py310hf9e7431_0_cpu conda-forge pyarrow-hotfix 0.6 pyhd8ed1ab_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 conda-forge python 3.10.13 h955ad1f_0 python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge python-tzdata 2024.1 pyhd8ed1ab_0 conda-forge python-xxhash 3.4.1 py310h2372a71_0 conda-forge python_abi 3.10 2_cp310 conda-forge pytorch 2.0.1 aws_py3.10_cuda12.2_cudnn8.9.4_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com pytorch-cuda 12.2 h5ef38aa_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com pytorch-mutex 1.0 cuda https://aws-ml-conda.s3.us-west-2.amazonaws.com pytz 2024.1 pyhd8ed1ab_0 conda-forge pyyaml 6.0.1 py310h2372a71_1 conda-forge rdma-core 50.0 hd3aeb46_1 conda-forge re2 2023.09.01 h7f4b329_2 conda-forge readline 8.2 h5eee18b_0 regex 2023.12.25 py310h2372a71_0 conda-forge requests 2.31.0 pyhd8ed1ab_0 conda-forge s2n 1.4.7 h06160fa_0 conda-forge s3transfer 0.10.1 pyhd8ed1ab_0 conda-forge safetensors 0.4.2 py310hcb5633a_0 conda-forge setuptools 68.2.2 py310h06a4308_0 six 1.16.0 pyh6c4a22f_0 conda-forge snappy 1.1.10 h9fff704_0 conda-forge sqlite 3.41.2 h5eee18b_0 sympy 1.12 pyh04b8f61_3 conda-forge tbb 2021.8.0 hdb19cb5_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com tk 8.6.12 h1ccaba5_0 tokenizers 0.15.2 py310h320607d_0 conda-forge torchaudio 2.0.2 py310_cu122 https://aws-ml-conda.s3.us-west-2.amazonaws.com torchtriton 2.0.0 py310 https://aws-ml-conda.s3.us-west-2.amazonaws.com torchvision 0.15.2 py310_cu122 https://aws-ml-conda.s3.us-west-2.amazonaws.com tqdm 4.66.2 pyhd8ed1ab_0 conda-forge transformers 4.39.0 pyhd8ed1ab_0 conda-forge typing-extensions 4.10.0 hd8ed1ab_0 conda-forge typing_extensions 4.10.0 pyha770c72_0 conda-forge tzdata 2024a h04d1e81_0 ucx 1.15.0 h11edf95_7 conda-forge urllib3 2.2.1 pyhd8ed1ab_0 conda-forge wheel 0.41.2 py310h06a4308_0 xorg-libxau 1.0.11 hd590300_0 conda-forge xorg-libxdmcp 1.1.3 h7f98852_0 conda-forge xxhash 0.8.2 hd590300_0 conda-forge xz 5.4.6 h5eee18b_0 yaml 0.2.5 h7f98852_2 conda-forge yarl 1.9.4 py310h2372a71_0 conda-forge zlib 1.2.13 hd590300_5 conda-forge zstd 1.5.5 hfc55251_0 conda-forge ```
Package list without PR-#118 ```text # packages in environment at /fsx/marcverd/awsome-distributed-training/3.test_cases/10.FSDP/pt_fsdp_haha: # # Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu aiohttp 3.9.3 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi async-timeout 4.0.3 pypi_0 pypi attrs 23.2.0 pypi_0 pypi bzip2 1.0.8 h5eee18b_5 ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi datasets 2.18.0 pypi_0 pypi dill 0.3.8 pypi_0 pypi filelock 3.13.1 pypi_0 pypi frozenlist 1.4.1 pypi_0 pypi fsspec 2024.2.0 pypi_0 pypi huggingface-hub 0.21.4 pypi_0 pypi idna 3.6 pypi_0 pypi jinja2 3.1.3 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1 libffi 3.4.4 h6a678d5_0 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libstdcxx-ng 11.2.0 h1234567_1 libuuid 1.41.5 h5eee18b_0 markupsafe 2.1.5 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi multidict 6.0.5 pypi_0 pypi multiprocess 0.70.16 pypi_0 pypi ncurses 6.4 h6a678d5_0 networkx 3.2.1 pypi_0 pypi numpy 1.26.4 pypi_0 pypi nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi nvidia-curand-cu12 10.3.2.106 pypi_0 pypi nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi nvidia-nccl-cu12 2.19.3 pypi_0 pypi nvidia-nvjitlink-cu12 12.4.99 pypi_0 pypi nvidia-nvtx-cu12 12.1.105 pypi_0 pypi openssl 3.0.13 h7f8727e_0 packaging 24.0 pypi_0 pypi pandas 2.2.1 pypi_0 pypi pillow 10.2.0 pypi_0 pypi pip 23.3.1 py310h06a4308_0 pyarrow 15.0.2 pypi_0 pypi pyarrow-hotfix 0.6 pypi_0 pypi python 3.10.14 h955ad1f_0 python-dateutil 2.9.0.post0 pypi_0 pypi pytz 2024.1 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi readline 8.2 h5eee18b_0 regex 2023.12.25 pypi_0 pypi requests 2.31.0 pypi_0 pypi safetensors 0.4.2 pypi_0 pypi setuptools 68.2.2 py310h06a4308_0 six 1.16.0 pypi_0 pypi sqlite 3.41.2 h5eee18b_0 sympy 1.12 pypi_0 pypi tk 8.6.12 h1ccaba5_0 tokenizers 0.15.2 pypi_0 pypi torch 2.2.1 pypi_0 pypi torchaudio 2.2.1 pypi_0 pypi torchvision 0.17.1 pypi_0 pypi tqdm 4.66.2 pypi_0 pypi transformers 4.39.1 pypi_0 pypi triton 2.2.0 pypi_0 pypi typing-extensions 4.10.0 pypi_0 pypi tzdata 2024.1 pypi_0 pypi urllib3 2.2.1 pypi_0 pypi wheel 0.41.2 py310h06a4308_0 xxhash 3.4.1 pypi_0 pypi xz 5.4.6 h5eee18b_0 yarl 1.9.4 pypi_0 pypi zlib 1.2.13 h5eee18b_0 ```
johnbensnyder commented 5 months ago

Open a PR that fixes this issue, and also updates the checkpointing functions to work with PyTorch>=2.1.