diff --git a/kotlinx-coroutines-core/jvm/test/flow/StateFlowStressTest.kt b/kotlinx-coroutines-core/jvm/test/flow/StateFlowStressTest.kt index 3739aef978..679fe3a088 100644 --- a/kotlinx-coroutines-core/jvm/test/flow/StateFlowStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/flow/StateFlowStressTest.kt @@ -1,9 +1,10 @@ package kotlinx.coroutines.flow -import kotlinx.coroutines.testing.* +import java.util.concurrent.atomic.AtomicLongArray +import kotlin.random.* import kotlinx.coroutines.* +import kotlinx.coroutines.testing.* import org.junit.* -import kotlin.random.* class StateFlowStressTest : TestBase() { private val nSeconds = 3 * stressTestMultiplier @@ -17,7 +18,7 @@ class StateFlowStressTest : TestBase() { fun stress(nEmitters: Int, nCollectors: Int) = runTest { pool = newFixedThreadPoolContext(nEmitters + nCollectors, "StateFlowStressTest") - val collected = Array(nCollectors) { LongArray(nEmitters) } + val collected = Array(nCollectors) { AtomicLongArray(nEmitters) } val collectors = launch { repeat(nCollectors) { collector -> launch(pool) { @@ -37,21 +38,18 @@ class StateFlowStressTest : TestBase() { } c[emitter] = current - }.take(batchSize).map { 1 }.sum() + }.take(batchSize).count() } while (cnt == batchSize) } } } - val emitted = LongArray(nEmitters) + val emitted = AtomicLongArray(nEmitters) val emitters = launch { repeat(nEmitters) { emitter -> launch(pool) { - var current = 1L while (true) { - state.value = current * nEmitters + emitter - emitted[emitter] = current - current++ - if (current % 1000 == 0L) yield() // make it cancellable + state.value = emitted.incrementAndGet(emitter) * nEmitters + emitter + if (emitted[emitter] % 1000 == 0L) yield() // make it cancellable } } } @@ -59,16 +57,20 @@ class StateFlowStressTest : TestBase() { for (second in 1..nSeconds) { delay(1000) val cs = collected.map { it.sum() } - println("$second: emitted=${emitted.sum()}, collected=${cs.minOrNull()}..${cs.maxOrNull()}") + println("$second: emitted=${emitted.sum()}, collected=${cs.min()}..${cs.max()}") } emitters.cancelAndJoin() collectors.cancelAndJoin() // make sure nothing hanged up - require(collected.all { c -> - c.withIndex().all { (emitter, current) -> current > emitted[emitter] / 2 } - }) + for (i in 0.. collected[i][j] > emitted[j] * 0.9 }) { + "collector #$i failed to collect any of the most recently emitted values" + } + } } + private fun AtomicLongArray.sum() = (0..