pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.18k stars 22.09k forks source link

Stack trace is symbolized when no exception is thrown #133979

Open ppwwyyxx opened 3 weeks ago

ppwwyyxx commented 3 weeks ago

🐛 Describe the bug

The following code:

import torch
import torch.distributed as dist
if __name__ == '__main__':
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())
    t = torch.randn(256, device='cuda')
    dist.scatter(t, [t, t] if dist.get_rank() == 0 else [])
    dist.scatter(t, [t, t] if dist.get_rank() == 0 else [])

    torch.cuda.synchronize()
    dist.destroy_process_group()

prints:

$ TORCH_SHOW_CPP_STACKTRACES=1 torchrun --nnodes=1 --nproc-per-node=2 a.py
[rank1]:[W820 06:51:47.822066421 Module.cpp:175] symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1...

[rank1]:[W820 06:51:47.982257538 Module.cpp:175] symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1...

The problem is that it should not try to symbolize the stack: this is an annoying and useless warning message, and may have performance implications.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ot @ezyang as a follow up of https://github.com/pytorch/pytorch/pull/113207#issuecomment-2298072193

Versions

torch 2.4.0

ezyang commented 3 weeks ago

If you're able to patch pytorch, patch it to dump the stack trace and let's see who is throwing an exception here

ppwwyyxx commented 3 weeks ago
#4 std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>::operator()() const from /usr/include/c++/11/bits/std_function.h:590
#5 std::function<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> ()>::operator()() const from /usr/include/c++/11/bits/std_function.h:590
#6 c10::NotImplementedError::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from /mnt/pytorch/c10/util/Exception.h:263
#7 c10::TensorImpl::storage() const from /mnt/pytorch/c10/core/TensorImpl.h:1032
#8 at::TensorBase::storage() const from /mnt/pytorch/aten/src/ATen/core/TensorBase.h:346
#9 c10::ivalue::Future::markCompleted(c10::IValue, std::optional<std::vector<c10::weak_intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> >, std::allocator<c10::weak_intrusive_ptr<c10::
StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> > > > >) from /mnt/pytorch/aten/src/ATen/core/ivalue_inl.h:927
#10 c10d::ProcessGroupNCCL::scatter(std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > >&, c
10d::ScatterOptions const&) from /mnt/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2648
#11 c10d::ops::(anonymous namespace)::scatter_CUDA(c10::ArrayRef<at::Tensor> const&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > > const&, c
10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, long, bool, long) from /mnt/pytorch/torch/csrc/distributed/c10d/Ops.cpp:385
#12 c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<std::tuple<std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > > (*)(c10::ArrayRef
<at::Tensor> const&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_tar
get_default_null_type<c10d::ProcessGroup> > const&, long, bool, long), std::tuple<std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > >,
c10::guts::typelist::typelist<c10::ArrayRef<at::Tensor> const&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > > const&, c10::intrusive_ptr<c10
d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, long, bool, long> >::operator()(c10::ArrayRef<at::Tensor> const&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::
allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > > const&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, long, bool, long) from /mnt/moonfs/users
/yuxin/home/projects/pytorch/aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:18
#13 torch::autograd::basicAutogradNotImplementedFallbackImpl(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) from /mnt/pytorch/torch/csrc/aut
ograd/autograd_not_implemented_fallback.cpp:148
#14 c10::BoxedKernel::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const from /mnt/pytorch/aten/src/ATen/core/boxing/BoxedKernel
_impl.h:41
#15 pybind11::cpp_function::cpp_function<c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> >, c10d::ProcessGroup, std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<std::vecto
r<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > >&, c10d::ScatterOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg,
 pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > (c10d::ProcessGroup::*)(std::vector<at::Tensor, std::allocator<at::T
ensor> >&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > >&, c10d::ScatterOptions const&), pybind11::name const&, pybind11::is_method const&,
pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(c10d::ProcessGroup*, std::vector<at::Tensor, std::allocator<at::Tensor
> >&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > >&, c10d::ScatterOptions const&)#1}::operator()(c10d::ProcessGroup*, std::vector<at::Tenso
r, std::allocator<at::Tensor> >&, std::vector<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::allocator<std::vector<at::Tensor, std::allocator<at::Tensor> > > >&, c10d::ScatterOptions const&) const from /mnt/moonfs/users/
yuxin/home/projects/pytorch/third_party/pybind11/include/pybind11/pybind11.h:154
#16 pybind11::cpp_function::dispatcher(_object*, _object*, _object*) from /mnt/pytorch/third_party/pybind11/include/pybind11/pybind11.h:987
#17 cfunction_call from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Objects/methodobject.c:543
#18 _PyObject_MakeTpCall from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Objects/call.c:215
#19 _PyObject_VectorcallTstate from /tmp/python-build.20240730053852.1383433/Python-3.10.14/./Include/cpython/abstract.h:112
#20 _PyObject_VectorcallTstate from /tmp/python-build.20240730053852.1383433/Python-3.10.14/./Include/cpython/abstract.h:114
#21 _PyEval_EvalFrame from /tmp/python-build.20240730053852.1383433/Python-3.10.14/./Include/internal/pycore_ceval.h:46
#22 do_call_core from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Python/ceval.c:5945
#23 _PyEval_EvalFrame from /tmp/python-build.20240730053852.1383433/Python-3.10.14/./Include/internal/pycore_ceval.h:46
#24 _PyObject_VectorcallTstate from /tmp/python-build.20240730053852.1383433/Python-3.10.14/./Include/cpython/abstract.h:114
#25 _PyEval_EvalFrame from /tmp/python-build.20240730053852.1383433/Python-3.10.14/./Include/internal/pycore_ceval.h:46
#26 run_eval_code_obj from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Python/pythonrun.c:1291
#27 pyrun_file from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Python/pythonrun.c:1208
#28 _PyRun_AnyFileObject from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Python/pythonrun.c:90
#29 pymain_run_file_obj from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Modules/main.c:353
#30 pymain_main from /tmp/python-build.20240730053852.1383433/Python-3.10.14/Modules/main.c:696
#31 __libc_start_call_main from ./csu/../sysdeps/nptl/libc_start_call_main.h:58
#32 __libc_start_main_impl from ./csu/../csu/libc-start.c:392

torch is built from v2.4.0

ezyang commented 2 weeks ago

A simple local fix for Future would be to check if a Tensor actually has a storage before trying to access it (which is raising the exception)

ppwwyyxx commented 1 week ago

In addition to fixing the exception in this case, is there a way to avoid symbolizing the stack except when printing the exception? e.g. maybe store the frame pointers, and do symbolize only when exception is caught in Python or when __str__ is called from Python.

Otherwise there would be more cases like this.

ezyang commented 1 week ago

Yes. In fbcode someone authored a diff to make the symbolization lazy, potentially something similar could be applied to oss

ot commented 1 week ago

@ezyang Should already be lazy, since https://github.com/pytorch/pytorch/commit/36e6f3b3390cba61a724fd0ede606573a84c7e9d . Maybe something is catching the exception and calling what()? Possibly at the C++/Python interface?