Skip to content

Commit

Permalink
Fix race in Flow.asPublisher (Kotlin#2124)
Browse files Browse the repository at this point in the history
The race was leading to emitting more items via onNext than requested, the corresponding stress-test was added, too

Fixes Kotlin#2109
  • Loading branch information
elizarov authored Jul 16, 2020
1 parent 5e91dc4 commit d718970
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 35 deletions.
9 changes: 4 additions & 5 deletions reactive/kotlinx-coroutines-jdk9/test/FlowAsPublisherTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class FlowAsPublisherTest : TestBase() {
fun testErrorOnCancellationIsReported() {
expect(1)
flow<Int> {
emit(2)
try {
hang { expect(3) }
emit(2)
} finally {
expect(3)
throw TestException()
}
}.asPublisher().subscribe(object : JFlow.Subscriber<Int> {
Expand Down Expand Up @@ -52,12 +52,11 @@ class FlowAsPublisherTest : TestBase() {
expect(1)
flow<Int> {
emit(2)
hang { expect(3) }
}.asPublisher().subscribe(object : JFlow.Subscriber<Int> {
private lateinit var subscription: JFlow.Subscription

override fun onComplete() {
expect(4)
expect(3)
}

override fun onSubscribe(s: JFlow.Subscription?) {
Expand All @@ -74,6 +73,6 @@ class FlowAsPublisherTest : TestBase() {
expectUnreached()
}
})
finish(5)
finish(4)
}
}
46 changes: 21 additions & 25 deletions reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,12 @@ private class FlowAsPublisher<T : Any>(private val flow: Flow<T>) : Publisher<T>
public class FlowSubscription<T>(
@JvmField public val flow: Flow<T>,
@JvmField public val subscriber: Subscriber<in T>
) : Subscription, AbstractCoroutine<Unit>(Dispatchers.Unconfined, false) {
) : Subscription, AbstractCoroutine<Unit>(Dispatchers.Unconfined, true) {
private val requested = atomic(0L)
private val producer = atomic<CancellableContinuation<Unit>?>(null)
private val producer = atomic<Continuation<Unit>?>(createInitialContinuation())

override fun onStart() {
// This code wraps startCoroutineCancellable into continuation
private fun createInitialContinuation(): Continuation<Unit> = Continuation(coroutineContext) {
::flowProcessing.startCoroutineCancellable(this)
}

Expand All @@ -197,19 +198,17 @@ public class FlowSubscription<T>(
*/
private suspend fun consumeFlow() {
flow.collect { value ->
/*
* Flow is scopeless, thus if it's not active, its subscription was cancelled.
* No intermediate "child failed, but flow coroutine is not" states are allowed.
*/
coroutineContext.ensureActive()
if (requested.value <= 0L) {
// Emit the value
subscriber.onNext(value)
// Suspend if needed before requesting the next value
if (requested.decrementAndGet() <= 0) {
suspendCancellableCoroutine<Unit> {
producer.value = it
if (requested.value != 0L) it.resumeSafely()
}
} else {
// check for cancellation if we don't suspend
coroutineContext.ensureActive()
}
requested.decrementAndGet()
subscriber.onNext(value)
}
}

Expand All @@ -218,22 +217,19 @@ public class FlowSubscription<T>(
}

override fun request(n: Long) {
if (n <= 0) {
return
}
start()
requested.update { value ->
if (n <= 0) return
val old = requested.getAndUpdate { value ->
val newValue = value + n
if (newValue <= 0L) Long.MAX_VALUE else newValue
}
val producer = producer.getAndSet(null) ?: return
producer.resumeSafely()
}

private fun CancellableContinuation<Unit>.resumeSafely() {
val token = tryResume(Unit)
if (token != null) {
completeResume(token)
if (old <= 0L) {
assert(old == 0L)
// Emitter is not started yet or has suspended -- spin on race with suspendCancellableCoroutine
while(true) {
val producer = producer.getAndSet(null) ?: continue // spin if not set yet
producer.resume(Unit)
break
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class FlowAsPublisherTest : TestBase() {
fun testErrorOnCancellationIsReported() {
expect(1)
flow<Int> {
emit(2)
try {
hang { expect(3) }
emit(2)
} finally {
expect(3)
throw TestException()
}
}.asPublisher().subscribe(object : Subscriber<Int> {
Expand Down Expand Up @@ -52,12 +52,11 @@ class FlowAsPublisherTest : TestBase() {
expect(1)
flow<Int> {
emit(2)
hang { expect(3) }
}.asPublisher().subscribe(object : Subscriber<Int> {
private lateinit var subscription: Subscription

override fun onComplete() {
expect(4)
expect(3)
}

override fun onSubscribe(s: Subscription?) {
Expand All @@ -74,6 +73,6 @@ class FlowAsPublisherTest : TestBase() {
expectUnreached()
}
})
finish(5)
finish(4)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.reactive

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.Flow
import org.junit.*
import org.reactivestreams.*
import java.util.concurrent.*
import java.util.concurrent.atomic.*
import kotlin.coroutines.*
import kotlin.random.*

/**
* This stress-test is self-contained reproducer for the race in [Flow.asPublisher] extension
* that was originally reported in the issue
* [#2109](https://github.com/Kotlin/kotlinx.coroutines/issues/2109).
* The original reproducer used a flow that loads a file using AsynchronousFileChannel
* (that issues completion callbacks from multiple threads)
* and uploads it to S3 via Amazon SDK, which internally uses netty for I/O
* (which uses a single thread for connection-related callbacks).
*
* This stress-test essentially mimics the logic in multiple interacting threads: several emitter threads that form
* the flow and a single requesting thread works on the subscriber's side to periodically request more
* values when the number of items requested drops below the threshold.
*/
@Suppress("ReactiveStreamsSubscriberImplementation")
class PublisherRequestStressTest : TestBase() {
private val testDurationSec = 3 * stressTestMultiplier

// Original code in Amazon SDK uses 4 and 16 as low/high watermarks.
// There constants were chosen so that problem reproduces asap with particular this code.
private val minDemand = 8L
private val maxDemand = 16L

private val nEmitThreads = 4

private val emitThreadNo = AtomicInteger()

private val emitPool = Executors.newFixedThreadPool(nEmitThreads) { r ->
Thread(r, "PublisherRequestStressTest-emit-${emitThreadNo.incrementAndGet()}")
}

private val reqPool = Executors.newSingleThreadExecutor { r ->
Thread(r, "PublisherRequestStressTest-req")
}

private val nextValue = AtomicLong(0)

@After
fun tearDown() {
emitPool.shutdown()
reqPool.shutdown()
emitPool.awaitTermination(10, TimeUnit.SECONDS)
reqPool.awaitTermination(10, TimeUnit.SECONDS)
}

private lateinit var subscription: Subscription

@Test
fun testRequestStress() {
val expectedValue = AtomicLong(0)
val requestedTill = AtomicLong(0)
val completionLatch = CountDownLatch(1)
val callingOnNext = AtomicInteger()

val publisher = mtFlow().asPublisher()
var error = false

publisher.subscribe(object : Subscriber<Long> {
private var demand = 0L // only updated from reqPool

override fun onComplete() {
completionLatch.countDown()
}

override fun onSubscribe(sub: Subscription) {
subscription = sub
maybeRequestMore()
}

private fun maybeRequestMore() {
if (demand >= minDemand) return
val nextDemand = Random.nextLong(minDemand + 1..maxDemand)
val more = nextDemand - demand
demand = nextDemand
requestedTill.addAndGet(more)
subscription.request(more)
}

override fun onNext(value: Long) {
check(callingOnNext.getAndIncrement() == 0) // make sure it is not concurrent
// check for expected value
check(value == expectedValue.get())
// check that it does not exceed requested values
check(value < requestedTill.get())
val nextExpected = value + 1
expectedValue.set(nextExpected)
// send more requests from request thread
reqPool.execute {
demand-- // processed an item
maybeRequestMore()
}
callingOnNext.decrementAndGet()
}

override fun onError(ex: Throwable?) {
error = true
error("Failed", ex)
}
})
var prevExpected = -1L
for (second in 1..testDurationSec) {
if (error) break
Thread.sleep(1000)
val expected = expectedValue.get()
println("$second: expectedValue = $expected")
check(expected > prevExpected) // should have progress
prevExpected = expected
}
if (!error) {
subscription.cancel()
completionLatch.await()
}
}

private fun mtFlow(): Flow<Long> = flow {
while (currentCoroutineContext().isActive) {
emit(aWait())
}
}

private suspend fun aWait(): Long = suspendCancellableCoroutine { cont ->
emitPool.execute(Runnable {
cont.resume(nextValue.getAndIncrement())
})
}
}

0 comments on commit d718970

Please sign in to comment.