Skip to content

Commit

Permalink
Don't allocate threads on every dispatch in Native's thread pools (#3595
Browse files Browse the repository at this point in the history
)

Related to #3576
  • Loading branch information
dkhalanskyjb committed Feb 13, 2023
1 parent 32af157 commit e946cd7
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public expect abstract class CloseableCoroutineDispatcher() : CoroutineDispatche

/**
* Initiate the closing sequence of the coroutine dispatcher.
* After a successful call to [close], no new tasks will
* be accepted to be [dispatched][dispatch], but the previously dispatched tasks will be run.
* After a successful call to [close], no new tasks will be accepted to be [dispatched][dispatch].
* The previously-submitted tasks will still be run, but [close] is not guaranteed to wait for them to finish.
*
* Invocations of `close` are idempotent and thread-safe.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlinx.atomicfu.*
import kotlin.coroutines.*
import kotlin.test.*

class MultithreadedDispatcherStressTest {
val shared = atomic(0)

/**
* Tests that [newFixedThreadPoolContext] will not drop tasks when closed.
*/
@Test
fun testClosingNotDroppingTasks() {
repeat(7) {
shared.value = 0
val nThreads = it + 1
val dispatcher = newFixedThreadPoolContext(nThreads, "testMultiThreadedContext")
repeat(1_000) {
dispatcher.dispatch(EmptyCoroutineContext, Runnable {
shared.incrementAndGet()
})
}
dispatcher.close()
while (shared.value < 1_000) {
// spin.
// the test will hang here if the dispatcher drops tasks.
}
}
}
}
81 changes: 65 additions & 16 deletions kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package kotlinx.coroutines

import kotlinx.atomicfu.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
Expand Down Expand Up @@ -73,43 +74,91 @@ private class MultiWorkerDispatcher(
workersCount: Int
) : CloseableCoroutineDispatcher() {
private val tasksQueue = Channel<Runnable>(Channel.UNLIMITED)
private val availableWorkers = Channel<CancellableContinuation<Runnable>>(Channel.UNLIMITED)
private val workerPool = OnDemandAllocatingPool(workersCount) {
Worker.start(name = "$name-$it").apply {
executeAfter { workerRunLoop() }
}
}

/**
* (number of tasks - number of workers) * 2 + (1 if closed)
*/
private val tasksAndWorkersCounter = atomic(0L)

private inline fun Long.isClosed() = this and 1L == 1L
private inline fun Long.hasTasks() = this >= 2
private inline fun Long.hasWorkers() = this < 0

private fun workerRunLoop() = runBlocking {
// NB: we leverage tail-call optimization in this loop, do not replace it with
// .receive() without proper evaluation
for (task in tasksQueue) {
/**
* Any unhandled exception here will pass through worker's boundary and will be properly reported.
*/
task.run()
while (true) {
val state = tasksAndWorkersCounter.getAndUpdate {
if (it.isClosed() && !it.hasTasks()) return@runBlocking
it - 2
}
if (state.hasTasks()) {
// we promised to process a task, and there are some
tasksQueue.receive().run()
} else {
try {
suspendCancellableCoroutine {
val result = availableWorkers.trySend(it)
checkChannelResult(result)
}.run()
} catch (e: CancellationException) {
/** we are cancelled from [close] and thus will never get back to this branch of code,
but there may still be pending work, so we can't just exit here. */
}
}
}
}

// a worker that promised to be here and should actually arrive, so we wait for it in a blocking manner.
private fun obtainWorker(): CancellableContinuation<Runnable> =
availableWorkers.tryReceive().getOrNull() ?: runBlocking { availableWorkers.receive() }

override fun dispatch(context: CoroutineContext, block: Runnable) {
fun throwClosed(block: Runnable) {
throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block")
val state = tasksAndWorkersCounter.getAndUpdate {
if (it.isClosed())
throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block")
it + 2
}

if (!workerPool.allocate()) throwClosed(block) // Do not even try to send to avoid race

tasksQueue.trySend(block).onClosed {
throwClosed(block)
if (state.hasWorkers()) {
// there are workers that have nothing to do, let's grab one of them
obtainWorker().resume(block)
} else {
workerPool.allocate()
// no workers are available, we must queue the task
val result = tasksQueue.trySend(block)
checkChannelResult(result)
}
}

override fun close() {
val workers = workerPool.close()
tasksQueue.close()
tasksAndWorkersCounter.getAndUpdate { if (it.isClosed()) it else it or 1L }
val workers = workerPool.close() // no new workers will be created
while (true) {
// check if there are workers that await tasks in their personal channels, we need to wake them up
val state = tasksAndWorkersCounter.getAndUpdate {
if (it.hasWorkers()) it + 2 else it
}
if (!state.hasWorkers())
break
obtainWorker().cancel()
}
/*
* Here we cannot avoid waiting on `.result`, otherwise it will lead
* to a native memory leak, including a pthread handle.
*/
val requests = workers.map { it.requestTermination() }
requests.map { it.result }
}

private fun checkChannelResult(result: ChannelResult<*>) {
if (!result.isSuccess)
throw IllegalStateException(
"Internal invariants of $this were violated, please file a bug to kotlinx.coroutines",
result.exceptionOrNull()
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlinx.atomicfu.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.internal.*
import kotlin.native.concurrent.*
import kotlin.test.*

private class BlockingBarrier(val n: Int) {
val counter = atomic(0)
val wakeUp = Channel<Unit>(n - 1)
fun await() {
val count = counter.addAndGet(1)
if (count == n) {
repeat(n - 1) {
runBlocking {
wakeUp.send(Unit)
}
}
} else if (count < n) {
runBlocking {
wakeUp.receive()
}
}
}
}

class MultithreadedDispatchersTest {
/**
* Test that [newFixedThreadPoolContext] does not allocate more dispatchers than it needs to.
* Incidentally also tests that it will allocate enough workers for its needs. Otherwise, the test will hang.
*/
@Test
fun testNotAllocatingExtraDispatchers() {
val barrier = BlockingBarrier(2)
val lock = SynchronizedObject()
suspend fun spin(set: MutableSet<Worker>) {
repeat(100) {
synchronized(lock) { set.add(Worker.current) }
delay(1)
}
}
val dispatcher = newFixedThreadPoolContext(64, "test")
try {
runBlocking {
val encounteredWorkers = mutableSetOf<Worker>()
val coroutine1 = launch(dispatcher) {
barrier.await()
spin(encounteredWorkers)
}
val coroutine2 = launch(dispatcher) {
barrier.await()
spin(encounteredWorkers)
}
listOf(coroutine1, coroutine2).joinAll()
assertEquals(2, encounteredWorkers.size)
}
} finally {
dispatcher.close()
}
}
}

0 comments on commit e946cd7

Please sign in to comment.