pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 469 forks source link

torchdynamo + XLA crash #7053

Open pritamdamania87 opened 4 months ago

pritamdamania87 commented 4 months ago

🐛 Bug

Running into the following error when using torch.compile(backend="openxla"):

  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "foo.py", line 1922, in forward
    with no_autocast():
  File "foo.py", line 1922, in torch_dynamo_resume_in_forward_at_1922
    with no_autocast():
  File "torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 88, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 505, in forward
    fw_outs = call_func_at_runtime_with_args(
  File "torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/_dynamo/backends/torchxla.py", line 51, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "torch_xla/core/dynamo_bridge.py", line 621, in extract_compiled_graph
    extract_internal(fused_module), node.args, None)
  File "torch_xla/core/dynamo_bridge.py", line 374, in extract_internal
    dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
  File "torch_xla/core/dynamo_bridge.py", line 339, in extract_graph_helper
    torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
RuntimeError: Bad StatusOr access: INTERNAL: RET_CHECK failure (external/xla/xla/service/gpu/gpu_compiler.cc:1759) !llvm::verifyModule(*llvm_module, &err_stream) Invalid LLVM IR before optimizations:
Attribute does not match Module context!
memory(none)
ptr @llvm.nvvm.read.ptx.sreg.ctaid.x
Attribute does not match Module context!
memory(none)
ptr @llvm.nvvm.read.ptx.sreg.tid.x
Attribute does not match Module context!
memory(none)
ptr @llvm.nvvm.read.ptx.sreg.ctaid.y
Attribute does not match Module context!
memory(none)
ptr @__nv_rsqrtf

