openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.65k stars 424 forks source link

Segmentation fault during compilation in CollectivePermuteMotion pass #7155

Open hawkinsp opened 11 months ago

hawkinsp commented 11 months ago

https://github.com/google/jax/issues/18384 describes a JAX segfault on 2 or more GPUs, which turns out to be an XLA crash:

Stack trace:

#0  0x00007fbb63893850 in xla::HloInstruction::operand(long) const ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#1  0x00007fbb5d39add6 in xla::MoveCollectivePermutes(xla::HloComputation*, xla::HloInstruction*) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#2  0x00007fbb5d39c42d in xla::CollectivePermuteMotion::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#3  0x00007fbb616d7b45 in absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&) () from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#4  0x00007fbb616d8937 in xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&) () from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#5  0x00007fbb5cb22241 in xla::HloPassInterface::Run(xla::HloModule*) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#6  0x00007fbb5cb367db in xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#7  0x00007fbb5cb3c431 in xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#8  0x00007fbb616daeda in xla::LLVMCompiler::Compile(std::unique_ptr<xla::HloModuleGroup, std::default_delete<xla::HloModuleGroup> >, std::vector<std::vector<stream_executor::StreamExecutor*, std::allocator<stream_executor::StreamExecutor*> >, std::allocator<std::vector<stream_executor::StreamExecutor*, std::allocator<stream_executor::StreamExecutor*> > > >, xla::Compiler::CompileOptions const&) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#9  0x00007fbb5ca61a69 in xla::Service::BuildExecutables(std::vector<xla::HloModuleProto const*, std::allocator<xla::HloModuleProto const*> > const&, std::vector<std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> >, std::allocator<std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> > > >, xla::Backend*, std::vector<std::vector<stream_executor::StreamExecutor*, std::allocator<stream_executor::StreamExecutor*> >, std::allocator<std::vector<stream_executor::StreamExecutor*, std::allocator<stream_executor::StreamExecutor*> > > >, xla::Compiler::CompileOptions const&, bool) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#10 0x00007fbb5c825d0c in xla::LocalService::CompileExecutables(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#11 0x00007fbb5c820cd2 in xla::LocalClient::Compile(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#12 0x00007fbb5c7dfacc in xla::PjRtStreamExecutorClient::Compile(xla::XlaComputation const&, xla::CompileOptions) () from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#13 0x00007fbb5c7bb9ef in xla::StreamExecutorGpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions) () from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#14 0x00007fbb5c7f2d5a in xla::PjRtStreamExecutorClient::Compile(mlir::ModuleOp, xla::CompileOptions) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#15 0x00007fbb5c70d3bf in xla::ifrt::PjRtLoadedExecutable::Create(xla::ifrt::PjRtCompatibleClient*, mlir::ModuleOp, xla::CompileOptions, std::vector<tsl::RCReference<xla::ifrt::LoadedHostCallback>, std::allocator<tsl::RCReference<xla::ifrt::LoadedHostCallback> > >) ()
   from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#16 0x00007fbb5c703a64 in xla::ifrt::PjRtCompiler::Compile(std::unique_ptr<xla::ifrt::Program, std::default_delete<xla::ifrt::Program> >, std::unique_ptr<xla::ifrt::CompileOptions, std::default_delete<xla::ifrt::CompileOptions> >) () from /home/phawkins/myenv/lib/python3.10/site-packages/jaxlib/xla_extension.so

The problematic HLO appears to be:

HloModule jit_test, entry_computation_layout={(f32[128]{0}, f32[8]{0})->f32[8]{0}}, allow_spmd_sharding_propagation_to_output={true}

region_1.9 {
  Arg_0.10 = f32[] parameter(0)
  Arg_1.11 = f32[] parameter(1)
  ROOT add.12 = f32[] add(Arg_0.10, Arg_1.11), metadata={op_name="jit(test)/jit(main)/while/body/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-49ad07addeb0>" source_line=32}
}

