google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29k stars 2.65k forks source link

T5x+jax pretraining longt5xl-3B 4096 input oom #21985

Open robotzheng opened 1 week ago

robotzheng commented 1 week ago

Description

Arguments

T5_SIZE='xl' #$1 # Model size (small, base, large) PREC="bfloat16" #"$2" # Precision (float32, float16, bfloat16) NUM_GPUS=8 #$3 # Number of GPUs (1, 2, 4, 8) BSIZE_PER_GPU=1 # Size per GPU (varies with model size) LOG_DIR="./logs/" #$5 # Output log directory MODEL_DIR_LOCAL=${6:-"model_dir"} MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} NUM_MICROBATCHES=${7:-0}

echo $MODEL_DIR

echo "Please makesure ${NUM_GPUS} is the number of visible CUDA devices you have"

Setting XLA flags

export XLA_FLAGS="--xla_gpu_simplify_all_fp_conversions --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"

export XLA_FLAGS="--xla_dump_to=/tmp/foo --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"

Global batch size

BSIZE=$(( NUM_GPUS * BSIZE_PER_GPU ))

System info (python version, jaxlib version, accelerator, etc.)

I0620 01:00:21.223034 140533293238080 utils.py:1085] Initializing parameters from scratch. 2024-06-20 01:03:26.100817: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 26.14GiB (28065947629 bytes) by rematerialization; only reduced to 34.30GiB (36831593796 bytes), down from 34.44GiB (36982730932 bytes) originally Fatal Python error: Segmentation fault

Thread 0x00007f4ca97fa640 (most recent call first): File "/opt/conda/envs/EN_LAM/lib/python3.10/concurrent/futures/thread.py", line 81 in _worker File "/opt/conda/envs/EN_LAM/lib/python3.10/threading.py", line 953 in run File "/opt/conda/envs/EN_LAM/lib/python3.10/threading.py", line 1016 in _bootstrap_inner File "/opt/conda/envs/EN_LAM/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007fd07505c740 (most recent call first): File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1253 in call File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/profiler.py", line 335 in wrapper File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/pjit.py", line 1568 in _pjit_call_impl_python File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/pjit.py", line 1614 in call_impl_cache_miss File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/pjit.py", line 1635 in _pjit_call_impl File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/core.py", line 921 in process_primitive File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/core.py", line 420 in bind_with_trace File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/core.py", line 2834 in bind File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/pjit.py", line 185 in _python_pjit_helper File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/pjit.py", line 327 in cache_miss File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179 in reraise_with_filtered_traceback File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/partitioning.py", line 954 in call File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/utils.py", line 1096 in from_scratch File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/train.py", line 389 in train File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/gin/config.py", line 1582 in gin_wrapper File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/train.py", line 961 in _main File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/train.py", line 900 in main File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/absl/app.py", line 254 in _run_main File "/opt/conda/envs/EN_LAM/lib/python3.10/site-packages/absl/app.py", line 308 in run File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/utils.py", line 2387 in run_main File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/gin_utils.py", line 135 in run File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/config_utils.py", line 215 in run File "/home/notebook/code/personal/80306170/AGI/LAM/TrainTramwork/html_t5x/t5x/train.py", line 966 in

Extension modules: jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, zstandard.backend_c

import jax; jax.print_environment_info() jax: 0.4.30.dev20240620 jaxlib: 0.4.30.dev20240619 numpy: 1.26.4 python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)] process_count: 1 platform: uname_result(system='Linux', node='task-20240618090647-17588', release='3.10.0-957.27.2.el7.x86_64', version='#1 SMP Mon Jul 29 17:46:05 UTC 2019', machine='x86_64')

$ nvidia-smi Thu Jun 20 01:13:15 2024
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.3 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA A800-SXM4-80GB On | 00000000:1E:00.0 Off | 0 | | N/A 50C P0 384W / 400W | 29056MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA A800-SXM4-80GB On | 00000000:24:00.0 Off | 0 | | N/A 66C P0 372W / 400W | 29032MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 2 NVIDIA A800-SXM4-80GB On | 00000000:4F:00.0 Off | 0 | | N/A 64C P0 395W / 400W | 29136MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 3 NVIDIA A800-SXM4-80GB On | 00000000:54:00.0 Off | 0 | | N/A 48C P0 401W / 400W | 29208MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 4 NVIDIA A800-SXM4-80GB On | 00000000:90:00.0 Off | 0 | | N/A 52C P0 365W / 400W | 29204MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 5 NVIDIA A800-SXM4-80GB On | 00000000:95:00.0 Off | 0 | | N/A 65C P0 360W / 400W | 29136MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 6 NVIDIA A800-SXM4-80GB On | 00000000:CB:00.0 Off | 0 | | N/A 71C P0 375W / 400W | 29204MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 7 NVIDIA A800-SXM4-80GB On | 00000000:D1:00.0 Off | 0 | | N/A 52C P0 377W / 400W | 29040MiB / 81920MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| +---------------------------------------------------------------------------------------+

justinjfu commented 1 week ago

Could you provide more details on the issue? Did this training script work previously and is this architecture expected to fit in the memory of a single 4090 GPU?

robotzheng commented 1 week ago

I use A100(80G), 8 GPUs. Input size is 4096, output size is 910. export XLA_PYTHON_CLIENT_PREALLOCATE=false export XLA_PYTHON_CLIENT_MEM_FRACTION=.99 export XLA_PYTHON_CLIENT_ALLOCATOR=platform after above 3 lines, only can run BSIZE_PER_GPU=1 # Size per GPU (varies with model size) when BSIZE_PER_GPU>1, the JAX oom. I use t5x + jax +xla +flaxformer.

Thanks.

justinjfu commented 1 week ago

It looks like JAX is not fully utilizing all of the GPUs memory. You could try looking into some of the tips at https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html This could include:

I would also potentially recommend posting this issue in the t5x github repository (https://github.com/google-research/t5x) and you may get a better response there.