Closed verdimrc closed 5 months ago
PR #118 causes below error, and I had to revert it to get the example works.
0: File "/fsx/marcverd/awsome-distributed-training/3.test_cases/10.FSDP/pt_fsdp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1076, in _update_causal_mask 0: causal_mask = torch.triu(causal_mask, diagonal=1)causal_mask = torch.triu(causal_mask, diagonal=1)causal_mask = torch.triu(causal_mask, diagonal=1) 0: 0: 0: RuntimeErrorRuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16':
Package versions (TL;DR):
torch-2.0.2
transformers-4.39.0
torch-2.2.1
transformers-4.39.1
Open a PR that fixes this issue, and also updates the checkpointing functions to work with PyTorch>=2.1.
PR #118 causes below error, and I had to revert it to get the example works.
Package versions (TL;DR):
torch-2.0.2
,transformers-4.39.0
torch-2.2.1
,transformers-4.39.1
Package list with PR-#118
```text # packages in environment at /fsx/marcverd/awsome-distributed-training/3.test_cases/10.FSDP/pt_fsdp: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge aiohttp 3.9.3 py310h2372a71_1 conda-forge aiosignal 1.3.1 pyhd8ed1ab_0 conda-forge async-timeout 4.0.3 pyhd8ed1ab_0 conda-forge attrs 23.2.0 pyh71513ae_0 conda-forge aws-c-auth 0.7.16 haed3651_8 conda-forge aws-c-cal 0.6.10 ha9bf9b1_2 conda-forge aws-c-common 0.9.14 hd590300_0 conda-forge aws-c-compression 0.2.18 h4466546_2 conda-forge aws-c-event-stream 0.4.2 he635cd5_6 conda-forge aws-c-http 0.8.1 hbfc29b2_7 conda-forge aws-c-io 0.14.6 h6b388c4_1 conda-forge aws-c-mqtt 0.10.3 hffff1cc_2 conda-forge aws-c-s3 0.5.2 h4893938_2 conda-forge aws-c-sdkutils 0.1.15 h4466546_2 conda-forge aws-checksums 0.1.18 h4466546_2 conda-forge aws-crt-cpp 0.26.3 h137ae52_2 conda-forge aws-ofi-nccl 1.7.4 aws_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com aws-sdk-cpp 1.11.267 he0cb598_3 conda-forge blas 2.116 mkl conda-forge blas-devel 3.9.0 16_linux64_mkl conda-forge boto3 1.34.67 pyhd8ed1ab_0 conda-forge botocore 1.34.67 pyge310_1234567_0 conda-forge brotli-python 1.1.0 py310hc6cd4ac_1 conda-forge bzip2 1.0.8 h5eee18b_5 c-ares 1.27.0 hd590300_0 conda-forge ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 pyhd8ed1ab_0 conda-forge charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge colorama 0.4.6 pyhd8ed1ab_0 conda-forge cuda-cudart 12.2.140 0 nvidia cuda-cupti 12.2.142 0 nvidia cuda-libraries 12.2.2 0 nvidia cuda-nvrtc 12.2.140 0 nvidia cuda-nvtx 12.2.140 0 nvidia cuda-opencl 12.4.99 0 nvidia cuda-runtime 12.2.2 0 nvidia datasets 2.18.0 pyhd8ed1ab_0 conda-forge dill 0.3.8 pyhd8ed1ab_0 conda-forge ffmpeg 4.2 h3fd9d12_1 https://aws-ml-conda.s3.us-west-2.amazonaws.com filelock 3.13.1 pyhd8ed1ab_0 conda-forge freetype 2.12.1 h267a509_2 conda-forge frozenlist 1.4.1 py310h2372a71_0 conda-forge fsspec 2023.9.2 pyh1a96a4e_0 conda-forge gettext 0.21.1 h27087fc_0 conda-forge gflags 2.2.2 he1b5a44_1004 conda-forge glog 0.7.0 hed5481d_0 conda-forge gmp 6.3.0 h59595ed_1 conda-forge gnutls 3.6.15 he1e5248_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com huggingface_hub 0.21.4 pyhd8ed1ab_0 conda-forge hwloc 2.9.2 h2bc3f7f_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com icu 73.2 h59595ed_0 conda-forge idna 3.6 pyhd8ed1ab_0 conda-forge jinja2 3.1.3 pyhd8ed1ab_0 conda-forge jmespath 1.0.1 pyhd8ed1ab_0 conda-forge jpeg 9e h0b41bf4_3 conda-forge keyutils 1.6.1 h166bdaf_0 conda-forge krb5 1.21.2 h659d440_0 conda-forge lame 3.100 h166bdaf_1003 conda-forge lcms2 2.15 hfd0df8a_0 conda-forge ld_impl_linux-64 2.38 h1181459_1 lerc 4.0.0 h27087fc_0 conda-forge libabseil 20240116.1 cxx17_h59595ed_2 conda-forge libarrow 15.0.2 h6bfc85a_0_cpu conda-forge libarrow-acero 15.0.2 h59595ed_0_cpu conda-forge libarrow-dataset 15.0.2 h59595ed_0_cpu conda-forge libarrow-flight 15.0.2 hc6145d9_0_cpu conda-forge libarrow-flight-sql 15.0.2 h757c851_0_cpu conda-forge libarrow-gandiva 15.0.2 hb016d2e_0_cpu conda-forge libarrow-substrait 15.0.2 h757c851_0_cpu conda-forge libblas 3.9.0 16_linux64_mkl conda-forge libbrotlicommon 1.1.0 hd590300_1 conda-forge libbrotlidec 1.1.0 hd590300_1 conda-forge libbrotlienc 1.1.0 hd590300_1 conda-forge libcblas 3.9.0 16_linux64_mkl conda-forge libcrc32c 1.1.2 h9c3ff4c_0 conda-forge libcublas 12.2.5.6 0 nvidia libcufft 11.0.8.103 0 nvidia libcufile 1.9.0.20 0 nvidia libcurand 10.3.5.119 0 nvidia libcurl 8.6.0 hca28451_0 conda-forge libcusolver 11.5.2.141 0 nvidia libcusparse 12.1.2.141 0 nvidia libdeflate 1.17 h0b41bf4_0 conda-forge libedit 3.1.20191231 he28a2e2_2 conda-forge libev 4.33 hd590300_2 conda-forge libevent 2.1.12 hf998b51_1 conda-forge libffi 3.4.4 h6a678d5_0 libgcc-ng 13.2.0 h807b86a_5 conda-forge libgfortran-ng 13.2.0 h69a702a_5 conda-forge libgfortran5 13.2.0 ha4646dd_5 conda-forge libgomp 13.2.0 h807b86a_5 conda-forge libgoogle-cloud 2.22.0 h9be4e54_1 conda-forge libgoogle-cloud-storage 2.22.0 hc7a4891_1 conda-forge libgrpc 1.62.1 h15f2491_0 conda-forge libiconv 1.17 hd590300_2 conda-forge libidn2 2.3.7 hd590300_0 conda-forge liblapack 3.9.0 16_linux64_mkl conda-forge liblapacke 3.9.0 16_linux64_mkl conda-forge libllvm16 16.0.6 h5cf9203_2 conda-forge libnghttp2 1.58.0 h47da74e_1 conda-forge libnl 3.9.0 hd590300_0 conda-forge libnpp 12.2.1.4 0 nvidia libnvjitlink 12.2.140 0 nvidia libnvjpeg 12.2.2.4 0 nvidia libparquet 15.0.2 h352af49_0_cpu conda-forge libpng 1.6.43 h2797004_0 conda-forge libprotobuf 4.25.3 h08a7969_0 conda-forge libre2-11 2023.09.01 h5a48ba9_2 conda-forge libssh2 1.11.0 h0841786_0 conda-forge libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge libtasn1 4.19.0 h166bdaf_0 conda-forge libthrift 0.19.0 hb90f79a_1 conda-forge libtiff 4.5.0 h6adf6a1_2 conda-forge libunistring 0.9.10 h7f98852_0 conda-forge libutf8proc 2.8.0 h166bdaf_0 conda-forge libuuid 1.41.5 h5eee18b_0 libwebp-base 1.3.2 hd590300_0 conda-forge libxcb 1.13 h7f98852_1004 conda-forge libxml2 2.11.6 h232c23b_0 conda-forge libzlib 1.2.13 hd590300_5 conda-forge llvm-openmp 15.0.7 h0cdce71_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com lz4-c 1.9.4 hcb278e6_0 conda-forge markupsafe 2.1.5 py310h2372a71_0 conda-forge mkl 2022.1.0 h84fe81f_915 https://aws-ml-conda.s3.us-west-2.amazonaws.com mkl-devel 2022.1.0 ha770c72_916 conda-forge mkl-include 2022.1.0 h84fe81f_915 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge multidict 6.0.5 py310h2372a71_0 conda-forge multiprocess 0.70.16 py310h2372a71_0 conda-forge ncurses 6.4 h6a678d5_0 nettle 3.7.3 hbbd107a_1 https://aws-ml-conda.s3.us-west-2.amazonaws.com networkx 3.2.1 pyhd8ed1ab_0 conda-forge numpy 1.26.4 py310hb13e2d6_0 conda-forge openh264 2.1.1 h780b84a_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com openjpeg 2.5.0 hfec8fc6_2 conda-forge openssl 3.2.1 hd590300_1 conda-forge orc 2.0.0 h1e5e2c1_0 conda-forge packaging 24.0 pyhd8ed1ab_0 conda-forge pandas 2.2.1 py310hcc13569_0 conda-forge pillow 9.4.0 py310h023d228_1 conda-forge pip 23.3.1 py310h06a4308_0 pthread-stubs 0.4 h36c2ea0_1001 conda-forge pyarrow 15.0.2 py310hf9e7431_0_cpu conda-forge pyarrow-hotfix 0.6 pyhd8ed1ab_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 conda-forge python 3.10.13 h955ad1f_0 python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge python-tzdata 2024.1 pyhd8ed1ab_0 conda-forge python-xxhash 3.4.1 py310h2372a71_0 conda-forge python_abi 3.10 2_cp310 conda-forge pytorch 2.0.1 aws_py3.10_cuda12.2_cudnn8.9.4_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com pytorch-cuda 12.2 h5ef38aa_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com pytorch-mutex 1.0 cuda https://aws-ml-conda.s3.us-west-2.amazonaws.com pytz 2024.1 pyhd8ed1ab_0 conda-forge pyyaml 6.0.1 py310h2372a71_1 conda-forge rdma-core 50.0 hd3aeb46_1 conda-forge re2 2023.09.01 h7f4b329_2 conda-forge readline 8.2 h5eee18b_0 regex 2023.12.25 py310h2372a71_0 conda-forge requests 2.31.0 pyhd8ed1ab_0 conda-forge s2n 1.4.7 h06160fa_0 conda-forge s3transfer 0.10.1 pyhd8ed1ab_0 conda-forge safetensors 0.4.2 py310hcb5633a_0 conda-forge setuptools 68.2.2 py310h06a4308_0 six 1.16.0 pyh6c4a22f_0 conda-forge snappy 1.1.10 h9fff704_0 conda-forge sqlite 3.41.2 h5eee18b_0 sympy 1.12 pyh04b8f61_3 conda-forge tbb 2021.8.0 hdb19cb5_0 https://aws-ml-conda.s3.us-west-2.amazonaws.com tk 8.6.12 h1ccaba5_0 tokenizers 0.15.2 py310h320607d_0 conda-forge torchaudio 2.0.2 py310_cu122 https://aws-ml-conda.s3.us-west-2.amazonaws.com torchtriton 2.0.0 py310 https://aws-ml-conda.s3.us-west-2.amazonaws.com torchvision 0.15.2 py310_cu122 https://aws-ml-conda.s3.us-west-2.amazonaws.com tqdm 4.66.2 pyhd8ed1ab_0 conda-forge transformers 4.39.0 pyhd8ed1ab_0 conda-forge typing-extensions 4.10.0 hd8ed1ab_0 conda-forge typing_extensions 4.10.0 pyha770c72_0 conda-forge tzdata 2024a h04d1e81_0 ucx 1.15.0 h11edf95_7 conda-forge urllib3 2.2.1 pyhd8ed1ab_0 conda-forge wheel 0.41.2 py310h06a4308_0 xorg-libxau 1.0.11 hd590300_0 conda-forge xorg-libxdmcp 1.1.3 h7f98852_0 conda-forge xxhash 0.8.2 hd590300_0 conda-forge xz 5.4.6 h5eee18b_0 yaml 0.2.5 h7f98852_2 conda-forge yarl 1.9.4 py310h2372a71_0 conda-forge zlib 1.2.13 hd590300_5 conda-forge zstd 1.5.5 hfc55251_0 conda-forge ```Package list without PR-#118
```text # packages in environment at /fsx/marcverd/awsome-distributed-training/3.test_cases/10.FSDP/pt_fsdp_haha: # # Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu aiohttp 3.9.3 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi async-timeout 4.0.3 pypi_0 pypi attrs 23.2.0 pypi_0 pypi bzip2 1.0.8 h5eee18b_5 ca-certificates 2024.3.11 h06a4308_0 certifi 2024.2.2 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi datasets 2.18.0 pypi_0 pypi dill 0.3.8 pypi_0 pypi filelock 3.13.1 pypi_0 pypi frozenlist 1.4.1 pypi_0 pypi fsspec 2024.2.0 pypi_0 pypi huggingface-hub 0.21.4 pypi_0 pypi idna 3.6 pypi_0 pypi jinja2 3.1.3 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1 libffi 3.4.4 h6a678d5_0 libgcc-ng 11.2.0 h1234567_1 libgomp 11.2.0 h1234567_1 libstdcxx-ng 11.2.0 h1234567_1 libuuid 1.41.5 h5eee18b_0 markupsafe 2.1.5 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi multidict 6.0.5 pypi_0 pypi multiprocess 0.70.16 pypi_0 pypi ncurses 6.4 h6a678d5_0 networkx 3.2.1 pypi_0 pypi numpy 1.26.4 pypi_0 pypi nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi nvidia-curand-cu12 10.3.2.106 pypi_0 pypi nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi nvidia-nccl-cu12 2.19.3 pypi_0 pypi nvidia-nvjitlink-cu12 12.4.99 pypi_0 pypi nvidia-nvtx-cu12 12.1.105 pypi_0 pypi openssl 3.0.13 h7f8727e_0 packaging 24.0 pypi_0 pypi pandas 2.2.1 pypi_0 pypi pillow 10.2.0 pypi_0 pypi pip 23.3.1 py310h06a4308_0 pyarrow 15.0.2 pypi_0 pypi pyarrow-hotfix 0.6 pypi_0 pypi python 3.10.14 h955ad1f_0 python-dateutil 2.9.0.post0 pypi_0 pypi pytz 2024.1 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi readline 8.2 h5eee18b_0 regex 2023.12.25 pypi_0 pypi requests 2.31.0 pypi_0 pypi safetensors 0.4.2 pypi_0 pypi setuptools 68.2.2 py310h06a4308_0 six 1.16.0 pypi_0 pypi sqlite 3.41.2 h5eee18b_0 sympy 1.12 pypi_0 pypi tk 8.6.12 h1ccaba5_0 tokenizers 0.15.2 pypi_0 pypi torch 2.2.1 pypi_0 pypi torchaudio 2.2.1 pypi_0 pypi torchvision 0.17.1 pypi_0 pypi tqdm 4.66.2 pypi_0 pypi transformers 4.39.1 pypi_0 pypi triton 2.2.0 pypi_0 pypi typing-extensions 4.10.0 pypi_0 pypi tzdata 2024.1 pypi_0 pypi urllib3 2.2.1 pypi_0 pypi wheel 0.41.2 py310h06a4308_0 xxhash 3.4.1 pypi_0 pypi xz 5.4.6 h5eee18b_0 yarl 1.9.4 pypi_0 pypi zlib 1.2.13 h5eee18b_0 ```