region_0.13 {
  arg_tuple.14 = (s32[], pred[], f32[8]{0}, f32[16,8]{1,0}, f32[8]{0}) parameter(0)
  get-tuple-element.15 = s32[] get-tuple-element(arg_tuple.14), index=0
  constant.21 = s32[] constant(1)
  add.36 = s32[] add(get-tuple-element.15, constant.21), metadata={op_name="jit(test)/jit(main)/while/body/add" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  constant.20 = pred[] constant(false)
  get-tuple-element.16 = pred[] get-tuple-element(arg_tuple.14), index=1
  broadcast.34 = pred[8]{0} broadcast(get-tuple-element.16), dimensions={}, metadata={op_name="jit(test)/jit(main)/while/body/select_n" source_file="<ipython-input-1-49ad07addeb0>" source_line=23}
  get-tuple-element.19 = f32[8]{0} get-tuple-element(arg_tuple.14), index=4
  get-tuple-element.18 = f32[16,8]{1,0} get-tuple-element(arg_tuple.14), index=3
  constant.24 = s32[] constant(0)
  compare.25 = pred[] compare(get-tuple-element.15, constant.24), direction=LT, metadata={op_name="jit(test)/jit(main)/while/body/lt" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  constant.23 = s32[] constant(16)
  add.26 = s32[] add(get-tuple-element.15, constant.23), metadata={op_name="jit(test)/jit(main)/while/body/add" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  select.27 = s32[] select(compare.25, add.26, get-tuple-element.15), metadata={op_name="jit(test)/jit(main)/while/body/select_n" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  dynamic-slice.28 = f32[1,8]{1,0} dynamic-slice(get-tuple-element.18, select.27, constant.24), dynamic_slice_sizes={1,8}, metadata={op_name="jit(test)/jit(main)/while/body/dynamic_slice[slice_sizes=(1, 8)]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  reshape.29 = f32[8]{0} reshape(dynamic-slice.28), metadata={op_name="jit(test)/jit(main)/while/body/squeeze[dimensions=(0,)]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  constant.22 = f32[] constant(0)
  reduce.30 = f32[] reduce(reshape.29, constant.22), dimensions={0}, to_apply=region_1.9, metadata={op_name="jit(test)/jit(main)/while/body/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-49ad07addeb0>" source_line=32}
  broadcast.31 = f32[8]{0} broadcast(reduce.30), dimensions={}, metadata={op_name="jit(test)/jit(main)/while/body/add" source_file="<ipython-input-1-49ad07addeb0>" source_line=32}
  add.32 = f32[8]{0} add(get-tuple-element.19, broadcast.31), metadata={op_name="jit(test)/jit(main)/while/body/add" source_file="<ipython-input-1-49ad07addeb0>" source_line=32}
  get-tuple-element.17 = f32[8]{0} get-tuple-element(arg_tuple.14), index=2
  add.33 = f32[8]{0} add(get-tuple-element.17, add.32), metadata={op_name="jit(test)/jit(main)/while/body/add" source_file="<ipython-input-1-49ad07addeb0>" source_line=23}
  select.35 = f32[8]{0} select(broadcast.34, add.32, add.33), metadata={op_name="jit(test)/jit(main)/while/body/select_n" source_file="<ipython-input-1-49ad07addeb0>" source_line=23}
  ROOT tuple.37 = (s32[], pred[], f32[8]{0}, f32[16,8]{1,0}, f32[8]{0}) tuple(add.36, constant.20, select.35, get-tuple-element.18, get-tuple-element.19)
} // region_0.13

region_2.38 {
  arg_tuple.39 = (s32[], pred[], f32[8]{0}, f32[16,8]{1,0}, f32[8]{0}) parameter(0)
  get-tuple-element.41 = pred[] get-tuple-element(arg_tuple.39), index=1
  get-tuple-element.42 = f32[8]{0} get-tuple-element(arg_tuple.39), index=2
  get-tuple-element.43 = f32[16,8]{1,0} get-tuple-element(arg_tuple.39), index=3
  get-tuple-element.44 = f32[8]{0} get-tuple-element(arg_tuple.39), index=4
  get-tuple-element.40 = s32[] get-tuple-element(arg_tuple.39), index=0
  constant.45 = s32[] constant(16)
  ROOT compare.46 = pred[] compare(get-tuple-element.40, constant.45), direction=LT, metadata={op_name="jit(test)/jit(main)/while/cond/lt" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
} // region_2.38

ENTRY main.53 {
  constant.5 = s32[] constant(0)
  constant.6 = pred[] constant(true)
  constant.3 = f32[] constant(0)
  broadcast.4 = f32[8]{0} broadcast(constant.3), dimensions={}
  Arg_0.1 = f32[128]{0} parameter(0), sharding={devices=[4]0,1,2,3}
  reshape.7 = f32[16,8]{1,0} reshape(Arg_0.1), metadata={op_name="jit(test)/jit(main)/reshape[new_sizes=(16, 8) dimensions=None]" source_file="<ipython-input-1-49ad07addeb0>" source_line=35}
  Arg_1.2 = f32[8]{0} parameter(1), sharding={replicated}
  tuple.8 = (s32[], pred[], f32[8]{0}, f32[16,8]{1,0}, f32[8]{0}) tuple(constant.5, constant.6, broadcast.4, reshape.7, Arg_1.2), metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  while.47 = (s32[], pred[], f32[8]{0}, f32[16,8]{1,0}, f32[8]{0}) while(tuple.8), condition=region_2.38, body=region_0.13, metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  get-tuple-element.48 = s32[] get-tuple-element(while.47), index=0, metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  get-tuple-element.49 = pred[] get-tuple-element(while.47), index=1, metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  ROOT get-tuple-element.50 = f32[8]{0} get-tuple-element(while.47), index=2, metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  get-tuple-element.51 = f32[16,8]{1,0} get-tuple-element(while.47), index=3, metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
  get-tuple-element.52 = f32[8]{0} get-tuple-element(while.47), index=4, metadata={op_name="jit(test)/jit(main)/while[cond_nconsts=0 body_nconsts=2]" source_file="<ipython-input-1-49ad07addeb0>" source_line=28}
} // main.53
jaro-sevcik commented 7 months ago

This looks like issue #10394 (while loop without collective-permute is crashing collective-permute-motion), so perhaps this is fixed (by #10395)?

inailuig commented 6 months ago

I think this can be closed, see https://github.com/google/jax/issues/18384.