This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR and looks for files with name containing: *module_0001.SyncTensorsGraph.68.*```

Running with `XLA_DYNAMO_DEBUG=1 PT_XLA_DEBUG=1`, I see the following logs right before the crash:

Graph Module:

def forward(self, primals_10, primals_9, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8):
    slice_5 = torch.ops.aten.slice.Tensor(primals_10, 0, 0, 9223372036854775807)
    slice_1 = torch.ops.aten.slice.Tensor(primals_10, 0, 0, 9223372036854775807);  primals_10 = None
    add = torch.ops.aten.add.Tensor(primals_9, 1);  primals_9 = None
    slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, 9223372036854775807);  slice_5 = None
    slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 2);  slice_1 = None
    slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 0, 9223372036854775807);  slice_6 = None
    slice_3 = torch.ops.aten.slice.Tensor(slice_2, 2, 0, 9223372036854775807);  slice_2 = None
    slice_8 = torch.ops.aten.slice.Tensor(slice_7, 3, 0, 9223372036854775807);  slice_7 = None
    slice_4 = torch.ops.aten.slice.Tensor(slice_3, 3, 0, 9223372036854775807);  slice_3 = None
    clone_1 = torch.ops.aten.clone.default(slice_8);  slice_8 = None
    clone = torch.ops.aten.clone.default(slice_4);  slice_4 = None
    mul = torch.ops.aten.mul.Tensor(clone_1, 1);  clone_1 = None
    convolution = torch.ops.aten.convolution.default(clone, primals_1, primals_2, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  clone = primals_1 = primals_2 = None
    relu = torch.ops.aten.relu.default(convolution);  convolution = None
    cat = torch.ops.aten.cat.default([relu, mul], 1);  relu = mul = None
    convolution_1 = torch.ops.aten.convolution.default(cat, primals_3, primals_4, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  cat = primals_3 = primals_4 = None
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution_1, primals_5, primals_6, primals_7, primals_8, True, 0.1, 1e-05);  primals_5 = primals_6 = primals_7 = primals_8 = None
    getitem = _native_batch_norm_legit_functional[0]
    getitem_1 = _native_batch_norm_legit_functional[1]
    getitem_2 = _native_batch_norm_legit_functional[2]
    getitem_3 = _native_batch_norm_legit_functional[3]
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
    hardtanh = torch.ops.aten.hardtanh.default(getitem, 0.0, 6.0)
    return (add, convolution_1, getitem, getitem_1, getitem_2, getitem_3, getitem_4, hardtanh)

Number of HLO Input: 11
Number of HLO Output: 8
Number of HLO Input can be aliased with Output: 0
XLA IR Text: 
IR {
  %0 = s64[] prim::Constant(), xla_shape=s64[]
  %1 = s64[] prim::Constant(), xla_shape=s64[]
  %2 = s64[] xla::device_data(), xla_shape=s64[]
  %3 = s64[] aten::add(%2, %1, %0), xla_shape=s64[], ROOT=0
  %4 = f32[3]{0} xla::device_data(), xla_shape=f32[3]{0}
  %5 = f32[3,3,1,1]{3,2,1,0} xla::device_data(), xla_shape=f32[3,3,1,1]{3,2,1,0}
  %6 = f32[] prim::Constant(), xla_shape=f32[]
  %7 = f32[28,4,480,640]{3,2,1,0} xla::device_data(), xla_shape=f32[28,4,480,640]{3,2,1,0}
  %8 = f32[28,4,480,640]{3,2,1,0} xla::select(%7), xla_shape=f32[28,4,480,640]{3,2,1,0}
  %9 = f32[28,2,480,640]{3,2,1,0} xla::select(%8), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %10 = f32[28,2,480,640]{3,2,1,0} xla::select(%9), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %11 = f32[28,2,480,640]{3,2,1,0} xla::select(%10), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %12 = f32[28,2,480,640]{3,2,1,0} aten::mul(%11, %6), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %13 = f32[1]{0} xla::device_data(), xla_shape=f32[1]{0}
  %14 = f32[1,2,1,1]{3,2,1,0} xla::device_data(), xla_shape=f32[1,2,1,1]{3,2,1,0}
  %15 = f32[28,4,480,640]{3,2,1,0} xla::select(%7), xla_shape=f32[28,4,480,640]{3,2,1,0}
  %16 = f32[28,2,480,640]{3,2,1,0} xla::select(%15), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %17 = f32[28,2,480,640]{3,2,1,0} xla::select(%16), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %18 = f32[28,2,480,640]{3,2,1,0} xla::select(%17), xla_shape=f32[28,2,480,640]{3,2,1,0}
  %19 = f32[28,1,480,640]{3,2,1,0} aten::convolution_overrideable(%18, %14, %13), xla_shape=f32[28,1,480,640]{3,2,1,0}
  %20 = f32[28,1,480,640]{3,2,1,0} aten::relu(%19), xla_shape=f32[28,1,480,640]{3,2,1,0}
  %21 = f32[28,3,480,640]{3,2,1,0} aten::cat(%20, %12), xla_shape=f32[28,3,480,640]{3,2,1,0}
  %22 = f32[28,3,480,640]{3,2,1,0} aten::convolution_overrideable(%21, %5, %4), xla_shape=f32[28,3,480,640]{3,2,1,0}, ROOT=1
  %23 = f32[3]{0} xla::device_data(), xla_shape=f32[3]{0}
  %24 = f32[3]{0} xla::device_data(), xla_shape=f32[3]{0}
  %25 = f32[3]{0} xla::device_data(), xla_shape=f32[3]{0}
  %26 = f32[3]{0} xla::device_data(), xla_shape=f32[3]{0}
  %27 = (f32[28,3,480,640]{3,2,1,0}, f32[3]{0}, f32[3]{0}, f32[3]{0}) aten::native_batch_norm(%22, %26, %25, %24, %23), num_outputs=4, xla_shape=(f32[28,3,480,640]{3,2,1,0}, f32[3]{0}, f32[3]{0}, f32[3]{0}), ROOT=4
  %28 = f32[3]{0} xla::moving_average(%27.1, %24), xla_shape=f32[3]{0}, ROOT=5
  %29 = f32[3]{0} xla::moving_average(%27.2, %23), xla_shape=f32[3]{0}, ROOT=6
  %30 = f32[] xla::device_data(), xla_shape=f32[]
  %31 = f32[] prim::Constant(), xla_shape=f32[]
  %32 = f32[28,3,480,640]{3,2,1,0} aten::clamp(%27.0, %31, %30), xla_shape=f32[28,3,480,640]{3,2,1,0}, ROOT=7
}

I see the following files in the xla dump directory directory:

module_0001.SyncTensorsGraph.68.autotune_results.pbtxt
module_0001.SyncTensorsGraph.68.before_optimizations.txt
module_0001.SyncTensorsGraph.68.gpu_target_config.pbtxt
module_0001.SyncTensorsGraph.68.ir-no-opt.ll
module_0001.SyncTensorsGraph.68.sm_9.0_gpu_after_optimizations-buffer-assignment.txt
module_0001.SyncTensorsGraph.68.sm_9.0_gpu_after_optimizations.txt

The module_0001.SyncTensorsGraph.68.before_optimizations.txt is as follows:

HloModule SyncTensorsGraph.68, entry_computation_layout={(s64[], f32[3]{0}, f32[3,3,1,1]{3,2,1,0}, f32[28,4,480,640]{3,2,1,0}, f32[1]{0}, /*index=5*/f32[1,2,1,1]{3,2,1,0}, f32[3]{0}, f32[3]{0}, f32[3]{0}, f32[3]{0}, /*index=10*/f32[])->(s64[], f32[28,3,480,640]{3,2,1,0}, f32[28,3,480,640]{3,2,1,0}, f32[3]{0}, f32[3]{0}, /*index=5*/f32[3]{0}, f32[3]{0}, f32[28,3,480,640]{3,2,1,0})}

ENTRY SyncTensorsGraph.68 {
  p0.3 = s64[] parameter(0)
  constant.2 = s64[] constant(1)
  constant.1 = s64[] constant(1)
  multiply.4 = s64[] multiply(constant.2, constant.1)
  add.5 = s64[] add(p0.3, multiply.4)
  p3.9 = f32[28,4,480,640]{3,2,1,0} parameter(3)
  slice.18 = f32[28,4,480,640]{3,2,1,0} slice(p3.9), slice={[0:28], [0:4], [0:480], [0:640]}
  slice.19 = f32[28,2,480,640]{3,2,1,0} slice(slice.18), slice={[0:28], [0:2], [0:480], [0:640]}
  slice.20 = f32[28,2,480,640]{3,2,1,0} slice(slice.19), slice={[0:28], [0:2], [0:480], [0:640]}
  slice.21 = f32[28,2,480,640]{3,2,1,0} slice(slice.20), slice={[0:28], [0:2], [0:480], [0:640]}
  p5.17 = f32[1,2,1,1]{3,2,1,0} parameter(5)
  convolution.22 = f32[28,1,480,640]{3,2,1,0} convolution(slice.21, p5.17), window={size=1x1}, dim_labels=bf01_oi01->bf01
  p4.16 = f32[1]{0} parameter(4)
  broadcast.23 = f32[28,480,640,1]{3,2,1,0} broadcast(p4.16), dimensions={3}
  transpose.24 = f32[28,1,480,640]{1,3,2,0} transpose(broadcast.23), dimensions={0,3,1,2}
  add.25 = f32[28,1,480,640]{3,2,1,0} add(convolution.22, transpose.24)
  constant.26 = f32[] constant(0)
  broadcast.27 = f32[28,1,480,640]{3,2,1,0} broadcast(constant.26), dimensions={}
  maximum.28 = f32[28,1,480,640]{3,2,1,0} maximum(add.25, broadcast.27)
  slice.10 = f32[28,4,480,640]{3,2,1,0} slice(p3.9), slice={[0:28], [0:4], [0:480], [0:640]}
  slice.11 = f32[28,2,480,640]{3,2,1,0} slice(slice.10), slice={[0:28], [2:4], [0:480], [0:640]}
  slice.12 = f32[28,2,480,640]{3,2,1,0} slice(slice.11), slice={[0:28], [0:2], [0:480], [0:640]}
  slice.13 = f32[28,2,480,640]{3,2,1,0} slice(slice.12), slice={[0:28], [0:2], [0:480], [0:640]}
  constant.8 = f32[] constant(1)
  broadcast.14 = f32[28,2,480,640]{3,2,1,0} broadcast(constant.8), dimensions={}
  multiply.15 = f32[28,2,480,640]{3,2,1,0} multiply(slice.13, broadcast.14)
  concatenate.29 = f32[28,3,480,640]{3,2,1,0} concatenate(maximum.28, multiply.15), dimensions={1}
  p2.7 = f32[3,3,1,1]{3,2,1,0} parameter(2)
  convolution.30 = f32[28,3,480,640]{3,2,1,0} convolution(concatenate.29, p2.7), window={size=1x1}, dim_labels=bf01_oi01->bf01
  p1.6 = f32[3]{0} parameter(1)
  broadcast.31 = f32[28,480,640,3]{3,2,1,0} broadcast(p1.6), dimensions={3}
  transpose.32 = f32[28,3,480,640]{1,3,2,0} transpose(broadcast.31), dimensions={0,3,1,2}
  add.33 = f32[28,3,480,640]{3,2,1,0} add(convolution.30, transpose.32)
  p9.37 = f32[3]{0} parameter(9)
  p8.36 = f32[3]{0} parameter(8)
  batch-norm-training.38 = (f32[28,3,480,640]{3,2,1,0}, f32[3]{0}, f32[3]{0}) batch-norm-training(add.33, p9.37, p8.36), epsilon=1e-05, feature_index=1
  get-tuple-element.39 = f32[28,3,480,640]{3,2,1,0} get-tuple-element(batch-norm-training.38), index=0
  get-tuple-element.40 = f32[3]{0} get-tuple-element(batch-norm-training.38), index=1
  get-tuple-element.41 = f32[3]{0} get-tuple-element(batch-norm-training.38), index=2
  constant.42 = f32[] constant(1e-05)
  broadcast.43 = f32[3]{0} broadcast(constant.42), dimensions={}
  add.44 = f32[3]{0} add(get-tuple-element.41, broadcast.43)
  rsqrt.45 = f32[3]{0} rsqrt(add.44)
  constant.47 = f32[] constant(0.1)
  broadcast.51 = f32[3]{0} broadcast(constant.47), dimensions={}
  multiply.52 = f32[3]{0} multiply(get-tuple-element.40, broadcast.51)
  p7.35 = f32[3]{0} parameter(7)
  constant.46 = f32[] constant(1)
  subtract.48 = f32[] subtract(constant.46, constant.47)
  broadcast.49 = f32[3]{0} broadcast(subtract.48), dimensions={}
  multiply.50 = f32[3]{0} multiply(p7.35, broadcast.49)
  add.53 = f32[3]{0} add(multiply.52, multiply.50)
  constant.55 = f32[] constant(0.1)
  broadcast.59 = f32[3]{0} broadcast(constant.55), dimensions={}
  multiply.60 = f32[3]{0} multiply(get-tuple-element.41, broadcast.59)
  p6.34 = f32[3]{0} parameter(6)
  constant.54 = f32[] constant(1)
  subtract.56 = f32[] subtract(constant.54, constant.55)
  broadcast.57 = f32[3]{0} broadcast(subtract.56), dimensions={}
  multiply.58 = f32[3]{0} multiply(p6.34, broadcast.57)
  add.61 = f32[3]{0} add(multiply.60, multiply.58)
  constant.63 = f32[] constant(0)
  broadcast.64 = f32[28,3,480,640]{3,2,1,0} broadcast(constant.63), dimensions={}
  p10.62 = f32[] parameter(10)
  broadcast.65 = f32[28,3,480,640]{3,2,1,0} broadcast(p10.62), dimensions={}
  clamp.66 = f32[28,3,480,640]{3,2,1,0} clamp(broadcast.64, get-tuple-element.39, broadcast.65)
  ROOT tuple.67 = (s64[], f32[28,3,480,640]{3,2,1,0}, f32[28,3,480,640]{3,2,1,0}, f32[3]{0}, f32[3]{0}, /*index=5*/f32[3]{0}, f32[3]{0}, f32[28,3,480,640]{3,2,1,0}) tuple(add.5, add.33, get-tuple-element.39, get-tuple-element.40, rsqrt.45, /*index=5*/add.53, add.61, clamp.66)
} // SyncTensorsGraph.68

This is with PyTorch 2.3 and torch-xla compatible with PyTorch 2.3. Was wondering how to debug this issue further?

JackCaoG commented 4 months ago

Hey @pritamdamania87 nice to see you here. Can you dump the full error message? I should crash with some C++/python stack traces. It wasn't clear to me from the current log where the crash is from.

pritamdamania87 commented 4 months ago

Nice to see you as well @JackCaoG :) I updated the original issue summary with the stack trace as well.

JackCaoG commented 4 months ago

Hmm from the log it seems like XLA:GPU failed when compiling the above HLO. This is really weird as the HLO is pretty straightforward. To sanity check, do you mind setting PJRT_DEVICE=CPU and see if the same program will compile with XLA:CPU?

pritamdamania87 commented 4 months ago

PJRT_DEVICE=CPU results in a segfault:

  File "torch_xla/core/xla_model.py", line 1056 in mark_step
  File "torch_xla/core/dynamo_bridge.py", line 542 in extract_compiled_graph
  File "torch/_dynamo/backends/torchxla.py", line 51 in fwd
  File "torch/_functorch/_aot_autograd/utils.py", line 89 in g
  File "torch/_functorch/_aot_autograd/utils.py", line 113 in call_func_at_runtime_with_args
  File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 505 in forward
  File "torch/autograd/function.py", line 598 in apply
  File "torch/_functorch/_aot_autograd/utils.py", line 89 in g
  File "torch/_functorch/_aot_autograd/utils.py", line 113 in call_func_at_runtime_with_args
  File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 88 in runtime_wrapper
  File "torch/_functorch/_aot_autograd/utils.py", line 89 in g
  File "torch/_functorch/aot_autograd.py", line 917 in forward
  File "torch/_dynamo/external_utils.py", line 36 in inner
  File "torch/_dynamo/eval_frame.py", line 451 in _fn

gdb backtrace:

#0  0x000015547b7bbd28 in llvm::FoldingSet<llvm::SDNode>::NodeEquals(llvm::FoldingSetBase const*, llvm::FoldingSetBase::Node*, llvm::FoldingSetNodeID const&, unsigned int, llvm::FoldingSetNodeID&) ()
   from llvmlite/binding/../../../../libLLVM-14.so
#1  0x000015523b109e88 in llvm::FoldingSetBase::FindNodeOrInsertPos(llvm::FoldingSetNodeID const&, void*&, llvm::FoldingSetBase::FoldingSetInfo const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#2  0x0000155237280ec8 in llvm::SelectionDAG::FindNodeOrInsertPos(llvm::FoldingSetNodeID const&, llvm::SDLoc const&, void*&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#3  0x00001552372b4399 in llvm::SelectionDAG::getConstant(llvm::ConstantInt const&, llvm::SDLoc const&, llvm::EVT, bool, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#4  0x00001552372fba59 in llvm::SelectionDAGBuilder::visitGetElementPtr(llvm::User const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#5  0x0000155237326c20 in llvm::SelectionDAGBuilder::visit(llvm::Instruction const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#6  0x000015523733b953 in llvm::SelectionDAGISel::SelectBasicBlock(llvm::ilist_iterator_w_bits<llvm::ilist_detail::node_options<llvm::Instruction, false, false, void, true>, false, true>, llvm::ilist_iterator_w_bits<llvm::ilist_detail::node_options<llvm::Instruction, false, false, void, true>, false, true>, bool&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#7  0x000015523733db9f in llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#8  0x000015523733fd01 in llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#9  0x00001552365e05c6 in (anonymous namespace)::X86DAGToDAGISel::runOnMachineFunction(llvm::MachineFunction&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#10 0x0000155236f239b4 in llvm::MachineFunctionPass::runOnFunction(llvm::Function&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#11 0x000015523a1b5b59 in llvm::FPPassManager::runOnFunction(llvm::Function&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#12 0x000015523a1b6003 in llvm::FPPassManager::runOnModule(llvm::Module&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#13 0x000015523a1b5500 in llvm::legacy::PassManagerImpl::run(llvm::Module&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#14 0x0000155234086144 in xla::cpu::CompilerFunctor::operator()(llvm::Module&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#15 0x00001552368d4388 in llvm::orc::IRCompileLayer::emit(std::unique_ptr<llvm::orc::MaterializationResponsibility, std::default_delete<llvm::orc::MaterializationResponsibility> >, llvm::orc::ThreadSafeModule) ()
   from _XLAC.cpython-310-x86_64-linux-gnu.so
#16 0x00001552368faac5 in llvm::orc::BasicIRLayerMaterializationUnit::materialize(std::unique_ptr<llvm::orc::MaterializationResponsibility, std::default_delete<llvm::orc::MaterializationResponsibility> >) ()
   from _XLAC.cpython-310-x86_64-linux-gnu.so
#17 0x000015523687ea9e in llvm::orc::MaterializationTask::run() () from _XLAC.cpython-310-x86_64-linux-gnu.so
#18 0x000015547ceec5c8 in void llvm::detail::UniqueFunctionBase<void, std::unique_ptr<llvm::orc::Task, std::default_delete<llvm::orc::Task> > >::CallImpl<void (*)(std::unique_ptr<llvm::orc::Task, std::default_delete<llvm::orc::Task> >)>(void*, std::unique_ptr<llvm::orc::--Type <RET> for more, q to quit, c to continue without paging--c
Task, std::default_delete<llvm::orc::Task> >&) () from llvmlite/binding/../../../../libLLVM-14.so
#19 0x000015523687f5e8 in llvm::orc::ExecutionSession::dispatchOutstandingMUs() () from _XLAC.cpython-310-x86_64-linux-gnu.so
#20 0x0000155236888d99 in llvm::orc::ExecutionSession::OL_completeLookup(std::unique_ptr<llvm::orc::InProgressLookupState, std::default_delete<llvm::orc::InProgressLookupState> >, std::shared_ptr<llvm::orc::AsynchronousSymbolQuery>, std::function<void (llvm::DenseMap<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> >, llvm::DenseMapInfo<llvm::orc::JITDylib*, void>, llvm::detail::DenseMapPair<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> > > > const&)>) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#21 0x000015523688990e in llvm::orc::InProgressFullLookupState::complete(std::unique_ptr<llvm::orc::InProgressLookupState, std::default_delete<llvm::orc::InProgressLookupState> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#22 0x000015523687aeb1 in llvm::orc::ExecutionSession::OL_applyQueryPhase1(std::unique_ptr<llvm::orc::InProgressLookupState, std::default_delete<llvm::orc::InProgressLookupState> >, llvm::Error) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#23 0x000015523687fb8a in llvm::orc::ExecutionSession::lookup(llvm::orc::LookupKind, std::vector<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags>, std::allocator<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags> > > const&, llvm::orc::SymbolLookupSet, llvm::orc::SymbolState, llvm::unique_function<void (llvm::Expected<llvm::DenseMap<llvm::orc::SymbolStringPtr, llvm::orc::ExecutorSymbolDef, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void>, llvm::detail::DenseMapPair<llvm::orc::SymbolStringPtr, llvm::orc::ExecutorSymbolDef> > >)>, std::function<void (llvm::DenseMap<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> >, llvm::DenseMapInfo<llvm::orc::JITDylib*, void>, llvm::detail::DenseMapPair<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> > > > const&)>) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#24 0x000015523688084a in llvm::orc::ExecutionSession::lookup(std::vector<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags>, std::allocator<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags> > > const&, llvm::orc::SymbolLookupSet, llvm::orc::LookupKind, llvm::orc::SymbolState, std::function<void (llvm::DenseMap<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> >, llvm::DenseMapInfo<llvm::orc::JITDylib*, void>, llvm::detail::DenseMapPair<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> > > > const&)>) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#25 0x0000155236880d8a in llvm::orc::ExecutionSession::lookup(std::vector<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags>, std::allocator<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags> > > const&, llvm::orc::SymbolStringPtr, llvm::orc::SymbolState) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#26 0x0000155236880fb4 in llvm::orc::ExecutionSession::lookup(llvm::ArrayRef<llvm::orc::JITDylib*>, llvm::orc::SymbolStringPtr, llvm::orc::SymbolState) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#27 0x00001552368811d1 in llvm::orc::ExecutionSession::lookup(llvm::ArrayRef<llvm::orc::JITDylib*>, llvm::StringRef, llvm::orc::SymbolState) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#28 0x0000155234080669 in xla::cpu::SimpleOrcJIT::FindCompiledSymbol(std::string const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#29 0x0000155234078781 in xla::cpu::CpuExecutable::Create(std::unique_ptr<xla::cpu::SimpleOrcJIT, std::default_delete<xla::cpu::SimpleOrcJIT> >, std::unique_ptr<xla::BufferAssignment const, std::default_delete<xla::BufferAssignment const> >, std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, std::string const&, std::unique_ptr<xla::HloProfilePrinterData, std::default_delete<xla::HloProfilePrinterData> >, std::unique_ptr<xla::HloProfileIndexMap, std::default_delete<xla::HloProfileIndexMap> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#30 0x0000155233aa2976 in xla::cpu::CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#31 0x0000155233aa303f in xla::cpu::CpuCompiler::RunBackend(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#32 0x0000155233a5d332 in xla::TfrtCpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#33 0x00001552317c1177 in torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#34 0x00001552315b7623 in torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#35 0x00001552315b953f in torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#36 0x00001552315b9b41 in torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, bool, bool, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#37 0x00001552315b9d78 in torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#38 0x00001552313702ba in torch_xla::(anonymous namespace)::StepMarker(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#39 0x0000155231370766 in pybind11::cpp_function::initialize<torch_xla::(anonymous namespace)::InitXlaModuleBindings(pybind11::module_)::{lambda(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool)#65}, void, std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg_v, pybind11::arg, pybind11::arg_v>(torch_xla::(anonymous namespace)::InitXlaModuleBindings(pybind11::module_)::{lambda(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool)#65}&&, void (*)(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#40 0x0000155231341fc6 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#41 0x0000555555698516 in cfunction_call (func=0x1552476ae2f0, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.10.12/Objects/methodobject.c:543
#42 0x0000555555691a6b in _PyObject_MakeTpCall (tstate=0x555555908e60, callable=0x1552476ae2f0, args=<optimized out>, nargs=2, keywords=0x1552476d8490) at /usr/local/src/conda/python-3.10.12/Objects/call.c:215
#43 0x000055555568dc39 in _PyObject_VectorcallTstate (kwnames=0x1552476d8490, nargsf=<optimized out>, args=<optimized out>, callable=0x1552476ae2f0, tstate=<optimized out>) at /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:112
JackCaoG commented 4 months ago

Could you share the model code you tried to torch.compile so I can double check on my end?

pritamdamania87 commented 4 months ago

Could you share the model code you tried to torch.compile so I can double check on my end?

Model is proprietary, was wondering if you have suggestions around building a minimal repro for this? There is a GraphModule section shared in the Issue summary, not sure if that can be used as-is for a repro.

JackCaoG commented 4 months ago

well I have the fx graph and if you dump the xla_args's input shape in https://github.com/pytorch/xla/blob/6f0b61e5d782913a0fc7743812f2a8e522189111/torch_xla/core/dynamo_bridge.py#L337-L339 I should be able to repo it. (otherwise I will have to manually map the HLO input shape back to tensor input shape and guess orders lol)

pritamdamania87 commented 4 months ago

Thanks for the pointer, these are the shapes for xla_args:

type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([28, 4, 480, 640])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([1, 2, 1, 1])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([1])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3, 3, 1, 1])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.nn.parameter.Parameter'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.nn.parameter.Parameter'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3])
pritamdamania87 commented 4 months ago

I did try the following repro.py based on the shapes, but seems like it works without any issues.

JackCaoG commented 4 months ago

filed an issue to XLA:GPU team, waiting for their feedback...

cheshire commented 3 months ago

Does it work with xla-gpu-nightly? We couldn't repro the issue.