Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Race condition in TIR ComputationCache in transforms common subexpression elimination #17072

Open
guillon opened this issue Jun 7, 2024 · 3 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@guillon
Copy link

guillon commented Jun 7, 2024

There is an issue due to a race in the TVM/TIR optimization passes when several distinct python threads are each compiling some operator.

The race occurs in the common subexpression elimination which uses a cache of expression/statements map which is declared static.

The cache should be attached to the build context or at least declared thread_local

The faulty cache declaration is located there: https://github.com/apache/tvm/blob/v0.16.0/src/tir/transforms/common_subexpr_elim_tools.h#L115

Expected behavior

Non faulty code generation on operator build when several python threads are each compiling different module/operators.

Actual behavior

When launching in parallel, for instance in a thread pool the creation and build of an operator, one may encounter a Segfault on HashTable insertion (race on iteration/insert).

This bug is flaky as it is highly dependent of the machine/number of threads/compiled workload.

Environment

TVM: v0.16.0
Target device: llvm host cpu
LLVM: llvm-config-12
Kernel: Linux 5.10.0-27-amd
Distro: Debian 5.10.205-2 (2023-12-31) x86_64 GNU/Linux
Archi: 52 Cores Intel(R) Xeon(R) Gold 6230R CPU @ 2.10GHz

Steps to reproduce

The problem arises due to a statically declared cache in: https://github.com/apache/tvm/blob/v0.16.0/src/tir/transforms/common_subexpr_elim_tools.h#L115

A simple fix is to define the cache thread_local in this declaration and at the definition point.
Though there may be some more elegant fix such as not using a static cache but a per compilation context cache.

Find there a test script which reproduce the issue by massively launching parallel build of a matmul operator: multithreaded-bug.py.gz

Launch it with (note that I use setarch -R in an attempt to be more reproducible), this was run on the machine described above (52 cores):

setarch -R python3 ./multithreaded-bug.py
...
Completed build: idx = 1369: built = Module(llvm, 15544c785298)
Completed build: idx = 1367: built = Module(llvm, 15541c73c518)
Completed build: idx = 1368: built = Module(llvm, 1554cc720c38)
Segmentation fault

Note that the bug is flaky, if not reproduced, try to play with the number of parallel threads and the total number of tasks, for instance:

# launch 100 parallel threads and execute 100000 total compilations
setarch -R python3 ./multithreaded-bug.py 100 100000
...

Also, one can play on the sys.setswitchinterval(0.00001) in the file, by lowering or increasing the context switch interval.

By applying the simple thread_local fix, the bug is not visible anymore. Ref to the attached patch file for the fix:
0001-Bugfix-TIR-Fix-race-on-ComputationCache.patch.gz

Triage

  • flow:tir
  • tir:transform
@guillon guillon added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Jun 7, 2024
@tqchen
Copy link
Member

tqchen commented Jun 7, 2024

thanks @guillon feel fre to send a PR

@PandaTinker
Copy link

Hi @tqchen ,
I am trying to add a static mutex and lock the mutex when we write to the cache. If you think this is ok, I can raise a PR.

@guillon
Copy link
Author

guillon commented Jul 26, 2024

Using synchronization may be more costly and error prone than a simple thread local cache as proposed above, you may experiment with both.

Though, actually I didn't understand the rationale for making à global expression cache. Shouldn't the cache be local to some function scope, hence a bare member of some parent object instead of a global state?
Perhaps were there some experiments conducted at the time of this addition?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

3 participants