apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.76k stars 6.8k forks source link

Initial inspections for singleton thread safety in MXNet #17495

Open eric-haibin-lin opened 4 years ago

eric-haibin-lin commented 4 years ago

I roughly skim-through the code base and searched for classes with ::Get method. Some of these class objects are global singletons and have potential thread safety issues. I list the initial assessment:

thread safe classes

These classes are thread safe

is a thread_local object or contains thread local variables

The thread safety depends on the lifecycle of the thread. Are there alternative ways to avoid them?

classes that are not thread safe

These classes are not thread safe and may cause bugs:

I didn't look into C APIs. And there are lots of other thread_local objects spreading around in the code base as well:

c_api/c_api_profile.cc:static thread_local ProfilingThreadData thread_profiling_data;
imperative/imperative.cc:thread_local bool Imperative::is_train_ = false;
imperative/imperative.cc:thread_local bool Imperative::is_recording_ = false;
imperative/imperative.cc:thread_local bool Imperative::is_np_shape_thread_local_ = false;
operator/nn/activation.cu:  static thread_local CuDNNActivationOp<DType> cudnn_op;
operator/nn/batch_norm.cu:  static thread_local CuDNNBatchNormOp<DType> op;
operator/nn/convolution.cu:  static thread_local std::unordered_map<ConvSignature,
operator/nn/deconvolution.cu:  static thread_local std::unordered_map<DeconvSignature,
operator/nn/mkldnn/mkldnn_act.cc:  static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActForward, OpHash> fwds;
operator/nn/mkldnn/mkldnn_act.cc:  static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
operator/nn/mkldnn/mkldnn_base-inl.h:    static thread_local TmpMemMgr mgr;
operator/nn/mkldnn/mkldnn_base.cc:  static thread_local MKLDNNStream stream;
operator/nn/mkldnn/mkldnn_batch_norm-inl.h:  static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNForward, OpHash> fwds;
operator/nn/mkldnn/mkldnn_batch_norm-inl.h:  static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
operator/nn/mkldnn/mkldnn_concat-inl.h:  static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds;
operator/nn/mkldnn/mkldnn_convolution.cc:  static thread_local conv_fwd_map fwds;
operator/nn/mkldnn/mkldnn_convolution.cc:  static thread_local mkldnn_conv_bwd_map bwds;
operator/nn/mkldnn/mkldnn_deconvolution.cc:  static thread_local std::unordered_map<DeconvSignature, MKLDNNDeconvForward,
operator/nn/mkldnn/mkldnn_deconvolution.cc:  static thread_local std::unordered_map<MKLDNNDeconvSignature,
operator/nn/mkldnn/mkldnn_deconvolution.cc:  static thread_local std::unordered_map<MKLDNNDeconvSignature,
operator/nn/mkldnn/mkldnn_fully_connected.cc:  static thread_local std::unordered_map<MKLDNNFullyconSignature,
operator/nn/mkldnn/mkldnn_lrn-inl.h:  static thread_local std::unordered_map<MKLDNNLRNSignature,
operator/nn/mkldnn/mkldnn_pooling.cc:  static thread_local std::unordered_map<MKLDNNPoolingSignature,
operator/nn/mkldnn/mkldnn_reshape.cc:  static thread_local std::unordered_map<MKLDNNReshapeSignature,
operator/nn/mkldnn/mkldnn_rnn.cc:  static thread_local std::unordered_map<OpSignature,
operator/nn/mkldnn/mkldnn_slice.cc:  static thread_local std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
operator/nn/mkldnn/mkldnn_sum.cc:  static thread_local std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
operator/nn/mkldnn/mkldnn_transpose.cc:  static thread_local std::unordered_map<MKLDNNTransposeSignature,
operator/nn/pooling.cu:  static thread_local CuDNNPoolingOp<DType> op;
operator/nn/softmax_activation.cu:  static thread_local CuDNNSoftmaxActivationOp op;
operator/quantization/quantized_conv.cu:  static thread_local QuantizedConvOpInt8 op;
operator/quantization/quantized_pooling.cu:  static thread_local QuantizedCuDNNPoolingOp<int8_t> op;

Related issuie: https://github.com/apache/incubator-mxnet/issues/17612

eric-haibin-lin commented 4 years ago

There are a couple of thread local variables in the python level, too:

They suggest that if the frontend python thread is switched, some of these contexts are lost.

leezu commented 4 years ago

Specifically, we'd need to use something like https://docs.python.org/3/library/contextvars.html But Contextvar in it's current form is not sufficient, as it doesn't allow us to hook in C API calls. This feature should be added to Python standard library, and on the MXNet side we should use a patched version of Contextvar

leezu commented 4 years ago

@eric-haibin-lin I raised the question on the Python Bug Tracker. Python maintainers recommend to refactor MXNet to make state management pluggable / customizable.

Adding callbacks to contextvars is infeasible:

For extra context: context switches occur on every callback invocation in asyncio and there can be thousands of them per seconds (or even more). Adding any extra code to context switching code will noticeably degrade the performance.

Reference: https://bugs.python.org/issue39660#msg362370

eric-haibin-lin commented 4 years ago

@leezu thanks for the followup.