Open qGentry opened 6 months ago
Here is some toy repro tested on JAX 0.4.34
Latency hiding scheduler enabled, XLA_FLAGS: --xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= --xla_gpu_enable_latency_hiding_scheduler=true
SCAN=False (hits OOM):
2024-10-17 10:47:03.275923: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 53.46GiB (57397316392 bytes) by rematerialization; only reduced to 70.25GiB (75430461976 bytes), down from 75.50GiB (81067606600 bytes) originally
Total size device = 54.500054121017456 GB, weights = 16.500030532479286 GB, total: 71.00008465349674 GB
SCAN=True:
Total size device = 34.5000324845314 GB, weights = 16.500030532479286 GB, total: 51.00006301701069 GB
Latency hiding scheduler disabled, XLA_FLAGS=--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer=
SCAN=False:
Total size device = 37.50002360343933 GB, weights = 16.500030532479286 GB, total: 54.00005413591862 GB
SCAN=True
Total size device = 35.00000220537186 GB, weights = 16.500030532479286 GB, total: 51.50003273785114 GB
@qGentry Using JAX 0.4.35 XLA_FLAGS="--xla_gpu_graph_level=0 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_command_buffer= "
and SCAN=False
, I'm seeing a failure.
*** Check failure stack trace: ***
@ 0x7f26b3b96dc4 absl::lts_20230802::log_internal::LogMessage::SendToLog()
@ 0x7f26b3b96c34 absl::lts_20230802::log_internal::LogMessage::Flush()
@ 0x7f26b3b971e9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f26ac056141 xla::PjRtStreamExecutorLoadedExecutable::Execute()
@ 0x7f26abfa5d71 pjrt::PJRT_LoadedExecutable_Execute()
@ 0x7f26bbd699fc xla::PjRtCApiLoadedExecutable::Execute()
@ 0x7f26c1ab0a25 xla::ifrt::PjRtLoadedExecutable::Execute()
@ 0x7f26c1250c49 xla::(anonymous namespace)::ExecuteShardedOnLocalDevicesInternal<>()
@ 0x7f26c12528ee xla::PyLoadedExecutable::ExecuteSharded()
@ 0x7f26bbc2fc55 xla::ValueOrThrowWrapper<>::operator()()
@ 0x7f26bbc2fabd nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x7f26c1a86eb8 nanobind::detail::nb_func_vectorcall_complex()
@ 0x56227ff6aabb _PyEval_EvalFrameDefault
Aborted (core dumped)
Any chance you have other flags or env variables set?
Description
Hi, we're training large (300B, 60 layers) mixture of experts transformer on a 1000+ GPU. We have some non-uniformity in layers so we can't use jax.lax.scan directly to stack layers together - instead, we just call each layer independently. Model doesn't have completely random structure, it is like (3 layers with same structure, 1 with another) repeated 15 times (to achieve 60 layers in total) We would benefit a LOT from overlapping computations & communications but when we try to enable latency hiding scheduler
--xla_gpu_enable_latency_hiding_scheduler
, this leads to increased in memory usage by a factor of 4-5 (from 50Gb per GPU to 200-250Gb per GPU, which is completely unusable). My guess is that compiler doesn't reuse buffers for async comms in this case for different layers.We've tested also variant with jax.lax.scan and uniform layers, it seemed to work okay from memory usage point of view - only 20-25% overhead from latency hiding scheduler.
Is this a known problem? Is these any workaround?
System info (python version, jaxlib version, accelerator, etc.)
tested on 0.4.25/0.4.26, 1000+ H100 GPU