diff --git a/slick/src/main/scala/org/apache/pekko/stream/connectors/slick/scaladsl/SlickWithTryResult.scala b/slick/src/main/scala/org/apache/pekko/stream/connectors/slick/scaladsl/SlickWithTryResult.scala new file mode 100644 index 000000000..36f3ea28f --- /dev/null +++ b/slick/src/main/scala/org/apache/pekko/stream/connectors/slick/scaladsl/SlickWithTryResult.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, which was derived from Akka. + */ + +/* + * Copyright (C) since 2016 Lightbend Inc. + */ +package org.apache.pekko.stream.connectors.slick.scaladsl + +import scala.concurrent.Future +import scala.util.Try +import org.apache.pekko +import pekko.NotUsed +import pekko.stream.scaladsl.Flow +import pekko.stream.scaladsl.Keep +import pekko.stream.scaladsl.Sink +import slick.dbio.DBIO + +/** + * Methods for interacting with relational databases using Slick and pekko-stream. + */ +object SlickWithTryResult { + + /** + * Scala API: creates a Flow that takes a stream of elements of + * type T, transforms each element to a SQL statement + * using the specified function, and then try executing + * those statements against the specified Slick database. + * It return Success[Int] or Failure[Throwable] + * if there was an exception during the execution. + * + * @param toStatement A function to produce the SQL statement to + * execute based on the current element. + * @param session The database session to use. + */ + def flowTry[T]( + toStatement: T => DBIO[Int])(implicit session: SlickSession): Flow[T, Try[Int], NotUsed] = flowTry(1, toStatement) + + /** + * Scala API: creates a Flow that takes a stream of elements of + * type T, transforms each element to a SQL statement + * using the specified function, and then executes + * those statements against the specified Slick database. + * It return Success[Int] or Failure[Throwable] + * if there was an exception during the execution. + * + * @param toStatement A function to produce the SQL statement to + * execute based on the current element. + * @param parallelism How many parallel asynchronous streams should be + * used to send statements to the database. Use a + * value of 1 for sequential execution. + * @param session The database session to use. + */ + def flowTry[T]( + parallelism: Int, + toStatement: T => DBIO[Int])(implicit session: SlickSession): Flow[T, Try[Int], NotUsed] = + flowTryWithPassThrough(parallelism, toStatement) + + /** + * Scala API: creates a Flow that takes a stream of elements of + * type T, transforms each element to a SQL statement + * using the specified function, then executes + * those statements against the specified Slick database + * and returns the statement result type Success[R] or + * Failure[Throwable] if there is an exception. + * + * @param toStatement A function to produce the SQL statement to + * execute based on the current element. + * @param session The database session to use. + */ + def flowTryWithPassThrough[T, R]( + toStatement: T => DBIO[R])(implicit session: SlickSession): Flow[T, Try[R], NotUsed] = + flowTryWithPassThrough(1, toStatement) + + /** + * Scala API: creates a Flow that takes a stream of elements of + * type T, transforms each element to a SQL statement + * using the specified function, then executes + * those statements against the specified Slick database + * and returns the statement result type Success[R] or + * Failure[Throwable] if there is an exception. + * + * @param toStatement A function to produce the SQL statement to + * execute based on the current element. + * @param parallelism How many parallel asynchronous streams should be + * used to send statements to the database. Use a + * value of 1 for sequential execution. + * @param session The database session to use. + */ + def flowTryWithPassThrough[T, R]( + parallelism: Int, + toStatement: T => DBIO[R])(implicit session: SlickSession): Flow[T, Try[R], NotUsed] = + Flow[T] + .mapAsync(parallelism) { t => + session.db.run(toStatement(t).asTry) + } + + /** + * Scala API: creates a Sink that takes a stream of elements of + * type T, transforms each element to a SQL statement + * using the specified function, and then executes + * those statements against the specified Slick database. + * + * @param toStatement A function to produce the SQL statement to + * execute based on the current element. + * @param session The database session to use. + */ + def sinkTry[T]( + toStatement: T => DBIO[Int])(implicit session: SlickSession): Sink[T, Future[Try[Int]]] = + flowTry[T](1, toStatement).toMat(Sink.last)(Keep.right) + + /** + * Scala API: creates a Sink that takes a stream of elements of + * type T, transforms each element to a SQL statement + * using the specified function, and then executes + * those statements against the specified Slick database. + * + * @param toStatement A function to produce the SQL statement to + * execute based on the current element. + * @param parallelism How many parallel asynchronous streams should be + * used to send statements to the database. Use a + * value of 1 for sequential execution. + * @param session The database session to use. + */ + def sinkTry[T]( + parallelism: Int, + toStatement: T => DBIO[Int])(implicit session: SlickSession): Sink[T, Future[Try[Int]]] = + flowTry[T](parallelism, toStatement).toMat(Sink.last)(Keep.right) +} diff --git a/slick/src/test/scala/docs/scaladsl/SlickWithTryResultSpec.scala b/slick/src/test/scala/docs/scaladsl/SlickWithTryResultSpec.scala new file mode 100644 index 000000000..9fdf031c7 --- /dev/null +++ b/slick/src/test/scala/docs/scaladsl/SlickWithTryResultSpec.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, which was derived from Akka. + */ + +/* + * Copyright (C) since 2016 Lightbend Inc. + */ + +package docs.scaladsl + +import org.apache.pekko +import pekko.Done +import pekko.actor.ActorSystem +import pekko.stream.connectors.slick.scaladsl.{ SlickSession, SlickWithTryResult } +import pekko.stream.connectors.testkit.scaladsl.LogCapturing +import pekko.stream.scaladsl._ +import pekko.testkit.TestKit + +import org.scalatest._ +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.matchers.must.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import slick.dbio.DBIOAction +import slick.jdbc.GetResult + +import scala.concurrent.duration._ +import scala.concurrent.{ Await, ExecutionContext, Future } +import scala.util.{ Failure, Success } + +class SlickWithTryResultSpec extends AnyWordSpec + with ScalaFutures + with BeforeAndAfterEach + with BeforeAndAfterAll + with Matchers + with LogCapturing { + // #init-mat + implicit val system: ActorSystem = ActorSystem() + // #init-mat + + // #init-session + implicit val session: SlickSession = SlickSession.forConfig("slick-h2") + // #init-session + + import session.profile.api._ + + case class User(id: Int, name: String) + class Users(tag: Tag) extends Table[(Int, String)](tag, "PEKKO_CONNECTORS_SLICK_SCALADSL_TEST_USERS") { + def id = column[Int]("ID", O.PrimaryKey) + def name = column[String]("NAME") + def * = (id, name) + } + + implicit val ec: ExecutionContext = system.dispatcher + implicit val defaultPatience: PatienceConfig = PatienceConfig(timeout = 3.seconds, interval = 50.millis) + implicit val getUserResult: GetResult[User] = GetResult(r => User(r.nextInt(), r.nextString())) + + val users = (1 to 40).map(i => User(i, s"Name$i")).toSet + val duplicateUser = scala.collection.immutable.Seq(users.head, users.head) + + val createTable = + sqlu"""CREATE TABLE PEKKO_CONNECTORS_SLICK_SCALADSL_TEST_USERS(ID INTEGER PRIMARY KEY, NAME VARCHAR(50))""" + val dropTable = sqlu"""DROP TABLE PEKKO_CONNECTORS_SLICK_SCALADSL_TEST_USERS""" + val selectAllUsers = sql"SELECT ID, NAME FROM PEKKO_CONNECTORS_SLICK_SCALADSL_TEST_USERS".as[User] + val typedSelectAllUsers = TableQuery[Users].result + + def insertUser(user: User): DBIO[Int] = + sqlu"INSERT INTO PEKKO_CONNECTORS_SLICK_SCALADSL_TEST_USERS VALUES(${user.id}, ${user.name})" + + def getAllUsersFromDb: Future[Set[User]] = + Source.fromPublisher(session.db.stream(selectAllUsers)).runWith(Sink.seq).map(_.toSet) + def populate(): Unit = { + val actions = users.map(insertUser) + + // This uses the standard Slick API exposed by the Slick session + // on purpose, just to double-check that inserting data through + // our Pekko connectors is equivalent to inserting it the Slick way. + session.db.run(DBIO.seq(actions.toList: _*)).futureValue + } + + override def beforeEach(): Unit = session.db.run(createTable).futureValue + override def afterEach(): Unit = session.db.run(dropTable).futureValue + + override def afterAll(): Unit = { + // #close-session + system.registerOnTermination(() => session.close()) + // #close-session + + TestKit.shutdownActorSystem(system) + } + + "SlickWithTryResult.flowTry(..)" must { + "insert 40 records into a table (no parallelism)" in { + val inserted = Source(users) + .via(SlickWithTryResult.flowTry(insertUser)) + .runWith(Sink.seq) + .futureValue + + inserted must have size users.size + inserted.toSet mustBe Set(Success(1)) + + getAllUsersFromDb.futureValue mustBe users + } + + "insert 40 records into a table (parallelism = 4)" in { + val inserted = Source(users) + .via(SlickWithTryResult.flowTry(parallelism = 4, insertUser)) + .runWith(Sink.seq) + .futureValue + + inserted must have size users.size + + getAllUsersFromDb.futureValue mustBe users + } + + "insert 40 records into a table faster using Flow.grouped (n = 10, parallelism = 4)" in { + val inserted = Source(users) + .grouped(10) + .via( + SlickWithTryResult.flowTry(parallelism = 4, + (group: Seq[User]) => group.map(insertUser).reduceLeft(_.andThen(_)))) + .runWith(Sink.seq) + .futureValue + + inserted must have size 4 + + getAllUsersFromDb.futureValue mustBe users + } + } + + "SlickWithTryResult.flowTry(..)" must { + "insert 40 records into a table with try (no parallelism)" in { + val inserted = Source(users) + .via(SlickWithTryResult.flowTry(insertUser)) + .runWith(Sink.last) + .futureValue + + inserted mustBe Success(1) + + getAllUsersFromDb.futureValue mustBe users + } + + "insert 40 records into a table with try (parallelism = 4)" in { + val inserted = Source(users) + .via(SlickWithTryResult.flowTry(parallelism = 4, insertUser)) + .runWith(Sink.last) + .futureValue + + inserted mustBe Success(1) + + getAllUsersFromDb.futureValue mustBe users + } + + "insert 40 records into a table with try faster using Flow.grouped (n = 10, parallelism = 4)" in { + val inserted = Source(users) + .grouped(10) + .via( + SlickWithTryResult.flowTry(parallelism = 4, + (group: Seq[User]) => group.map(insertUser).reduceLeft(_.andThen(_)))) + .runWith(Sink.seq) + .futureValue + + inserted must have size 4 + + getAllUsersFromDb.futureValue mustBe users + } + } + + "SlickWithTryResult.flowTryWithPassThrough(..)" must { + "inserting 40 records into a table with try (no parallelism)" in { + val inserted = Source(users) + .via(SlickWithTryResult.flowTryWithPassThrough { user => + insertUser(user).map(insertCount => (user, insertCount)) + }) + .runWith(Sink.seq) + .futureValue + + inserted must have size (users.size) + inserted.collect { case Success((u, _)) => u }.toSet mustBe users + inserted.collect { case Success((_, i)) => i }.toSet mustBe Set(1) + + getAllUsersFromDb.futureValue mustBe users + } + + "inserting 40 records into a table with try (parallelism = 4)" in { + val inserted = Source(users) + .via(SlickWithTryResult.flowTryWithPassThrough(parallelism = 4, + user => { + insertUser(user).map(insertCount => (user, insertCount)) + })) + .runWith(Sink.seq) + .futureValue + + inserted must have size users.size + + getAllUsersFromDb.futureValue mustBe users + } + + "inserting 40 records into a table with try faster using Flow.grouped (n = 10, parallelism = 4)" in { + val inserted = Source(users) + .grouped(10) + .via( + SlickWithTryResult.flowTryWithPassThrough( + parallelism = 4, + (group: Seq[User]) => { + val groupedDbActions = group.map(user => insertUser(user).map(insertCount => Seq((user, insertCount)))) + DBIOAction.fold(groupedDbActions, Seq.empty[(User, Int)])(_ ++ _) + })) + .collect { case Success(r) => r } + .runWith(Sink.fold(Seq.empty[(User, Int)])((a, b) => a ++ b)) + .futureValue + + inserted must have size users.size + inserted.map(_._1).toSet mustBe users + inserted.map(_._2).toSet mustBe Set(1) + + getAllUsersFromDb.futureValue mustBe users + } + + "not throw an exception, but return `[Failure]` when there is any from the db" in { + val inserted = Source(duplicateUser) + .via(SlickWithTryResult.flowTryWithPassThrough { user => + insertUser(user).map(insertCount => (user, insertCount)) + }) + .runWith(Sink.last) + .futureValue + + inserted mustBe a[Failure[_]] + + getAllUsersFromDb.futureValue mustBe Set(users.head) + } + + "kafka-example - try store documents and pass Responses with passThrough if successful" in { + + // #kafka-example + // We're going to pretend we got messages from kafka. + // After we've written them to a db with Slick, we want + // to commit the offset to Kafka + + case class KafkaOffset(offset: Int) + case class KafkaMessage[A](msg: A, offset: KafkaOffset) { + // map the msg and keep the offset + def map[B](f: A => B): KafkaMessage[B] = KafkaMessage(f(msg), offset) + } + + val messagesFromKafka = users.zipWithIndex.map { case (user, index) => KafkaMessage(user, KafkaOffset(index)) } + + var committedOffsets = List[KafkaOffset]() + + def commitToKafka(offset: KafkaOffset): Future[Done] = { + committedOffsets = committedOffsets :+ offset + Future.successful(Done) + } + + val f1 = Source(messagesFromKafka) // Assume we get this from Kafka + .via( // write to db with Slick + SlickWithTryResult.flowTryWithPassThrough { kafkaMessage => + insertUser(kafkaMessage.msg).map(insertCount => kafkaMessage.map(_ => insertCount)) + }) + .collect { case Success(x) => x } + .mapAsync(1) { kafkaMessage => + if (kafkaMessage.msg == 0) throw new Exception("Failed to write message to db") + // Commit to kafka + commitToKafka(kafkaMessage.offset) + } + .runWith(Sink.seq) + + Await.ready(f1, Duration.Inf) + + // Make sure all messages were committed to kafka + committedOffsets.map(_.offset).sorted mustBe (0 until users.size).toList + + // Assert that all docs were written to db + getAllUsersFromDb.futureValue mustBe users + } + } + + "SlickWithTryResult.sinkTry(..)" must { + "insert 40 records into a table (no parallelism)" in { + val inserted = Source(users) + .runWith(SlickWithTryResult.sinkTry(insertUser)) + .futureValue + + inserted mustBe Success(1) + getAllUsersFromDb.futureValue mustBe users + } + + "insert 40 records into a table (parallelism = 4)" in { + val inserted = Source(users) + .runWith(SlickWithTryResult.sinkTry(parallelism = 4, insertUser)) + .futureValue + + inserted mustBe Success(1) + getAllUsersFromDb.futureValue mustBe users + } + + "insert 40 records into a table faster using Flow.grouped (n = 10, parallelism = 4)" in { + val inserted = Source(users) + .grouped(10) + .runWith( + SlickWithTryResult.sinkTry(parallelism = 4, + (group: Seq[User]) => group.map(insertUser).reduceLeft(_.andThen(_)))) + .futureValue + + inserted mustBe Success(1) + getAllUsersFromDb.futureValue mustBe users + } + + "produce Failure(_) when inserting duplicate record" in { + val inserted = Source(duplicateUser) + .runWith(SlickWithTryResult.sinkTry(insertUser)) + .futureValue + + inserted mustBe a[Failure[_]] + getAllUsersFromDb.futureValue mustBe Set(users.head) + } + + "produce `Failure[_]` when inserting duplicate record (parallelism = 4)" in { + val records = scala.collection.immutable.Seq.empty ++ users :+ users.head + + val inserted = Source(records) + .runWith(SlickWithTryResult.sinkTry(parallelism = 4, insertUser)) + .futureValue + + records.length mustBe 41 + inserted mustBe a[Failure[_]] + getAllUsersFromDb.futureValue mustBe users + } + + } +}