apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.67k stars 3.45k forks source link

[Bug] Race condition in TIR ComputationCache in transforms common subexpression elimination #17072

Open guillon opened 3 months ago

guillon commented 3 months ago

There is an issue due to a race in the TVM/TIR optimization passes when several distinct python threads are each compiling some operator.

The race occurs in the common subexpression elimination which uses a cache of expression/statements map which is declared static.

The cache should be attached to the build context or at least declared thread_local

The faulty cache declaration is located there: https://github.com/apache/tvm/blob/v0.16.0/src/tir/transforms/common_subexpr_elim_tools.h#L115

Expected behavior

Non faulty code generation on operator build when several python threads are each compiling different module/operators.

Actual behavior

When launching in parallel, for instance in a thread pool the creation and build of an operator, one may encounter a Segfault on HashTable insertion (race on iteration/insert).

This bug is flaky as it is highly dependent of the machine/number of threads/compiled workload.

Environment

TVM: v0.16.0 Target device: llvm host cpu LLVM: llvm-config-12 Kernel: Linux 5.10.0-27-amd Distro: Debian 5.10.205-2 (2023-12-31) x86_64 GNU/Linux Archi: 52 Cores Intel(R) Xeon(R) Gold 6230R CPU @ 2.10GHz

Steps to reproduce

The problem arises due to a statically declared cache in: https://github.com/apache/tvm/blob/v0.16.0/src/tir/transforms/common_subexpr_elim_tools.h#L115

A simple fix is to define the cache thread_local in this declaration and at the definition point. Though there may be some more elegant fix such as not using a static cache but a per compilation context cache.

Find there a test script which reproduce the issue by massively launching parallel build of a matmul operator: multithreaded-bug.py.gz

Launch it with (note that I use setarch -R in an attempt to be more reproducible), this was run on the machine described above (52 cores):

setarch -R python3 ./multithreaded-bug.py
...
Completed build: idx = 1369: built = Module(llvm, 15544c785298)
Completed build: idx = 1367: built = Module(llvm, 15541c73c518)
Completed build: idx = 1368: built = Module(llvm, 1554cc720c38)
Segmentation fault

Note that the bug is flaky, if not reproduced, try to play with the number of parallel threads and the total number of tasks, for instance:

# launch 100 parallel threads and execute 100000 total compilations
setarch -R python3 ./multithreaded-bug.py 100 100000
...

Also, one can play on the sys.setswitchinterval(0.00001) in the file, by lowering or increasing the context switch interval.

By applying the simple thread_local fix, the bug is not visible anymore. Ref to the attached patch file for the fix: 0001-Bugfix-TIR-Fix-race-on-ComputationCache.patch.gz

Triage

tqchen commented 3 months ago

thanks @guillon feel fre to send a PR

PandaTinker commented 3 months ago

Hi @tqchen , I am trying to add a static mutex and lock the mutex when we write to the cache. If you think this is ok, I can raise a PR.

guillon commented 2 months ago

Using synchronization may be more costly and error prone than a simple thread local cache as proposed above, you may experiment with both.

Though, actually I didn't understand the rationale for making à global expression cache. Shouldn't the cache be local to some function scope, hence a bare member of some parent object instead of a global state? Perhaps were there some experiments conducted at the time of this addition?