Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458
stars
69
forks
source link
Pipeline Parallelism: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size Fatal Python error: Aborted #4
2022-10-10 16:01:05.537760: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size
Fatal Python error: Aborted
Current thread 0x00007f5c10b73740 (most recent call first):
File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 940 in backend_compile
File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294 in wrapper
File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 996 in compile_or_get_cached
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 3048 in from_hlo
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2890 in compile
File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 815 in _pjit_call_impl
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 685 in process_primitive
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 327 in bind_with_trace
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 324 in bind
File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 385 in wrapped
File "/pax/paxml/paxml/train.py", line 1087 in train_and_evaluate_spmd_model
File "/pax/paxml/paxml/train.py", line 271 in train_and_evaluate
File "/pax/paxml/paxml/main.py", line 290 in run_experiment
File "/pax/paxml/paxml/main.py", line 535 in run
File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582 in gin_wrapper
File "/pax/paxml/paxml/main.py", line 588 in main
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run
File "/pax/paxml/paxml/main.py", line 631 in <module>
There is no problem when NUM_MICROBATCHES = 1.
It would be great if someone could look into this to figure out what may be causing XLA to break when using NUM_MICROBATCHES > 1.
Hello!
I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I run into some errors when NUM_MICROBATCHES > 1.
System:
8X NVIDIA A100-SXM 80 GB
Gin Configs:
Command:
XLA Complie Time Error:
There is no problem when NUM_MICROBATCHES = 1.
It would be great if someone could look into this to figure out what may be causing XLA to break when using NUM_MICROBATCHES > 1.