diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..cc0d1e9 Binary files /dev/null and b/.DS_Store differ diff --git a/.circleci/config.yml b/.circleci/config.yml index aab78c7..8db5019 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,6 +9,7 @@ jobs: docker: - image: swoopinc/spark-alchemy:adoptopenjdk-8u222-alpine-circleci-201909251709 + - image: swoopinc/postgres-hll:11 working_directory: ~/spark-alchemy @@ -34,6 +35,10 @@ jobs: keys: - v1-dependencies-{{ checksum "build.sbt" }} + - run: + name: Wait for postgres + command: dockerize -wait tcp://localhost:5432 -timeout 1m + # "cat /dev/null |" prevents sbt from running in interactive mode. One of many amazing # hacks get sbt working in a sane manner. - run: diff --git a/.gitignore b/.gitignore index 55fa90a..ff706da 100644 --- a/.gitignore +++ b/.gitignore @@ -19,5 +19,8 @@ project/plugins/project/ #Markdown editing .Ulysses-favorites.plist +#IntelliJ +.idea + metastore_db/ tmp/ diff --git a/dev/release-process.md b/DEVELOPMENT.md similarity index 80% rename from dev/release-process.md rename to DEVELOPMENT.md index 6ccec62..c1ead92 100644 --- a/dev/release-process.md +++ b/DEVELOPMENT.md @@ -1,4 +1,13 @@ -# Release Process +# Development + +## Local tests + +To run the `PostgresInteropTest` you need to have a working Docker +environment. On Mac, that means having Docker Desktop installed and +running. Then run `docker-compose up` in the repository root to start a +Postgres server. + +## Release Process 1. Develop new code on feature branches. @@ -12,7 +21,7 @@ * Publish the microsite to Github Pages * Create a new release on the [Github Project Release Page](https://github.com/swoop-inc/spark-alchemy/releases) -## Project Version Numbers +### Project Version Numbers * The `VERSION` file in the root of the project contains the version number that SBT will use for the `spark-alchemy` project. * The format should follow [Semantic Versioning](https://semver.org/) with the patch number matching the Travis CI build number when deploying new releases. diff --git a/VERSION b/VERSION index 10145c2..f0334e9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.4.0-SNAPSHOT +0.5.0-SNAPSHOT diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala index a86eeb7..9ab81ac 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala @@ -70,8 +70,8 @@ trait NativeFunctionRegistration extends FunctionRegistration { } /** - * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name. - */ + * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name. + */ protected def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = { val clazz = scala.reflect.classTag[T].runtimeClass val df = clazz.getAnnotation(classOf[ExpressionDescription]) diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala index 48621ab..b95c45f 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala @@ -5,14 +5,16 @@ import org.apache.spark.sql.Column /** Convenience trait to use HyperLogLog functions with the same error consistently. - * Spark's own [[sql.functions.approx_count_distinct()]] as well as the granular HLL - * [[HLLFunctions.hll_init()]] and [[HLLFunctions.hll_init_collection()]] will be - * automatically parameterized by [[BoundHLL.hllError]]. - */ + * Spark's own [[sql.functions.approx_count_distinct()]] as well as the granular HLL + * [[HLLFunctions.hll_init()]] and [[HLLFunctions.hll_init_collection()]] will be + * automatically parameterized by [[BoundHLL.hllError]]. + */ trait BoundHLL extends Serializable { def hllError: Double + def functions: HLLFunctions + def approx_count_distinct(col: Column): Column = sql.functions.approx_count_distinct(col, hllError) @@ -42,11 +44,16 @@ trait BoundHLL extends Serializable { def hll_init_collection_agg(columnName: String): Column = functions.hll_init_collection_agg(columnName, hllError) - } object BoundHLL { - def apply(error: Double): BoundHLL = new BoundHLL { + /** + * @param error maximum estimation error allowed + * @param impl only affects the hll_* functions, not Spark's built-ins + */ + def apply(error: Double)(implicit impl: Implementation = null): BoundHLL = new BoundHLL { def hllError: Double = error + + val functions = HLLFunctions.withImpl(impl) } } diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala index 82cb008..ca865b6 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala @@ -7,9 +7,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * Hash function for Spark data values that is suitable for cardinality counting. Unlike Spark's built-in hashing, - * it differentiates between different data types and accounts for nulls. - */ + * Hash function for Spark data values that is suitable for cardinality counting. Unlike Spark's built-in hashing, + * it differentiates between different data types and accounts for nulls. + */ abstract class CardinalityHashFunction extends InterpretedHashFunction { override def hash(value: Any, dataType: DataType, seed: Long): Long = { diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala index 74f5866..7bbde1a 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala @@ -13,7 +13,7 @@ object HLLFunctionRegistration extends NativeFunctionRegistration { expression[HyperLogLogMerge]("hll_merge"), expression[HyperLogLogRowMerge]("hll_row_merge"), expression[HyperLogLogCardinality]("hll_cardinality"), - expression[HyperLogLogIntersectionCardinality]("hll_intersect_cardinality") + expression[HyperLogLogIntersectionCardinality]("hll_intersect_cardinality"), + expression[HyperLogLogConvert]("hll_convert") ) - } diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala index f61ad91..7b58d0b 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala @@ -1,32 +1,80 @@ package com.swoop.alchemy.spark.expressions.hll -import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import com.swoop.alchemy.spark.expressions.WithHelper -import org.apache.spark.sql.Column +import com.swoop.alchemy.spark.expressions.hll.HyperLogLogBase.{nameToImpl, resolveImplementation} +import com.swoop.alchemy.spark.expressions.hll.Implementation.{AGGREGATE_KNOWLEDGE, AGKN, STREAM_LIB, STRM} +import org.apache.spark.sql.EncapsulationViolator.createAnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.aggregate.{HyperLogLogPlusPlus, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus.validateDoubleLiteral +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, ExpressionDescription, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, ExpressionDescription, Literal, UnaryExpression} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, SparkSession} -trait HyperLogLogInit extends Expression { +trait HyperLogLogBase { + def impl: Implementation +} + +object HyperLogLogBase { + def resolveImplementation(exp: Expression): Implementation = exp match { + case null => resolveImplementation + case _ => nameToImpl(exp, "last argument") + } + + def resolveImplementation(exp: String): Implementation = exp match { + case null => resolveImplementation + case s => nameToImpl(s.toString) + } + + def resolveImplementation(implicit impl: Implementation = null): Implementation = + if (impl != null) + impl + else + SparkSession.getActiveSession + .flatMap(_.conf.getOption(IMPLEMENTATION_CONFIG_KEY)) + .map(nameToImpl) + .getOrElse(StreamLib) + + def nameToImpl(name: Expression, argName: String = "argument"): Implementation = name match { + case Literal(s: Any, StringType) => + nameToImpl(s.toString) + case _ => + throw createAnalysisException( + s"The $argName must be a string argument (${Implementation.OPTIONS.mkString("/")}) designating one of the implementation options." + ) + } + + + def nameToImpl(name: String): Implementation = name match { + case STRM => StreamLib + case STREAM_LIB => StreamLib + case AGKN => AgKn + case AGGREGATE_KNOWLEDGE => AgKn + case s => throw createAnalysisException( + s"The HLL implementation choice '$s' is not one of the valid options: ${Implementation.OPTIONS.mkString(", ")}" + ) + } +} + +trait HyperLogLogInit extends Expression with HyperLogLogBase { def relativeSD: Double // This formula for `p` came from org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus:93 protected[this] val p: Int = Math.ceil(2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)).toInt - require(p >= 4, "HLL++ requires at least 4 bits for addressing. Use a lower error, at most 39%.") + require(p >= 4, "HLL requires at least 4 bits for addressing. Use a lower error, at most 39%.") override def dataType: DataType = BinaryType def child: Expression - def offer(value: Any, buffer: HyperLogLogPlus): HyperLogLogPlus + def offer(value: Any, buffer: Instance): Instance - def createHll = new HyperLogLogPlus(p, 0) + def createHll: Instance = impl.createHll(p) def hash(value: Any, dataType: DataType, seed: Long): Long = CardinalityXxHash64Function.hash(value, dataType, seed) @@ -37,15 +85,12 @@ trait HyperLogLogInit extends Expression { } } - trait HyperLogLogSimple extends HyperLogLogInit { - def offer(value: Any, buffer: HyperLogLogPlus): HyperLogLogPlus = { - buffer.offerHashed(hash(value, child.dataType)) - buffer + def offer(value: Any, buffer: Instance): Instance = { + buffer.offer(hash(value, child.dataType)) } } - trait HyperLogLogCollection extends HyperLogLogInit { override def checkInputDataTypes(): TypeCheckResult = @@ -54,19 +99,19 @@ trait HyperLogLogCollection extends HyperLogLogInit { case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array and map input.") } - def offer(value: Any, buffer: HyperLogLogPlus): HyperLogLogPlus = { + def offer(value: Any, buffer: Instance): Instance = { value match { case arr: ArrayData => child.dataType match { case ArrayType(et, _) => arr.foreach(et, (_, v) => { - if (v != null) buffer.offerHashed(hash(v, et)) + if (v != null) buffer.offer(hash(v, et)) }) case dt => throw new UnsupportedOperationException(s"Unknown DataType for ArrayData: $dt") } case map: MapData => child.dataType match { case MapType(kt, vt, _) => map.foreach(kt, vt, (k, v) => { - buffer.offerHashed(hash(v, vt, hash(k, kt))) // chain key and value hash + buffer.offer(hash(v, vt, hash(k, kt))) // chain key and value hash }) case dt => throw new UnsupportedOperationException(s"Unknown DataType for MapData: $dt") } @@ -77,17 +122,16 @@ trait HyperLogLogCollection extends HyperLogLogInit { } } - trait HyperLogLogInitSingle extends UnaryExpression with HyperLogLogInit with CodegenFallback { override def nullable: Boolean = child.nullable override def nullSafeEval(value: Any): Any = - offer(value, createHll).getBytes + offer(value, createHll).serialize } trait HyperLogLogInitAgg extends NullableSketchAggregation with HyperLogLogInit { - override def update(buffer: Option[HyperLogLogPlus], inputRow: InternalRow): Option[HyperLogLogPlus] = { + override def update(buffer: Option[Instance], inputRow: InternalRow): Option[Instance] = { val value = child.eval(inputRow) if (value != null) { Some(offer(value, buffer.getOrElse(createHll))) @@ -97,22 +141,21 @@ trait HyperLogLogInitAgg extends NullableSketchAggregation with HyperLogLogInit } } -trait NullableSketchAggregation extends TypedImperativeAggregate[Option[HyperLogLogPlus]] { +trait NullableSketchAggregation extends TypedImperativeAggregate[Option[Instance]] with HyperLogLogBase { - override def createAggregationBuffer(): Option[HyperLogLogPlus] = None + override def createAggregationBuffer(): Option[Instance] = None - override def merge(buffer: Option[HyperLogLogPlus], other: Option[HyperLogLogPlus]): Option[HyperLogLogPlus] = + override def merge(buffer: Option[Instance], other: Option[Instance]): Option[Instance] = (buffer, other) match { case (Some(a), Some(b)) => - a.addAll(b) - Some(a) + Some(a.merge(b)) case (a, None) => a case (None, b) => b case _ => None } - override def eval(buffer: Option[HyperLogLogPlus]): Any = - buffer.map(_.getBytes).orNull + override def eval(buffer: Option[Instance]): Any = + buffer.map(_.serialize).orNull def child: Expression @@ -120,32 +163,37 @@ trait NullableSketchAggregation extends TypedImperativeAggregate[Option[HyperLog override def nullable: Boolean = child.nullable - override def serialize(hll: Option[HyperLogLogPlus]): Array[Byte] = - hll.map(_.getBytes).orNull + override def serialize(hll: Option[Instance]): Array[Byte] = + hll.map(_.serialize).orNull - override def deserialize(bytes: Array[Byte]): Option[HyperLogLogPlus] = - if (bytes == null) None else Option(HyperLogLogPlus.Builder.build(bytes)) + override def deserialize(bytes: Array[Byte]): Option[Instance] = + if (bytes == null) None else Option(impl.deserialize(bytes)) } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * This version creates a composable "sketch" for each input row. - * All expression values treated as simple values. - * - * @param child to estimate the cardinality of. - * @param relativeSD defines the maximum estimation error allowed - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * This function creates a composable "sketch" for each input row. + * All expression values are treated as simple values. + * + * @param child to estimate the cardinality of. + * @param relativeSD defines the maximum estimation error allowed + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ - _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + _FUNC_(expr[, relativeSD[, implName]]) - Returns the composable "sketch" by HyperLogLog++. `relativeSD` defines the maximum estimation error allowed. """) case class HyperLogLogInitSimple( override val child: Expression, - override val relativeSD: Double = 0.05) + override val relativeSD: Double = 0.05, + override val impl: Implementation = resolveImplementation) extends HyperLogLogInitSingle with HyperLogLogSimple { def this(child: Expression) = this(child, relativeSD = 0.05) @@ -153,34 +201,45 @@ case class HyperLogLogInitSimple( def this(child: Expression, relativeSD: Expression) = { this( child = child, - relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD) + relativeSD = validateDoubleLiteral(relativeSD) ) } - override def prettyName: String + def this(child: Expression, relativeSD: Expression, implName: Expression) = { + this( + child = child, + relativeSD = validateDoubleLiteral(relativeSD), + impl = resolveImplementation(implName) + ) + } - = "hll_init" + override def prettyName: String = "hll_init" } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * This version combines all input in each aggregate group into a single "sketch". - * All expression values treated as simple values. - * - * @param child to estimate the cardinality of - * @param relativeSD defines the maximum estimation error allowed - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * This version combines all input in each aggregate group into a single "sketch". + * All expression values treated as simple values. + * + * @param child to estimate the cardinality of + * @param relativeSD defines the maximum estimation error allowed + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ - _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + _FUNC_(expr[, relativeSD[, implName]]) - Returns the composable "sketch" by HyperLogLog++. `relativeSD` defines the maximum estimation error allowed. """) case class HyperLogLogInitSimpleAgg( override val child: Expression, override val relativeSD: Double = 0.05, + override val impl: Implementation = resolveImplementation, override val mutableAggBufferOffset: Int = 0, override val inputAggBufferOffset: Int = 0) extends HyperLogLogInitAgg with HyperLogLogSimple { @@ -190,9 +249,15 @@ case class HyperLogLogInitSimpleAgg( def this(child: Expression, relativeSD: Expression) = { this( child = child, - relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) + relativeSD = validateDoubleLiteral(relativeSD)) + } + + def this(child: Expression, relativeSD: Expression, implName: Expression) = { + this( + child = child, + relativeSD = validateDoubleLiteral(relativeSD), + impl = resolveImplementation(implName) + ) } override def withNewMutableAggBufferOffset(newOffset: Int): HyperLogLogInitSimpleAgg = @@ -205,23 +270,28 @@ case class HyperLogLogInitSimpleAgg( } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * This version creates a composable "sketch" for each input row. - * Expression must be is a collection (Array, Map), and collection elements are treated as individual values. - * - * @param child to estimate the cardinality of. - * @param relativeSD defines the maximum estimation error allowed - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * This version creates a composable "sketch" for each input row. + * Expression must be is a collection (Array, Map), and collection elements are treated as individual values. + * + * @param child to estimate the cardinality of. + * @param relativeSD defines the maximum estimation error allowed + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ - _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + _FUNC_(expr[, relativeSD[, implName]]) - Returns the composable "sketch" by HyperLogLog++. `relativeSD` defines the maximum estimation error allowed. """) case class HyperLogLogInitCollection( override val child: Expression, - override val relativeSD: Double = 0.05) + override val relativeSD: Double = 0.05, + override val impl: Implementation = resolveImplementation) extends HyperLogLogInitSingle with HyperLogLogCollection { def this(child: Expression) = this(child, relativeSD = 0.05) @@ -229,32 +299,46 @@ case class HyperLogLogInitCollection( def this(child: Expression, relativeSD: Expression) = { this( child = child, - relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD) + relativeSD = validateDoubleLiteral(relativeSD) + ) + } + + def this(child: Expression, relativeSD: Expression, implName: Expression) = { + this( + child = child, + relativeSD = validateDoubleLiteral(relativeSD), + impl = resolveImplementation(implName) ) } + override def prettyName: String = "hll_init_collection" } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * This version combines all input in each aggregate group into a a single "sketch". - * If `expr` is a collection (Array, Map), collection elements are treated as individual values. - * - * @param child to estimate the cardinality of - * @param relativeSD defines the maximum estimation error allowed - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * This version combines all input in each aggregate group into a a single "sketch". + * If `expr` is a collection (Array, Map), collection elements are treated as individual values. + * + * @param child to estimate the cardinality of + * @param relativeSD defines the maximum estimation error allowed + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ - _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + _FUNC_(expr[, relativeSD[, implName]]) - Returns the composable "sketch" by HyperLogLog++. `relativeSD` defines the maximum estimation error allowed. """) case class HyperLogLogInitCollectionAgg( - child: Expression, - relativeSD: Double = 0.05, + override val child: Expression, + override val relativeSD: Double = 0.05, + override val impl: Implementation = resolveImplementation, override val mutableAggBufferOffset: Int = 0, override val inputAggBufferOffset: Int = 0) extends HyperLogLogInitAgg with HyperLogLogCollection { @@ -263,10 +347,17 @@ case class HyperLogLogInitCollectionAgg( def this(child: Expression, relativeSD: Expression) = { this( - child = child, - relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) + child, + validateDoubleLiteral(relativeSD) + ) + } + + def this(child: Expression, relativeSD: Expression, implName: Expression) = { + this( + child, + validateDoubleLiteral(relativeSD), + resolveImplementation(implName) + ) } override def withNewMutableAggBufferOffset(newOffset: Int): HyperLogLogInitCollectionAgg = @@ -280,33 +371,40 @@ case class HyperLogLogInitCollectionAgg( /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * This version aggregates the "sketches" into a single merged "sketch" that represents the union of the constituents. - * - * @param child "sketch" to merge - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * This version aggregates the "sketches" into a single merged "sketch" that represents the union of the constituents. + * + * @param child "sketch" to merge + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ - _FUNC_(expr) - Returns the merged HLL++ sketch. + _FUNC_(expr[, implName]) - Returns the merged HLL sketch. """) case class HyperLogLogMerge( child: Expression, - override val mutableAggBufferOffset: Int, - override val inputAggBufferOffset: Int) + override val impl: Implementation = resolveImplementation, + override val mutableAggBufferOffset: Int = 0, + override val inputAggBufferOffset: Int = 0) extends NullableSketchAggregation { - def this(child: Expression) = this(child, 0, 0) + def this(child: Expression) = this(child, resolveImplementation) + + def this(child: Expression, implName: Expression) = this(child, resolveImplementation(implName)) - override def update(buffer: Option[HyperLogLogPlus], inputRow: InternalRow): Option[HyperLogLogPlus] = { + override def update(buffer: Option[Instance], inputRow: InternalRow): Option[Instance] = { val value = child.eval(inputRow) if (value != null) { val hll = value match { - case b: Array[Byte] => HyperLogLogPlus.Builder.build(b) + case b: Array[Byte] => impl.deserialize(b) case _ => throw new IllegalStateException(s"$prettyName only supports Array[Byte]") } - buffer.map(_.merge(hll).asInstanceOf[HyperLogLogPlus]) + buffer.map(_.merge(hll)) .orElse(Option(hll)) } else { buffer @@ -332,20 +430,39 @@ case class HyperLogLogMerge( } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * This version merges multiple "sketches" in one row into a single field. - * @see HyperLogLogMerge - * - * @param children "sketch" row fields to merge - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * This version merges multiple "sketches" in one row into a single field. + * + * @see [[HyperLogLogMerge]] + * @param children "sketch" row fields to merge + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ -_FUNC_(expr) - Returns the merged HLL++ sketch. -""") -case class HyperLogLogRowMerge(children: Seq[Expression]) - extends Expression with ExpectsInputTypes with CodegenFallback { + _FUNC_(expr[, implName]) - Returns the merged HLL sketch. + """) +case class HyperLogLogRowMerge( + override val children: Seq[Expression], + override val impl: Implementation = resolveImplementation) + extends Expression with ExpectsInputTypes with CodegenFallback with HyperLogLogBase { + + def this(children: Seq[Expression]) = this({ + assert(children.nonEmpty, s"function requires at least one argument") + children + }.last match { + case Literal(s: Any, StringType) => children.init + case _ => children + }, + children.last match { + case Literal(s: Any, StringType) => resolveImplementation(s.toString) + case _ => resolveImplementation + } + ) require(children.nonEmpty, s"$prettyName requires at least one argument.") @@ -360,15 +477,15 @@ case class HyperLogLogRowMerge(children: Seq[Expression]) override def eval(input: InternalRow): Any = { val flatInputs = children.flatMap(_.eval(input) match { case null => None - case b: Array[Byte] => Some(HyperLogLogPlus.Builder.build(b)) + case b: Array[Byte] => Some(impl.deserialize(b)) case _ => throw new IllegalStateException(s"$prettyName only supports Array[Byte]") }) if (flatInputs.isEmpty) null else { val acc = flatInputs.head - flatInputs.tail.foreach(acc.addAll) - acc.getBytes + flatInputs.tail.foreach(acc.merge) + acc.serialize } } @@ -376,18 +493,29 @@ case class HyperLogLogRowMerge(children: Seq[Expression]) } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * Returns the estimated cardinality of an HLL++ "sketch" - * - * @param child HLL++ "sketch" - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * Returns the estimated cardinality of an HLL "sketch" + * + * @param child HLL "sketch" + * @param impl HLL implementation to use + */ @ExpressionDescription( usage = """ - _FUNC_(sketch) - Returns the estimated cardinality of the binary representation produced by HyperLogLog++. + _FUNC_(sketch[, implName]) - Returns the estimated cardinality of the binary representation produced by HyperLogLog++. """) -case class HyperLogLogCardinality(override val child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback { +case class HyperLogLogCardinality( + override val child: Expression, + override val impl: Implementation = resolveImplementation +) extends UnaryExpression with ExpectsInputTypes with CodegenFallback with HyperLogLogBase { + + def this(child: Expression) = this(child, resolveImplementation) + + def this(child: Expression, implName: Expression) = this(child, resolveImplementation(implName)) override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -395,34 +523,49 @@ case class HyperLogLogCardinality(override val child: Expression) extends UnaryE override def nullSafeEval(input: Any): Long = { val data = input.asInstanceOf[Array[Byte]] - HyperLogLogPlus.Builder.build(data).cardinality() + impl.deserialize(data).cardinality } override def prettyName: String = "hll_cardinality" } /** - * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. - * - * Computes a merged (unioned) sketch and uses the fact that |A intersect B| = (|A| + |B|) - |A union B| to estimate - * the intersection cardinality of two HLL++ "sketches". - * - * @see HyperLogLogRowMerge - * @see HyperLogLogCardinality - * - * @param left HLL++ "sketch" - * @param right HLL++ "sketch" - * - * @return the estimated intersection cardinality (0 if one sketch is null, but null if both are) - */ + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * The underlying [[Implementation]] can be changed by setting a [[IMPLEMENTATION_CONFIG_KEY configuration value]] + * in the [[SparkSession]] to the implementation name, or passing it as an argument. + * + * Computes a merged (unioned) sketch and uses the fact that |A intersect B| = (|A| + |B|) - |A union B| to estimate + * the intersection cardinality of two HLL "sketches". + * + * @note Error in the cardinality of the intersection is determined by the cardinality of the constituent sketches, not + * the cardinality of the intersection itself (i.e. it may be much larger than naively expected) - + * https://research.neustar.biz/2012/12/17/hll-intersections-2/ + * + * @see HyperLogLogRowMerge + * @see HyperLogLogCardinality + * @param left HLL "sketch" + * @param right HLL "sketch" + * @param impl HLL implementation to use + * @return the estimated intersection cardinality (0 if one sketch is null, but null if both are) + */ @ExpressionDescription( usage = """ - _FUNC_(sketchL, sketchR) - Returns the estimated intersection cardinality of the binary representations produced by - HyperLogLog++. Computes a merged (unioned) sketch and uses the fact that |A intersect B| = (|A| + |B|) - |A union B|. + _FUNC_(sketchL, sketchR[, implName]) - Returns the estimated intersection cardinality of the binary representations produced by + HyperLogLog. Computes a merged (unioned) sketch and uses the fact that |A intersect B| = (|A| + |B|) - |A union B|. Returns null if both sketches are null, but 0 if only one is """) -case class HyperLogLogIntersectionCardinality(override val left: Expression, override val right: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +case class HyperLogLogIntersectionCardinality( + override val left: Expression, + override val right: Expression, + override val impl: Implementation = resolveImplementation +) extends BinaryExpression with ExpectsInputTypes with CodegenFallback with HyperLogLogBase { + + def this(left: Expression, right: Expression) = this(left, right, resolveImplementation) + + def this(left: Expression, right: Expression, implName: Expression) = + this(left, right, resolveImplementation(implName)) override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType) @@ -435,13 +578,13 @@ case class HyperLogLogIntersectionCardinality(override val left: Expression, ove val rightValue = right.eval(input) if (leftValue != null && rightValue != null) { - val leftHLL = HyperLogLogPlus.Builder.build(leftValue.asInstanceOf[Array[Byte]]) - val rightHLL = HyperLogLogPlus.Builder.build(rightValue.asInstanceOf[Array[Byte]]) + val leftHLL = impl.deserialize(leftValue.asInstanceOf[Array[Byte]]) + val rightHLL = impl.deserialize(rightValue.asInstanceOf[Array[Byte]]) - val leftCount = leftHLL.cardinality() - val rightCount = rightHLL.cardinality() - leftHLL.addAll(rightHLL) - val unionCount = leftHLL.cardinality() + val leftCount = leftHLL.cardinality + val rightCount = rightHLL.cardinality + leftHLL.merge(rightHLL) + val unionCount = leftHLL.cardinality // guarantee a non-negative result despite the approximate nature of the counts math.max((leftCount + rightCount) - unionCount, 0L) @@ -457,92 +600,180 @@ case class HyperLogLogIntersectionCardinality(override val left: Expression, ove override def prettyName: String = "hll_intersect_cardinality" } -object functions extends HLLFunctions + +/** + * HyperLogLog (HLL) is a state of the art cardinality estimation algorithm with multiple implementations available. + * + * This function converts between implementations. Currently the only conversion supported is from the StreamLib + * implementation (`"STRM"` or `"STREAM_LIB"`) to the Aggregate Knowledge implementation (`"AGKN"` or + * `"AGGREGATE_KNOWLEDGE"`). + * + * @note Converted values CANNOT be merged with unconverted ("native") values of type that they've been converted to. + * This is because the different implementations use different parts of the hashed valued to construct the HLL + * (effectively equivalent to using different hash functions). + * @param child HLL "sketch" + * @param from string name of implementation type of the given sketch + * @param to string name of implementation type to convert the given sketch to + */ + +@ExpressionDescription( + usage = + """ + _FUNC_(sketch, implNameFrom, implNameTo) - Converts between implementations. + """) +case class HyperLogLogConvert( + override val child: Expression, + from: Implementation, + to: Implementation +) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { + + def this(hll: Expression, fromName: Expression, toName: Expression) = { + this(hll, nameToImpl(fromName, "second argument"), nameToImpl(toName, "third argument")) + } + + + def this(hll: Expression, fromName: String, toName: String) = { + this(hll, nameToImpl(fromName), nameToImpl(toName)) + } + + override def dataType: DataType = BinaryType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def nullSafeEval(hll: Any): Any = (from, to) match { + case (StreamLib, AgKn) => strmToAgkn(hll.asInstanceOf[Array[Byte]]) + case _ => throw new IllegalArgumentException( + "HLL conversion is currently only supported from STREAM_LIB to AGGREGATE_KNOWLEDGE" + ) + } + + override def prettyName: String = "hll_convert" +} + +object functions extends HLLFunctions { + val impl: Implementation = null +} trait HLLFunctions extends WithHelper { - def hll_init(e: Column, relativeSD: Double): Column = withExpr { - HyperLogLogInitSimple(e.expr, relativeSD) + implicit def impl: Implementation + + def hll_init(e: Column, relativeSD: Double, implName: String = null): Column = withExpr { + HyperLogLogInitSimple(e.expr, relativeSD, resolveImplementation(implName)) } def hll_init(columnName: String, relativeSD: Double): Column = hll_init(col(columnName), relativeSD) + def hll_init(columnName: String, relativeSD: Double, implName: String): Column = + hll_init(col(columnName), relativeSD, implName) + def hll_init(e: Column): Column = withExpr { - HyperLogLogInitSimple(e.expr) + HyperLogLogInitSimple(e.expr, impl = resolveImplementation) } def hll_init(columnName: String): Column = hll_init(col(columnName)) - def hll_init_collection(e: Column, relativeSD: Double): Column = withExpr { - HyperLogLogInitCollection(e.expr, relativeSD) + def hll_init_collection(e: Column, relativeSD: Double, implName: String = null): Column = withExpr { + HyperLogLogInitCollection(e.expr, relativeSD, resolveImplementation(implName)) } def hll_init_collection(columnName: String, relativeSD: Double): Column = hll_init_collection(col(columnName), relativeSD) + def hll_init_collection(columnName: String, relativeSD: Double, implName: String): Column = + hll_init_collection(col(columnName), relativeSD, implName) + def hll_init_collection(e: Column): Column = withExpr { - HyperLogLogInitCollection(e.expr) + HyperLogLogInitCollection(e.expr, impl = resolveImplementation) } def hll_init_collection(columnName: String): Column = hll_init_collection(col(columnName)) - def hll_init_agg(e: Column, relativeSD: Double): Column = withAggregateFunction { - HyperLogLogInitSimpleAgg(e.expr, relativeSD) + def hll_init_agg(e: Column, relativeSD: Double, implName: String = null): Column = withAggregateFunction { + HyperLogLogInitSimpleAgg(e.expr, relativeSD, resolveImplementation(implName)) } def hll_init_agg(columnName: String, relativeSD: Double): Column = hll_init_agg(col(columnName), relativeSD) + def hll_init_agg(columnName: String, relativeSD: Double, implName: String): Column = + hll_init_agg(col(columnName), relativeSD, implName) + def hll_init_agg(e: Column): Column = withAggregateFunction { - HyperLogLogInitSimpleAgg(e.expr) + HyperLogLogInitSimpleAgg(e.expr, impl = resolveImplementation) } def hll_init_agg(columnName: String): Column = hll_init_agg(col(columnName)) - def hll_init_collection_agg(e: Column, relativeSD: Double): Column = withAggregateFunction { - HyperLogLogInitCollectionAgg(e.expr, relativeSD) + def hll_init_collection_agg(e: Column, relativeSD: Double, implName: String = null): Column = withAggregateFunction { + HyperLogLogInitCollectionAgg(e.expr, relativeSD, resolveImplementation(implName)) } def hll_init_collection_agg(columnName: String, relativeSD: Double): Column = hll_init_collection_agg(col(columnName), relativeSD) + def hll_init_collection_agg(columnName: String, relativeSD: Double, implName: String): Column = + hll_init_collection_agg(col(columnName), relativeSD, implName) + def hll_init_collection_agg(e: Column): Column = withAggregateFunction { - HyperLogLogInitCollectionAgg(e.expr) + HyperLogLogInitCollectionAgg(e.expr, impl = resolveImplementation) } def hll_init_collection_agg(columnName: String): Column = hll_init_collection_agg(col(columnName)) - def hll_merge(e: Column): Column = withAggregateFunction { - HyperLogLogMerge(e.expr, 0, 0) + def hll_merge(e: Column, implName: String = null): Column = withAggregateFunction { + HyperLogLogMerge(e.expr, resolveImplementation(implName)) } def hll_merge(columnName: String): Column = hll_merge(col(columnName)) + def hll_merge(columnName: String, implName: String): Column = + hll_merge(col(columnName), implName) + def hll_row_merge(es: Column*): Column = withExpr { - HyperLogLogRowMerge(es.map(_.expr)) + HyperLogLogRowMerge(es.map(_.expr), resolveImplementation) } - // split arguments to avoid collision w/ above after erasure - def hll_row_merge(columnName: String, columnNames: String*): Column = - hll_row_merge((columnName +: columnNames).map(col): _*) + def hll_row_merge(implName: String, es: Column*): Column = withExpr { + HyperLogLogRowMerge(es.map(_.expr), resolveImplementation(implName)) + } - def hll_cardinality(e: Column): Column = withExpr { - HyperLogLogCardinality(e.expr) + def hll_cardinality(e: Column, implName: String = null): Column = withExpr { + HyperLogLogCardinality(e.expr, resolveImplementation(implName)) } def hll_cardinality(columnName: String): Column = hll_cardinality(col(columnName)) - def hll_intersect_cardinality(l: Column, r: Column): Column = withExpr { - HyperLogLogIntersectionCardinality(l.expr, r.expr) + def hll_cardinality(columnName: String, implName: String): Column = + hll_cardinality(col(columnName), implName) + + def hll_intersect_cardinality(l: Column, r: Column, implName: String = null): Column = withExpr { + HyperLogLogIntersectionCardinality(l.expr, r.expr, resolveImplementation(implName)) } def hll_intersect_cardinality(leftColumnName: String, rightColumnName: String): Column = hll_intersect_cardinality(col(leftColumnName), col(rightColumnName)) + + def hll_intersect_cardinality(leftColumnName: String, rightColumnName: String, implName: String): Column = + hll_intersect_cardinality(col(leftColumnName), col(rightColumnName), implName) + + def hll_convert(hll: Column, from: String, to: String): Column = withExpr { + HyperLogLogConvert(hll.expr, nameToImpl(from), nameToImpl(to)) + } + + def hll_convert(columnName: String, from: String, to: String): Column = + hll_convert(col(columnName), from, to) +} + +object HLLFunctions { + def withImpl(hllImpl: Implementation): HLLFunctions = new HLLFunctions { + override implicit def impl: Implementation = hllImpl + } } diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/Implementation.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/Implementation.scala new file mode 100644 index 0000000..05f7bde --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/Implementation.scala @@ -0,0 +1,103 @@ +package com.swoop.alchemy.spark.expressions.hll + +import com.clearspring.analytics.stream +import com.clearspring.analytics.stream.cardinality.{HyperLogLogPlus, RegisterSet} +import net.agkn.hll.HLL +import net.agkn.hll.util.BitVector + +/** + * Wrapper for instances of different HLL implementations + * + * @note `offer`` and `merge`` may just mutate and return the same underlying HLL instance + */ +sealed trait Instance { + def offer(hashedValue: Long): Instance + + def merge(other: Instance): Instance + + def serialize: Array[Byte] + + def cardinality: Long +} + +class AgKnInstance(val hll: net.agkn.hll.HLL) extends Instance { + override def offer(hashedValue: Long): Instance = { + hll.addRaw(hashedValue) + this + } + + override def merge(other: Instance): Instance = { + if (other.isInstanceOf[AgKnInstance]) { + hll.union(other.asInstanceOf[AgKnInstance].hll) + this + } else + throw new IllegalArgumentException(s"Type of HLL to merge does not match this HLL (${hll.getClass.getName})") + } + + override def serialize: Array[Byte] = hll.toBytes + + def cardinality: Long = hll.cardinality() +} + +class StreamLibInstance(val hll: stream.cardinality.HyperLogLogPlus) extends Instance { + override def offer(hashedValue: Long): Instance = { + hll.offerHashed(hashedValue) + this + } + + override def merge(other: Instance): Instance = { + if (other.isInstanceOf[StreamLibInstance]) { + hll.addAll(other.asInstanceOf[StreamLibInstance].hll) + this + } else + throw new IllegalArgumentException(s"Type of HLL to merge does not match this HLL (${hll.getClass.getName})") + } + + override def serialize: Array[Byte] = hll.getBytes + + def cardinality: Long = hll.cardinality() +} + +/** + * Option for the underlying HLL implementation used by all functions + */ +trait Implementation { + def createHll(p: Int): Instance + + def deserialize(bytes: Array[Byte]): Instance +} + +object Implementation { + val AGKN = "AGKN" + val STRM = "STRM" + val AGGREGATE_KNOWLEDGE = "AGGREGATE_KNOWLEDGE" + val STREAM_LIB = "STREAM_LIB" + val OPTIONS = Seq(AGKN, STRM, AGGREGATE_KNOWLEDGE, STREAM_LIB) + + + // TODO @peter debugging tools, remove: + def registerSetToSeq(r: RegisterSet): Seq[Int] = + for (i <- 0 until r.count) yield r.get(i) + + def bitVectorToSeq(b: BitVector): Seq[Long] = { + val i = b.registerIterator() + new Iterator[Long] { + def hasNext = i.hasNext + + def next = i.next() + }.toArray + } +} + +case object AgKn extends Implementation { + override def createHll(p: Int) = new AgKnInstance(new HLL(p, 5)) + + override def deserialize(bytes: Array[Byte]) = new AgKnInstance(HLL.fromBytes(bytes)) +} + +case object StreamLib extends Implementation { + override def createHll(p: Int) = new StreamLibInstance(new HyperLogLogPlus(p, 0)) + + override def deserialize(bytes: Array[Byte]) = new StreamLibInstance(HyperLogLogPlus.Builder.build(bytes)) +} + diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/package.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/package.scala new file mode 100644 index 0000000..77bc1df --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/package.scala @@ -0,0 +1,62 @@ +package com.swoop.alchemy.spark.expressions + +import java.io.{ByteArrayInputStream, DataInputStream} + +import com.clearspring.analytics.stream +import com.clearspring.analytics.stream.cardinality.RegisterSet +import com.clearspring.analytics.util.{Bits, Varint} +import net.agkn.hll.HLL +import net.agkn.hll.serialization.{HLLMetadata, SchemaVersionOne} +import net.agkn.hll.util.BitVector + +package object hll { + val IMPLEMENTATION_CONFIG_KEY = "com.swoop.alchemy.hll.implementation" + + def strmToAgkn(from: stream.cardinality.HyperLogLogPlus): net.agkn.hll.HLL = { + HLL.fromBytes(strmToAgkn(from.getBytes)) + } + + def strmToAgkn(from: Array[Byte]): Array[Byte] = { + var bais = new ByteArrayInputStream(from) + var oi = new DataInputStream(bais) + val version = oi.readInt + // the new encoding scheme includes a version field + // that is always negative. + if (version >= 0) { + throw new UnsupportedOperationException("conversion is only supported for the new style encoding scheme") + } + + val p = Varint.readUnsignedVarInt(oi) + val sp = Varint.readUnsignedVarInt(oi) + val formatType = Varint.readUnsignedVarInt(oi) + if (formatType != 0) { + throw new UnsupportedOperationException("conversion is only supported for non-sparse representation") + } + + val size = Varint.readUnsignedVarInt(oi) + val longArrayBytes = new Array[Byte](size) + oi.readFully(longArrayBytes) + val registerSet = new RegisterSet(Math.pow(2, p).toInt, Bits.getBits(longArrayBytes)) + val bitVector = new BitVector(RegisterSet.REGISTER_SIZE, registerSet.count) + + for (i <- 0 until registerSet.count) bitVector.setRegister(i, registerSet.get(i)) + val schemaVersion = new SchemaVersionOne + val serializer = + schemaVersion.getSerializer(net.agkn.hll.HLLType.FULL, RegisterSet.REGISTER_SIZE, registerSet.count) + bitVector.getRegisterContents(serializer) + var outBytes = serializer.getBytes + + val metadata = new HLLMetadata( + schemaVersion.schemaVersionNumber(), + net.agkn.hll.HLLType.FULL, + p, + RegisterSet.REGISTER_SIZE, + 0, + true, + false, + false + ) + schemaVersion.writeMetadata(outBytes, metadata) + outBytes + } +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/utils/AnyExtensions.scala b/alchemy/src/main/scala/com/swoop/alchemy/utils/AnyExtensions.scala index 448d996..d90d562 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/utils/AnyExtensions.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/utils/AnyExtensions.scala @@ -1,41 +1,41 @@ package com.swoop.alchemy.utils /** - * Convenience methods for all types - */ + * Convenience methods for all types + */ object AnyExtensions { /** Sugar for applying functions in a method chain. */ implicit class TransformOps[A](val underlying: A) extends AnyVal { /** Applies a transformer function in a method chain. - * - * @param f function to apply - * @tparam B the return type - * @return the result of applying `f` on `underlying`. - */ + * + * @param f function to apply + * @tparam B the return type + * @return the result of applying `f` on `underlying`. + */ @inline def transform[B](f: A => B): B = f(underlying) /** Conditionally applies a transformer function in a method chain. - * Use this instead of [[transformWhen()]] when the predicate requires the value of `underlying`. - * - * @param predicate predicate to evaluate to determine if the function should be applied - * @tparam B the return type of the function - * @return `underlying` if the predicate evaluates to `false` or the result of function application. - */ + * Use this instead of [[transformWhen()]] when the predicate requires the value of `underlying`. + * + * @param predicate predicate to evaluate to determine if the function should be applied + * @tparam B the return type of the function + * @return `underlying` if the predicate evaluates to `false` or the result of function application. + */ @inline def transformIf[B <: A](predicate: A => Boolean)(f: A => B): A = if (predicate(underlying)) f(underlying) else underlying /** Conditionally applies a transformer function in a method chain. - * Use this instead of [[transformIf()]] when the condition does not require the value of `underlying`. - * - * @param condition condition to evaluate to determine if the function should be applied - * @tparam B the return type of the function - * @return `underlying` if the expression evaluates to `false` or the result of function application. - */ + * Use this instead of [[transformIf()]] when the condition does not require the value of `underlying`. + * + * @param condition condition to evaluate to determine if the function should be applied + * @tparam B the return type of the function + * @return `underlying` if the expression evaluates to `false` or the result of function application. + */ @inline def transformWhen[B <: A](condition: => Boolean)(f: A => B): A = if (condition) f(underlying) @@ -47,25 +47,25 @@ object AnyExtensions { implicit class TapOps[A](val underlying: A) extends AnyVal { /** Applies a function for its side-effect as part of a method chain. - * Inspired by Ruby's `Object#tap`. - * - * @param f side-effect function to call - * @tparam B the return type of the function; ignored - * @return `this` - */ + * Inspired by Ruby's `Object#tap`. + * + * @param f side-effect function to call + * @tparam B the return type of the function; ignored + * @return `this` + */ @inline def tap[B](f: A => B): A = { f(underlying) underlying } /** Conditionally applies a function for its side-effect as part of a method chain. - * Use this instead of [[tapWhen()]] when the predicate requires the value of `underlying`. - * - * @param predicate predicate to evaluate to determine if the side-effect should be invoked - * @param f side-effect function to call - * @tparam B the return type of the function; ignored - * @return `this` - */ + * Use this instead of [[tapWhen()]] when the predicate requires the value of `underlying`. + * + * @param predicate predicate to evaluate to determine if the side-effect should be invoked + * @param f side-effect function to call + * @tparam B the return type of the function; ignored + * @return `this` + */ @inline def tapIf[B](predicate: A => Boolean)(f: A => B): A = { if (predicate(underlying)) f(underlying) @@ -73,13 +73,13 @@ object AnyExtensions { } /** Conditionally applies a function for its side-effect as part of a method chain. - * Use this instead of [[tapIf()]] when the condition does not require the value of `underlying`. - * - * @param condition condition to evaluate to determine if the side-effect should be invoked - * @param f side-effect function to call - * @tparam B the return type of the function; ignored - * @return `this` - */ + * Use this instead of [[tapIf()]] when the condition does not require the value of `underlying`. + * + * @param condition condition to evaluate to determine if the side-effect should be invoked + * @param f side-effect function to call + * @tparam B the return type of the function; ignored + * @return `this` + */ @inline def tapWhen[B](condition: => Boolean)(f: A => B): A = { if (condition) f(underlying) @@ -91,57 +91,57 @@ object AnyExtensions { implicit class PrintOps[A](val underlying: A) extends AnyVal { /** Taps and prints the object in a method chain. - * Shorthand for `.tap(println)`. - * - * @return `underlying` - */ + * Shorthand for `.tap(println)`. + * + * @return `underlying` + */ def tapp: A = underlying.tap(println) /** Prints a value as a side effect. - * - * @param v the value to print - * @return `underlying` - */ + * + * @param v the value to print + * @return `underlying` + */ def print[B](v: B): A = underlying.tap((_: A) => println(v)) /** Conditionally taps and prints the object in a method chain. - * Use this instead of [[printWhen()]] when the predicate requires the value of `underlying`. - * - * @param predicate predicate to evaluate to determine if the underlying value should be printed - * @return `this` - */ + * Use this instead of [[printWhen()]] when the predicate requires the value of `underlying`. + * + * @param predicate predicate to evaluate to determine if the underlying value should be printed + * @return `this` + */ def printIf(predicate: A => Boolean): A = underlying.tapIf(predicate)(println) /** Conditionally prints a value as a side effect. - * Use this instead of [[printWhen()]] when the predicate requires the value of `underlying`. - * - * @param predicate predicate to evaluate to determine if the value should be printed - * @param v side-effect function to call - * @tparam B the value type; ignored - * @return `this` - */ + * Use this instead of [[printWhen()]] when the predicate requires the value of `underlying`. + * + * @param predicate predicate to evaluate to determine if the value should be printed + * @param v side-effect function to call + * @tparam B the value type; ignored + * @return `this` + */ def printIf[B](predicate: A => Boolean, v: B): A = underlying.tapIf(predicate)((_: A) => println(v)) /** Conditionally taps and prints the object in a method chain. - * Use this instead of [[printIf()]] when the condition does not require the value of `underlying`. - * - * @param condition condition to evaluate to determine if the underlying value should be printed - * @return `underlying` - */ + * Use this instead of [[printIf()]] when the condition does not require the value of `underlying`. + * + * @param condition condition to evaluate to determine if the underlying value should be printed + * @return `underlying` + */ def printWhen(condition: => Boolean): A = underlying.tapWhen(condition)(println) /** Conditionally prints a value as a side effect. - * Use this instead of [[printIf()]] when the condition does not require the value of `underlying`. - * - * @param condition condition to evaluate to determine if the value should be printed - * @tparam B the value type; ignored - * @return `underlying` - */ + * Use this instead of [[printIf()]] when the condition does not require the value of `underlying`. + * + * @param condition condition to evaluate to determine if the value should be printed + * @tparam B the value type; ignored + * @return `underlying` + */ def printWhen[B](condition: => Boolean, v: B): A = underlying.tapWhen(condition)((_: A) => println(v)) @@ -151,14 +151,14 @@ object AnyExtensions { implicit class ThrowOps[A](val underlying: A) extends AnyVal { /** Raises an exception if a predicate is satisfied. - * Use this instead of [[throwWhen()]] when the predicate requires the value of `underlying`. - * - * @param predicate predicate to evaluate to determine if the exception should be thrown - * @param e expression that will return an exception - * @tparam B the exception type - * @return `underlying` if the predicate evaluates to `false`. - * @throws B - */ + * Use this instead of [[throwWhen()]] when the predicate requires the value of `underlying`. + * + * @param predicate predicate to evaluate to determine if the exception should be thrown + * @param e expression that will return an exception + * @tparam B the exception type + * @return `underlying` if the predicate evaluates to `false`. + * @throws B + */ def throwIf[B <: Throwable](predicate: A => Boolean)(e: => B): A = { if (predicate(underlying)) throw e @@ -166,14 +166,14 @@ object AnyExtensions { } /** Raises an exception if a condition is satisfied. - * Use this instead of [[throwIf()]] when the condition does not require the value of `underlying`. - * - * @param condition condition to evaluate to determine if the exception should be thrown - * @param e expression that will return an exception - * @tparam B the exception type - * @return `underlying` if the predicate evaluates to `false`. - * @throws B - */ + * Use this instead of [[throwIf()]] when the condition does not require the value of `underlying`. + * + * @param condition condition to evaluate to determine if the exception should be thrown + * @param e expression that will return an exception + * @tparam B the exception type + * @return `underlying` if the predicate evaluates to `false`. + * @throws B + */ def throwWhen[B <: Throwable](condition: => Boolean, e: => B): A = { if (condition) throw e diff --git a/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala index a22212e..8e31af9 100644 --- a/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala +++ b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala @@ -1,8 +1,13 @@ package com.swoop.alchemy.spark.expressions.hll +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus +import com.swoop.alchemy.spark.expressions.hll.Implementation.{AGKN, STRM} import com.swoop.alchemy.spark.expressions.hll.functions.{hll_init_collection, hll_init_collection_agg, _} import com.swoop.spark.test.HiveSqlSpec +import net.agkn.hll.HLL +import net.agkn.hll.HLLType.FULL import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.XXH64 import org.apache.spark.sql.functions.{array, col, lit, map} import org.apache.spark.sql.types._ import org.scalatest.{Matchers, WordSpec} @@ -20,7 +25,6 @@ object HLLFunctionsTestHelpers { case class Data2(c1: Array[String], c2: Map[String, String]) case class Data3(c1: String, c2: String, c3: String) - } class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { @@ -29,10 +33,29 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { lazy val spark = sqlc.sparkSession - "HyperLogLog functions" should { + "HyperLogLog functions" when { + "config key unset" should { + behave like hllImplementation(StreamLib, spark.conf.unset(IMPLEMENTATION_CONFIG_KEY)) + } + + "config key AGKN" should { + behave like hllImplementation(AgKn, spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "AGKN")) + } + + "config key STRM" should { + behave like hllImplementation(StreamLib, spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "STRM")) + } + } + + def hllImplementation(impl: Implementation, setup: => Unit) = { + "use right implementation" in { + setup + hll_init(lit(null), 0.39).expr.asInstanceOf[HyperLogLogInitSimple].impl should be(impl) + } "not allow relativeSD > 39%" in { - val err = "requirement failed: HLL++ requires at least 4 bits for addressing. Use a lower error, at most 39%." + setup + val err = "requirement failed: HLL requires at least 4 bits for addressing. Use a lower error, at most 39%." val c = lit(null) noException should be thrownBy hll_init(c, 0.39) @@ -46,10 +69,10 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { the[IllegalArgumentException] thrownBy { hll_init_collection(c, 0.40) } should have message err - } "register native org.apache.spark.sql.ext.functions" in { + setup HLLFunctionRegistration.registerFunctions(spark) noException should be thrownBy spark.sql( @@ -63,13 +86,15 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { | hll_cardinality(hll_init_agg(1, 0.05)), | hll_cardinality(hll_init_collection_agg(array(1,2,3), 0.05)), | hll_cardinality(hll_row_merge(hll_init(1),hll_init(1))), - | hll_intersect_cardinality(hll_init(1), hll_init(1)) - """.stripMargin + | hll_intersect_cardinality(hll_init(1), hll_init(1)), + | hll_cardinality(hll_convert(hll_init(1),"STRM","AGKN")) + """.stripMargin // last line will error if evaluated, but is valid under statical analysis ) } - "estimate cardinality of simple types and collections" in { + setup + val a123 = array(lit(1), lit(2), lit(3)) val simpleValues = Seq( @@ -92,8 +117,11 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { /* collections */ 0, 0, 0, 3 )) } + // @todo merge tests with grouping "estimate cardinality correctly" in { + setup + import spark.implicits._ val df = spark.createDataset[Data](Seq[Data]( @@ -131,7 +159,10 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { 0 // 0 unique values across all arrays, nulls not counted )) } + "estimate multiples correctly" in { + setup + import spark.implicits._ val createSampleData = @@ -149,9 +180,26 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { } } - "HyperLogLog aggregate functions" should { + + "HyperLogLog aggregate functions" when { + "config key unset" should { + behave like aggregateFunctions(spark.conf.unset(IMPLEMENTATION_CONFIG_KEY)) + } + + "config key AGKN" should { + behave like aggregateFunctions(spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "AGKN")) + + } + + "config key STRM" should { + behave like aggregateFunctions(spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "STRM")) + } + } + + def aggregateFunctions(setup: => Unit): Unit = { // @todo merge tests with grouping "estimate cardinality correctly" in { + setup import spark.implicits._ val df = spark.createDataset[Data](Seq[Data]( @@ -190,6 +238,7 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { )) } "estimate multiples correctly" in { + setup import spark.implicits._ val createSampleData = @@ -205,6 +254,7 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { results should be(Seq(4, 4)) } + } def merge(df: DataFrame): DataFrame = @@ -246,9 +296,24 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { } } - "HyperLogLog intersection function" should { + "HyperLogLog intersection function" when { + "config key unset" should { + behave like intersectionFunction(spark.conf.unset(IMPLEMENTATION_CONFIG_KEY)) + } + + "config key AGKN" should { + behave like intersectionFunction(spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "AGKN")) + } + + "config key STRM" should { + behave like intersectionFunction(spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "STRM")) + } + } + + def intersectionFunction(setup: => Unit): Unit = { // @todo merge tests with grouping "estimate cardinality correctly" in { + setup import spark.implicits._ val df = spark.createDataset[Data3](Seq[Data3]( @@ -269,6 +334,7 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { } "handle nulls correctly" in { + setup import spark.implicits._ val df = spark.createDataset[Data3](Seq[Data3]( @@ -290,4 +356,55 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { results should be((0, -1)) } } + + "Spark SQL functions" should { + "accept HLL implementation by name in signature" in { + HLLFunctionRegistration.registerFunctions(spark) + noException should be thrownBy spark.sql( + """select + | hll_cardinality(hll_merge(hll_init(1, 0.05, "AGKN"), "AGKN"), "AGKN"), + | hll_cardinality(hll_merge(hll_init_collection(array(1,2,3), 0.05, "STRM"), "STRM"), "STRM"), + | hll_cardinality(hll_init_agg(1, 0.05, "AGKN"), "AGKN"), + | hll_cardinality(hll_init_collection_agg(array(1,2,3), 0.05, "STRM"), "STRM"), + | hll_cardinality(hll_row_merge(hll_init(1, 0.05, "AGKN"),hll_init(1, 0.05, "AGKN"), "AGKN"), "AGKN"), + | hll_intersect_cardinality(hll_init(1, 0.05, "STRM"), hll_init(1, 0.05, "STRM"), "STRM") + """.stripMargin + ) + } + } + + "Conversion function" should { + "estimate similar as original" in { + + def randomize(callable: (Long) => Unit, n: Int): Unit = { + val rand = new scala.util.Random(42) + for (i <- 0 until n) { + callable(XXH64.hashInt(rand.nextInt(n), 0)) + } + } + + val p = 20 + + val strm = new HyperLogLogPlus(p, 0) + val agkn = new HLL(p, 5, 0, false, FULL) + + val n = 10000 + randomize(strm.offerHashed(_: Long), n) + randomize(agkn.addRaw, n) + + val converted = strmToAgkn(strm) + + converted.cardinality() should be(agkn.cardinality() +- 1) + } + } + + "error on unsupported conversion" in { + the[IllegalArgumentException] thrownBy { + import spark.implicits._ + + spark.emptyDataset[(Int)].toDF() + .withColumn("foo", hll_convert(hll_init(lit(1)), AGKN, STRM)) + .collect() + } should have message "HLL conversion is currently only supported from STREAM_LIB to AGGREGATE_KNOWLEDGE" + } } diff --git a/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/PostgresInteropTest.scala b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/PostgresInteropTest.scala new file mode 100644 index 0000000..c0bed31 --- /dev/null +++ b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/PostgresInteropTest.scala @@ -0,0 +1,104 @@ +package com.swoop.alchemy.spark.expressions.hll.agkn + +import java.sql.{DriverManager, ResultSet} + +import com.swoop.alchemy.spark.expressions.hll.IMPLEMENTATION_CONFIG_KEY +import com.swoop.alchemy.spark.expressions.hll.functions._ +import com.swoop.spark.test.HiveSqlSpec +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.scalatest.{Matchers, WordSpec} + +case class Postgres(user: String, database: String, port: Int) { + val con_str = s"jdbc:postgresql://localhost:${port}/${database}?user=${user}" + + def execute[T](query: String, handler: ResultSet => T): T = { + val conn = DriverManager.getConnection(con_str) + try { + val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + return handler(stm.executeQuery(query)) + } finally { + conn.close() + } + } + + def update(query: String): Unit = { + val conn = DriverManager.getConnection(con_str) + try { + val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + stm.executeUpdate(query) + } finally { + conn.close() + } + } + + def sparkRead(schema: String, table: String)(implicit spark: SparkSession): DataFrame = + spark.read + .format("jdbc") + .option("url", s"jdbc:postgresql:${database}") + .option("dbtable", s"${schema}.${table}") + .option("user", user) + .load() + + def sparkWrite(schema: String, table: String)(df: DataFrame) = + df.write + .format("jdbc") + .option("url", s"jdbc:postgresql:${database}") + .option("dbtable", s"${schema}.${table}") + .option("user", user) + .save() +} + +class PostgresInteropTest extends WordSpec with Matchers with HiveSqlSpec { + lazy val spark = sqlc.sparkSession + lazy val pg = Postgres("postgres", "postgres", 5432) + + "Postgres interop" should { + "calculate same results" in { + import spark.implicits._ + + // use Aggregate Knowledge (Postgres-compatible) HLL implementation + spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "AGKN") + + // init Postgres extension for database + pg.update("CREATE EXTENSION IF NOT EXISTS hll;") + + + // create some random not-entirely distinct rows + val rand = new scala.util.Random(42) + val n = 100000 + val randomDF = sc.parallelize( + Seq.fill(n) { + (rand.nextInt(24), rand.nextInt(n)) + } + ).toDF("hour", "id").cache + + // create hll aggregates (by hour) + val byHourDF = randomDF.groupBy("hour").agg(hll_init_agg("id", .39).as("hll_id")).cache + + // send hlls to postgres + pg.update("DROP TABLE IF EXISTS spark_hlls CASCADE;") + pg.sparkWrite("public", "spark_hlls")(byHourDF) + + // convert hll column from `bytea` to `hll` type + pg.update( + """ + |ALTER TABLE spark_hlls + |ALTER COLUMN hll_id TYPE hll USING CAST (hll_id AS hll); + |""".stripMargin + ) + + // re-aggregate all hours in Spark + val distinctSpark = byHourDF.select(hll_cardinality(hll_merge(byHourDF("hll_id")))).as[Long].first() + // re-aggregate all hours in Postgres + val distinctPostgres = pg.execute( + "SELECT CAST (hll_cardinality(hll_union_agg(hll_id)) as Integer) AS approx FROM spark_hlls", + (rs) => { + rs.next; + rs.getInt("approx") + } + ) + + distinctSpark should be(distinctPostgres) + } + } +} diff --git a/build.sbt b/build.sbt index a4989fb..5d75283 100644 --- a/build.sbt +++ b/build.sbt @@ -23,7 +23,9 @@ lazy val alchemy = (project in file(".")) resourceDirectory in Test := baseDirectory.value / "alchemy/src/test/resources", libraryDependencies ++= Seq( scalaTest % Test withSources(), - "com.swoop" %% "spark-test-sugar" % "1.5.0" % Test withSources() + "com.swoop" %% "spark-test-sugar" % "1.5.0" % Test withSources(), + "net.agkn" % "hll" % "1.6.0" withSources(), + "org.postgresql" % "postgresql" % "42.2.8" % Test withSources() ), libraryDependencies ++= sparkDependencies, fork in Test := true // required for Spark diff --git a/codeStyleSettings.xml b/codeStyleSettings.xml new file mode 100644 index 0000000..3194b56 --- /dev/null +++ b/codeStyleSettings.xml @@ -0,0 +1,19 @@ + + + + + + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..884251a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3' + +volumes: + postgres-data: + driver: local + +services: + postgres: + image: swoopinc/postgres-hll:11 + ports: + - 5432:5432 + volumes: + - postgres-data:/var/lib/postgresql/data:cached