From 11e95473142ca8eb7b13ff6453e3ecaa49b90756 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Fri, 10 Jan 2025 20:22:04 +0800 Subject: [PATCH 1/2] feat: Add TraversalBuilder.getValuePresentedSource method for further optimization. --- .../stream/impl/TraversalBuilderSpec.scala | 96 ++++++++++++++++++- .../scaladsl/FlowFlattenMergeSpec.scala | 14 --- .../pekko/stream/impl/FailedSource.scala | 2 +- .../pekko/stream/impl/JavaStreamSource.scala | 2 +- .../pekko/stream/impl/TraversalBuilder.scala | 45 ++++++++- 5 files changed, 140 insertions(+), 19 deletions(-) diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala index 484ac0ed477..4c0f82e7552 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala @@ -17,9 +17,14 @@ import org.apache.pekko import pekko.NotUsed import pekko.stream._ import pekko.stream.impl.TraversalTestUtils._ -import pekko.stream.scaladsl.Keep +import pekko.stream.impl.fusing.IterableSource +import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource } +import pekko.stream.scaladsl.{ Keep, Source } +import pekko.util.OptionVal import pekko.testkit.PekkoSpec +import scala.concurrent.Future + class TraversalBuilderSpec extends PekkoSpec { "CompositeTraversalBuilder" must { @@ -447,4 +452,93 @@ class TraversalBuilderSpec extends PekkoSpec { } } + "find Source.single via TraversalBuilder" in { + TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a") + TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None) + + val singleSourceA = new SingleSource("a") + TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) + + TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None) + TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) + } + + "find Source.single via TraversalBuilder with getValuePresentedSource" in { + TraversalBuilder.getValuePresentedSource(Source.single("a")).get.asInstanceOf[SingleSource[String]].elem should ===( + "a") + val singleSourceA = new SingleSource("a") + TraversalBuilder.getValuePresentedSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) + + TraversalBuilder.getValuePresentedSource(Source.single("c").async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.empty via TraversalBuilder with getValuePresentedSource" in { + val emptySource = EmptySource + TraversalBuilder.getValuePresentedSource(emptySource) should be(OptionVal.Some(emptySource)) + + TraversalBuilder.getValuePresentedSource(Source.empty.async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.empty.mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) + } + + "find javadsl Source.empty via TraversalBuilder with getValuePresentedSource" in { + import pekko.stream.javadsl.Source + val emptySource = Source.empty() + TraversalBuilder.getValuePresentedSource(Source.empty()) should be(OptionVal.Some(emptySource)) + + TraversalBuilder.getValuePresentedSource(Source.empty().async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.empty().mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) + } + + "find Source.future via TraversalBuilder with getValuePresentedSource" in { + val future = Future.successful("a") + TraversalBuilder.getValuePresentedSource(Source.future(future)).get.asInstanceOf[FutureSource[String]].future should ===( + future) + val futureSourceA = new FutureSource(future) + TraversalBuilder.getValuePresentedSource(futureSourceA) should be(OptionVal.Some(futureSourceA)) + + TraversalBuilder.getValuePresentedSource(Source.future(future).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.future(future).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.iterable via TraversalBuilder with getValuePresentedSource" in { + val iterable = List("a") + TraversalBuilder.getValuePresentedSource(Source(iterable)).get.asInstanceOf[IterableSource[String]].elements should ===( + iterable) + val iterableSource = new IterableSource(iterable) + TraversalBuilder.getValuePresentedSource(iterableSource) should be(OptionVal.Some(iterableSource)) + + TraversalBuilder.getValuePresentedSource(Source(iterable).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source(iterable).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.javaStreamSource via TraversalBuilder with getValuePresentedSource" in { + val javaStream = java.util.stream.Stream.empty[String]() + TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream)).get + .asInstanceOf[JavaStreamSource[String, _]].open() shouldEqual javaStream + val streamSource = new JavaStreamSource(() => javaStream) + TraversalBuilder.getValuePresentedSource(streamSource) should be(OptionVal.Some(streamSource)) + + TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource( + Source.fromJavaStream(() => javaStream).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.failed via TraversalBuilder with getValuePresentedSource" in { + val failure = new RuntimeException("failure") + TraversalBuilder.getValuePresentedSource(Source.failed(failure)).get.asInstanceOf[FailedSource[String]] + .failure should ===( + failure) + val failedSourceA = new FailedSource(failure) + TraversalBuilder.getValuePresentedSource(failedSourceA) should be(OptionVal.Some(failedSourceA)) + + TraversalBuilder.getValuePresentedSource(Source.failed(failure).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.failed(failure).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + } diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala index 99b342d74ba..7d4cf5de379 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala @@ -19,8 +19,6 @@ import scala.concurrent.duration._ import org.apache.pekko import pekko.NotUsed import pekko.stream._ -import pekko.stream.impl.TraversalBuilder -import pekko.stream.impl.fusing.GraphStages.SingleSource import pekko.stream.stage.GraphStage import pekko.stream.stage.GraphStageLogic import pekko.stream.stage.OutHandler @@ -29,7 +27,6 @@ import pekko.stream.testkit.TestPublisher import pekko.stream.testkit.Utils.TE import pekko.stream.testkit.scaladsl.TestSink import pekko.testkit.TestLatch -import pekko.util.OptionVal import org.scalatest.exceptions.TestFailedException @@ -283,16 +280,5 @@ class FlowFlattenMergeSpec extends StreamSpec { probe.expectComplete() } - "find Source.single via TraversalBuilder" in { - TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a") - TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None) - - val singleSourceA = new SingleSource("a") - TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) - - TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None) - TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) - } - } } diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala index 4ab1c253556..b107857f860 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala @@ -22,7 +22,7 @@ import pekko.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } /** * INTERNAL API */ -@InternalApi private[pekko] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] { +@InternalApi private[pekko] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] { val out = Outlet[T]("FailedSource.out") override val shape = SourceShape(out) diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala index 74bba55d0a0..d05625fee13 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala @@ -23,7 +23,7 @@ import java.util.function.Consumer /** INTERNAL API */ @InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]]( - open: () => java.util.stream.BaseStream[T, S]) + val open: () => java.util.stream.BaseStream[T, S]) extends GraphStage[SourceShape[T]] { val out: Outlet[T] = Outlet("JavaStreamSource") diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala index 7ff61a2b3a8..24410e0f6d0 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala @@ -21,8 +21,8 @@ import pekko.annotation.{ DoNotInherit, InternalApi } import pekko.stream._ import pekko.stream.impl.StreamLayout.AtomicModule import pekko.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 } -import pekko.stream.impl.fusing.GraphStageModule -import pekko.stream.impl.fusing.GraphStages.SingleSource +import pekko.stream.impl.fusing.{ GraphStageModule, IterableSource } +import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource } import pekko.stream.scaladsl.Keep import pekko.util.OptionVal import pekko.util.unused @@ -380,12 +380,53 @@ import pekko.util.unused } } + /** + * Try to find `SingleSource` or wrapped such. This is used as a + * performance optimization in FlattenConcat and possibly other places. + * @since 1.2.0 + */ + @InternalApi def getValuePresentedSource[A >: Null]( + graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = { + def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match { + case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] | + _: FailedSource[_] => + true + case maybeEmpty if isEmptySource(maybeEmpty) => true + case _ => false + } + graph match { + case _ if isValuePresentedSource(graph) => OptionVal.Some(graph) + case _ => + graph.traversalBuilder match { + case l: LinearTraversalBuilder => + l.pendingBuilder match { + case OptionVal.Some(a: AtomicTraversalBuilder) => + a.module match { + case m: GraphStageModule[_, _] => + m.stage match { + case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) => + // It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize. + if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync) + OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) + else OptionVal.None + case _ => OptionVal.None + } + case _ => OptionVal.None + } + case _ => OptionVal.None + } + case _ => OptionVal.None + } + } + } + /** * Test if a Graph is an empty Source. */ def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match { case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true + case EmptySource => true case _ => false } From 4b92618f4cd10a261234f509b27073f8c9491441 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Fri, 10 Jan 2025 21:12:31 +0800 Subject: [PATCH 2/2] feat: Add flatMapConcat with parallelism support. --- .../pekko/stream/FlatMapConcatBenchmark.scala | 82 ++++- .../FlowFlatMapConcatParallelismSpec.scala | 204 ++++++++++ .../org/apache/pekko/stream/impl/Stages.scala | 1 + .../stream/impl/fusing/FlattenConcat.scala | 347 ++++++++++++++++++ .../apache/pekko/stream/javadsl/Flow.scala | 20 + .../apache/pekko/stream/javadsl/Source.scala | 18 + .../apache/pekko/stream/javadsl/SubFlow.scala | 20 + .../pekko/stream/javadsl/SubSource.scala | 20 + .../apache/pekko/stream/scaladsl/Flow.scala | 18 + 9 files changed, 729 insertions(+), 1 deletion(-) create mode 100644 stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala create mode 100644 stream/src/main/scala/org/apache/pekko/stream/impl/fusing/FlattenConcat.scala diff --git a/bench-jmh/src/main/scala/org/apache/pekko/stream/FlatMapConcatBenchmark.scala b/bench-jmh/src/main/scala/org/apache/pekko/stream/FlatMapConcatBenchmark.scala index a1f22186b3f..f8c0d49d5d6 100644 --- a/bench-jmh/src/main/scala/org/apache/pekko/stream/FlatMapConcatBenchmark.scala +++ b/bench-jmh/src/main/scala/org/apache/pekko/stream/FlatMapConcatBenchmark.scala @@ -16,7 +16,7 @@ package org.apache.pekko.stream import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit -import scala.concurrent.Await +import scala.concurrent.{ Await, Future } import scala.concurrent.duration._ import com.typesafe.config.ConfigFactory @@ -76,6 +76,16 @@ class FlatMapConcatBenchmark { awaitLatch(latch) } + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def sourceDotSingleP1(): Unit = { + val latch = new CountDownLatch(1) + + testSource.flatMapConcat(1, Source.single).runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + @Benchmark @OperationsPerInvocation(OperationsPerInvocation) def internalSingleSource(): Unit = { @@ -88,6 +98,18 @@ class FlatMapConcatBenchmark { awaitLatch(latch) } + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def internalSingleSourceP1(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(1, elem => new GraphStages.SingleSource(elem)) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + @Benchmark @OperationsPerInvocation(OperationsPerInvocation) def oneElementList(): Unit = { @@ -98,6 +120,64 @@ class FlatMapConcatBenchmark { awaitLatch(latch) } + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def oneElementListP1(): Unit = { + val latch = new CountDownLatch(1) + + testSource.flatMapConcat(1, n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def completedFuture(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(n => Source.future(Future.successful(n))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def completedFutureP1(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(1, n => Source.future(Future.successful(n))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def normalFuture(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(n => Source.future(Future(n)(system.dispatcher))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def normalFutureP1(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(1, n => Source.future(Future(n)(system.dispatcher))) + .runWith(new LatchSink(OperationsPerInvocation, latch)) + + awaitLatch(latch) + } + @Benchmark @OperationsPerInvocation(OperationsPerInvocation) def mapBaseline(): Unit = { diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala new file mode 100644 index 00000000000..58d8ebe061b --- /dev/null +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlatMapConcatParallelismSpec.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.stream.scaladsl + +import java.util.concurrent.ThreadLocalRandom +import java.util.concurrent.atomic.AtomicInteger +import org.apache.pekko +import pekko.pattern.FutureTimeoutSupport +import pekko.NotUsed +import pekko.stream._ +import pekko.stream.testkit.{ ScriptedTest, StreamSpec } +import pekko.stream.testkit.scaladsl.TestSink + +import java.util.Collections +import scala.annotation.switch +import scala.concurrent.duration.DurationInt +import scala.concurrent.Future +import scala.util.control.NoStackTrace + +class FlowFlatMapConcatParallelismSpec extends StreamSpec(""" + pekko.stream.materializer.initial-input-buffer-size = 2 + """) with ScriptedTest with FutureTimeoutSupport { + val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right) + + class BoomException extends RuntimeException("BOOM~~") with NoStackTrace + "A flatMapConcat" must { + + for (i <- 1 until 129) { + s"work with value presented sources with parallelism: $i" in { + Source( + List( + Source.empty[Int], + Source.single(1), + Source.empty[Int], + Source(List(2, 3, 4)), + Source.future(Future.successful(5)), + Source.lazyFuture(() => Future.successful(6)), + Source.future(after(1.millis)(Future.successful(7))))) + .flatMapConcat(i, identity) + .runWith(toSeq) + .futureValue should ===(1 to 7) + } + } + + def generateRandomValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = { + val seq = List.tabulate(nums) { _ => + val random = ThreadLocalRandom.current().nextInt(1, 10) + (random: @switch) match { + case 1 => Source.single(1) + case 2 => Source(List(1)) + case 3 => Source.fromJavaStream(() => Collections.singleton(1).stream()) + case 4 => Source.future(Future.successful(1)) + case 5 => Source.future(after(1.millis)(Future.successful(1))) + case _ => Source.empty[Int] + } + } + val sum = seq.filterNot(_.eq(Source.empty[Int])).size + (sum, seq) + } + + def generateSequencedValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = { + val seq = List.tabulate(nums) { index => + val random = ThreadLocalRandom.current().nextInt(1, 6) + (random: @switch) match { + case 1 => Source.single(index) + case 2 => Source(List(index)) + case 3 => Source.fromJavaStream(() => Collections.singleton(index).stream()) + case 4 => Source.future(Future.successful(index)) + case 5 => Source.future(after(1.millis)(Future.successful(index))) + case _ => throw new IllegalStateException("unexpected") + } + } + val sum = (0 until nums).sum + (sum, seq) + } + + for (i <- 1 until 129) { + s"work with generated value presented sources with parallelism: $i " in { + val (sum, sources @ _) = generateRandomValuePresentedSources(100000) + Source(sources) + .flatMapConcat(i, identity(_)) // scala 2.12 can't infer the type of identity + .runWith(Sink.seq) + .map(_.sum)(pekko.dispatch.ExecutionContexts.parasitic) + .futureValue shouldBe sum + } + } + + for (i <- 1 until 129) { + s"work with generated value sequenced sources with parallelism: $i " in { + val (sum, sources @ _) = generateSequencedValuePresentedSources(100000) + Source(sources) + .flatMapConcat(i, identity(_)) // scala 2.12 can't infer the type of identity + // check the order + .statefulMap(() => -1)((pre, current) => { + if (pre + 1 != current) { + throw new IllegalStateException(s"expected $pre + 1 == $current") + } + (current, current) + }, _ => None) + .runWith(Sink.seq) + .map(_.sum)(pekko.dispatch.ExecutionContexts.parasitic) + .futureValue shouldBe sum + } + } + + "work with value presented failed sources" in { + val ex = new BoomException + Source( + List( + Source.empty[Int], + Source.single(1), + Source.empty[Int], + Source(List(2, 3, 4)), + Source.future(Future.failed(ex)), + Source.lazyFuture(() => Future.successful(5)))) + .flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity) + .onErrorComplete[BoomException]() + .runWith(toSeq) + .futureValue should ===(1 to 4) + } + + "work with value presented sources when demands slow" in { + val prob = Source( + List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5)))) + .flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity) + .runWith(TestSink()) + + prob.request(1) + prob.expectNext(1) + prob.expectNoMessage(1.seconds) + prob.request(2) + prob.expectNext(2, 3) + prob.expectNoMessage(1.seconds) + prob.request(2) + prob.expectNext(4, 5) + prob.expectComplete() + } + + val parallelism = ThreadLocalRandom.current().nextInt(4, 65) + s"can do pre materialization when parallelism > 1, parallelism is $parallelism" in { + val materializationCounter = new AtomicInteger(0) + val prob = Source(1 to (parallelism * 3)) + .flatMapConcat( + parallelism, + value => { + Source + .lazySingle(() => { + materializationCounter.incrementAndGet() + value + }) + .buffer(1, overflowStrategy = OverflowStrategy.backpressure) + }) + .runWith(TestSink()) + + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe 0 + + prob.request(1) + prob.expectNext(1.seconds, 1) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe (parallelism + 1) + materializationCounter.set(0) + + prob.request(2) + prob.expectNextN(List(2, 3)) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe 2 + materializationCounter.set(0) + + prob.request(parallelism - 3) + prob.expectNextN(4 to parallelism) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe (parallelism - 3) + materializationCounter.set(0) + + prob.request(parallelism) + prob.expectNextN(parallelism + 1 to parallelism * 2) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe parallelism + materializationCounter.set(0) + + prob.request(parallelism) + prob.expectNextN(parallelism * 2 + 1 to parallelism * 3) + expectNoMessage(1.seconds) + materializationCounter.get() shouldBe 0 + prob.expectComplete() + } + } +} diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala index 1559a550dcc..904c76ff2de 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala @@ -93,6 +93,7 @@ import pekko.stream.Attributes._ val mergePreferred = name("mergePreferred") val mergePrioritized = name("mergePrioritized") val flattenMerge = name("flattenMerge") + val flattenConcat = name("flattenConcat") val recoverWith = name("recoverWith") val onErrorComplete = name("onErrorComplete") val broadcast = name("broadcast") diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/FlattenConcat.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/FlattenConcat.scala new file mode 100644 index 00000000000..87a839eaf9c --- /dev/null +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/FlattenConcat.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.stream.impl.fusing + +import org.apache.pekko +import pekko.annotation.InternalApi +import pekko.stream.scaladsl.Source +import pekko.stream.{ Attributes, FlowShape, Graph, Inlet, Outlet, SourceShape, SubscriptionWithCancelException } +import pekko.stream.impl.Stages.DefaultAttributes +import pekko.stream.impl.{ Buffer => BufferImpl, FailedSource, JavaStreamSource, TraversalBuilder } +import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource } +import pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import pekko.util.OptionVal + +import scala.concurrent.Future +import scala.util.{ Failure, Try } + +/** + * INTERNAL API + */ +@InternalApi +private[pekko] object FlattenConcat { + private sealed abstract class InflightSource[T] { + def hasNext: Boolean + def next(): T + def tryPull(): Unit + def cancel(cause: Throwable): Unit + def isClosed: Boolean + def hasFailed: Boolean = failure.isDefined + def failure: Option[Throwable] = None + def materialize(): Unit = () + } + + private final class InflightIteratorSource[T](iterator: Iterator[T]) extends InflightSource[T] { + override def hasNext: Boolean = iterator.hasNext + override def next(): T = iterator.next() + override def tryPull(): Unit = () + override def cancel(cause: Throwable): Unit = () + override def isClosed: Boolean = !hasNext + } + + private final class InflightCompletedFutureSource[T](result: Try[T]) extends InflightSource[T] { + private var _hasNext = result.isSuccess + override def hasNext: Boolean = _hasNext + override def next(): T = { + if (_hasNext) { + _hasNext = false + result.get + } else throw new NoSuchElementException("next called after completion") + } + override def hasFailed: Boolean = result.isFailure + override def failure: Option[Throwable] = result.failed.toOption + override def tryPull(): Unit = () + override def cancel(cause: Throwable): Unit = () + override def isClosed: Boolean = true + } + + private final class InflightPendingFutureSource[T](cb: InflightSource[T] => Unit) + extends InflightSource[T] + with (Try[T] => Unit) { + private var result: Try[T] = MapAsync.NotYetThere + private var consumed = false + override def apply(result: Try[T]): Unit = { + this.result = result + cb(this) + } + override def hasNext: Boolean = (result ne MapAsync.NotYetThere) && !consumed && result.isSuccess + override def next(): T = { + if (!consumed) { + consumed = true + result.get + } else throw new NoSuchElementException("next called after completion") + } + override def hasFailed: Boolean = (result ne MapAsync.NotYetThere) && result.isFailure + override def failure: Option[Throwable] = if (result eq MapAsync.NotYetThere) None else result.failed.toOption + override def tryPull(): Unit = () + override def cancel(cause: Throwable): Unit = () + override def isClosed: Boolean = consumed || hasFailed + } +} + +/** + * INTERNAL API + */ +@InternalApi +private[pekko] final class FlattenConcat[T, M](parallelism: Int) + extends GraphStage[FlowShape[Graph[SourceShape[T], M], T]] { + require(parallelism >= 1, "parallelism should >= 1") + private val in = Inlet[Graph[SourceShape[T], M]]("flattenConcat.in") + private val out = Outlet[T]("flattenConcat.out") + + override def initialAttributes: Attributes = DefaultAttributes.flattenConcat + override val shape: FlowShape[Graph[SourceShape[T], M], T] = FlowShape(in, out) + override def createLogic(enclosingAttributes: Attributes) = { + object FlattenConcatLogic extends GraphStageLogic(shape) with InHandler with OutHandler { + import FlattenConcat._ + // InflightSource[T] or SingleSource[T] + // AnyRef here to avoid lift the SingleSource[T] to InflightSource[T] + private var queue: BufferImpl[AnyRef] = _ + private val invokeCb: InflightSource[T] => Unit = + getAsyncCallback[InflightSource[T]](futureSourceCompleted).invoke + + override def preStart(): Unit = queue = BufferImpl(parallelism, enclosingAttributes) + + private def futureSourceCompleted(futureSource: InflightSource[T]): Unit = { + if (queue.peek() eq futureSource) { + if (isAvailable(out) && futureSource.hasNext) { + push(out, futureSource.next()) // TODO should filter out the `null` here? + if (futureSource.isClosed) { + handleCurrentSourceClosed(futureSource) + } + } else if (futureSource.isClosed) { + handleCurrentSourceClosed(futureSource) + } + } // else just ignore, it will be picked up by onPull + } + + override def onPush(): Unit = { + addSource(grab(in)) + // must try pull after addSource to avoid queue overflow + if (!queue.isFull) { // try to keep the maximum parallelism + tryPull(in) + } + } + + override def onUpstreamFinish(): Unit = if (queue.isEmpty) completeStage() + + override def onUpstreamFailure(ex: Throwable): Unit = { + super.onUpstreamFailure(ex) + cancelInflightSources(SubscriptionWithCancelException.NoMoreElementsNeeded) + } + + override def onPull(): Unit = { + // purge if possible + queue.peek() match { + case src: SingleSource[T] @unchecked => + push(out, src.elem) + removeSource() + case src: InflightSource[T] @unchecked => pushOut(src) + case null => // queue is empty + if (!hasBeenPulled(in)) { + tryPull(in) + } else if (isClosed(in)) { + completeStage() + } + case _ => throw new IllegalStateException("Should not reach here.") + } + } + + private def pushOut(src: InflightSource[T]): Unit = { + if (src.hasNext) { + push(out, src.next()) + if (src.isClosed) { + handleCurrentSourceClosed(src) + } + } else if (src.isClosed) { + handleCurrentSourceClosed(src) + } else { + src.tryPull() + } + } + + private def handleCurrentSourceClosed(source: InflightSource[T]): Unit = { + source.failure match { + case Some(cause) => onUpstreamFailure(cause) + case None => removeSource(source) + } + } + + override def onDownstreamFinish(cause: Throwable): Unit = { + super.onDownstreamFinish(cause) + cancelInflightSources(cause) + } + + private def cancelInflightSources(cause: Throwable): Unit = { + if (queue.nonEmpty) { + var source = queue.dequeue() + while ((source ne null) && (source.isInstanceOf[InflightSource[T] @unchecked])) { + source.asInstanceOf[InflightSource[T]].cancel(cause) + source = queue.dequeue() + } + } + } + + private def addSource(singleSource: SingleSource[T]): Unit = { + if (isAvailable(out) && queue.isEmpty) { + push(out, singleSource.elem) + } else { + queue.enqueue(singleSource) + } + } + + private def addSourceElements(iterator: Iterator[T]): Unit = { + val inflightSource = new InflightIteratorSource[T](iterator) + if (isAvailable(out) && queue.isEmpty) { + if (inflightSource.hasNext) { + push(out, inflightSource.next()) + if (inflightSource.hasNext) { + queue.enqueue(inflightSource) + } + } + } else { + queue.enqueue(inflightSource) + } + } + + private def addCompletedFutureElem(elem: Try[T]): Unit = { + if (isAvailable(out) && queue.isEmpty) { + elem match { + case scala.util.Success(value) => push(out, value) + case scala.util.Failure(ex) => onUpstreamFailure(ex) + } + } else { + queue.enqueue(new InflightCompletedFutureSource(elem)) + } + } + + private def addPendingFutureElem(future: Future[T]): Unit = { + val inflightSource = new InflightPendingFutureSource[T](invokeCb) + future.onComplete(inflightSource)(pekko.dispatch.ExecutionContexts.parasitic) + queue.enqueue(inflightSource) + } + + private def attachAndMaterializeSource(source: Graph[SourceShape[T], M]): Unit = { + object inflightSource extends InflightSource[T] { self => + private val sinkIn = new SubSinkInlet[T]("FlattenConcatSink") + private var upstreamFailure = Option.empty[Throwable] + sinkIn.setHandler(new InHandler { + override def onPush(): Unit = { + if (isAvailable(out) && (queue.peek() eq self)) { + push(out, sinkIn.grab()) + } + } + override def onUpstreamFinish(): Unit = if (!sinkIn.isAvailable) removeSource(self) + override def onUpstreamFailure(ex: Throwable): Unit = { + upstreamFailure = Some(ex) + // if it's the current emitting source, fail the stage + if (queue.peek() eq self) { + super.onUpstreamFailure(ex) + } // else just mark the source as failed + } + }) + + final override def materialize(): Unit = { + val graph = Source.fromGraph(source).to(sinkIn.sink) + interpreter.subFusingMaterializer.materialize(graph, defaultAttributes = enclosingAttributes) + } + final override def cancel(cause: Throwable): Unit = sinkIn.cancel(cause) + final override def hasNext: Boolean = sinkIn.isAvailable + final override def isClosed: Boolean = sinkIn.isClosed + final override def failure: Option[Throwable] = upstreamFailure + final override def next(): T = sinkIn.grab() + final override def tryPull(): Unit = if (!sinkIn.isClosed && !sinkIn.hasBeenPulled) sinkIn.pull() + } + if (isAvailable(out) && queue.isEmpty) { + // this is the first one, pull + inflightSource.tryPull() + } + queue.enqueue(inflightSource) + inflightSource.materialize() + } + + private def addSource(source: Graph[SourceShape[T], M]): Unit = { + TraversalBuilder.getValuePresentedSource(source) match { + case OptionVal.Some(graph) => + graph match { + case single: SingleSource[T] @unchecked => addSource(single) + case futureSource: FutureSource[T] @unchecked => + val future = futureSource.future + future.value match { + case Some(elem) => addCompletedFutureElem(elem) + case None => addPendingFutureElem(future) + } + case iterable: IterableSource[T] @unchecked => addSourceElements(iterable.elements.iterator) + case javaStream: JavaStreamSource[T, _] @unchecked => + import pekko.util.ccompat.JavaConverters._ + addSourceElements(javaStream.open().iterator.asScala) + case failed: FailedSource[T] @unchecked => addCompletedFutureElem(Failure(failed.failure)) + case maybeEmpty if TraversalBuilder.isEmptySource(maybeEmpty) => // Empty source is discarded + case _ => attachAndMaterializeSource(source) + } + case _ => attachAndMaterializeSource(source) + } + + } + + private def removeSource(): Unit = { + queue.dequeue() + pullIfNeeded() + } + + private def removeSource(source: InflightSource[T]): Unit = { + if (source eq queue.peek()) { + // only dequeue if it's the current emitting source + queue.dequeue() + pullIfNeeded() + } // not the head source, just ignore + } + + private def pullIfNeeded(): Unit = { + if (isClosed(in)) { + if (queue.isEmpty) { + completeStage() + } else { + tryPullNextSourceInQueue() + } + } else { + if (queue.nonEmpty) { + tryPullNextSourceInQueue() + } + if (!hasBeenPulled(in)) { + tryPull(in) + } + } + } + + private def tryPullNextSourceInQueue(): Unit = { + // pull the new emitting source + val nextSource = queue.peek() + if (nextSource.isInstanceOf[InflightSource[T] @unchecked]) { + nextSource.asInstanceOf[InflightSource[T]].tryPull() + } + } + + setHandlers(in, out, this) + } + + FlattenConcatLogic + } + + override def toString: String = s"FlattenConcat($parallelism)" +} diff --git a/stream/src/main/scala/org/apache/pekko/stream/javadsl/Flow.scala b/stream/src/main/scala/org/apache/pekko/stream/javadsl/Flow.scala index 1eb248c0c0e..b9cfcf0f10e 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/javadsl/Flow.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/javadsl/Flow.scala @@ -2742,6 +2742,26 @@ final class Flow[In, Out, Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends Gr def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Flow[In, T, Mat] = new Flow(delegate.flatMapConcat[T, M](x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * @since 1.2.0 + */ + def flatMapConcat[T, M]( + parallelism: Int, + f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Flow[In, T, Mat] = + new Flow(delegate.flatMapConcat[T, M](parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/stream/src/main/scala/org/apache/pekko/stream/javadsl/Source.scala b/stream/src/main/scala/org/apache/pekko/stream/javadsl/Source.scala index 7ad5bd0df73..ce6e2786fd0 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/javadsl/Source.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/javadsl/Source.scala @@ -4246,6 +4246,24 @@ final class Source[Out, Mat](delegate: scaladsl.Source[Out, Mat]) extends Graph[ def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Source[T, Mat] = new Source(delegate.flatMapConcat[T, M](x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * @since 1.2.0 + */ + def flatMapConcat[T, M](parallelism: Int, f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): Source[T, Mat] = + new Source(delegate.flatMapConcat[T, M](parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubFlow.scala b/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubFlow.scala index 79d7ea4e433..3a177042cb5 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubFlow.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubFlow.scala @@ -1783,6 +1783,26 @@ class SubFlow[In, Out, Mat]( def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubFlow[In, T, Mat] = new SubFlow(delegate.flatMapConcat(x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * @since 1.2.0 + */ + def flatMapConcat[T, M]( + parallelism: Int, + f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubFlow[In, T, Mat] = + new SubFlow(delegate.flatMapConcat(parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubSource.scala b/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubSource.scala index 3092262b7dc..ee574e89c20 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubSource.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/javadsl/SubSource.scala @@ -1757,6 +1757,26 @@ class SubSource[Out, Mat]( def flatMapConcat[T, M](f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubSource[T, Mat] = new SubSource(delegate.flatMapConcat(x => f(x))) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * @since 1.2.0 + */ + def flatMapConcat[T, M]( + parallelism: Int, + f: function.Function[Out, _ <: Graph[SourceShape[T], M]]): SubSource[T, Mat] = + new SubSource(delegate.flatMapConcat(parallelism, x => f(x))) + /** * Transform each input element into a `Source` of output elements that is * then flattened into the output stream by merging, where at most `breadth` diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala index e2db2e8daf9..4275a01ec2c 100755 --- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala @@ -2742,6 +2742,24 @@ trait FlowOps[+Out, +Mat] { */ def flatMapConcat[T, M](f: Out => Graph[SourceShape[T], M]): Repr[T] = map(f).via(new FlattenMerge[T, M](1)) + /** + * Transform each input element into a `Source` of output elements that is + * then flattened into the output stream by concatenation, + * fully consuming one Source after the other. + * `parallelism` can be used to config the max inflight sources, which will be materialized at the same time. + * + * '''Emits when''' a currently consumed substream has an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' upstream completes and all consumed substreams complete + * + * '''Cancels when''' downstream cancels + * @since 1.2.0 + */ + def flatMapConcat[T, M](parallelism: Int, f: Out => Graph[SourceShape[T], M]): Repr[T] = + map(f).via(new FlattenConcat[T, M](parallelism)) + /** * Alias for [[flatMapConcat]], added to enable for comprehensions. *