Skip to content

Commit

Permalink
Add hll_row_merge and hll_intersect_cardinality
Browse files Browse the repository at this point in the history
Also does a bit of cleanup to existing HLL functions
  • Loading branch information
pidge committed Dec 17, 2018
1 parent e076baa commit 14f294d
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ object HLLFunctionRegistration extends NativeFunctionRegistration {
expression[HyperLogLogInitSimpleAgg]("hll_init_agg"),
expression[HyperLogLogInitCollectionAgg]("hll_init_collection_agg"),
expression[HyperLogLogMerge]("hll_merge"),
expression[HyperLogLogCardinality]("hll_cardinality")
expression[HyperLogLogRowMerge]("hll_row_merge"),
expression[HyperLogLogCardinality]("hll_cardinality"),
expression[HyperLogLogIntersectionCardinality]("hll_intersect_cardinality")
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.aggregate.{HyperLogLogPlusPlus, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, ExpressionDescription, UnaryExpression}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -331,41 +331,109 @@ case class HyperLogLogMerge(
override def prettyName: String = "hll_merge"
}

/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm.
*
* This version merges multiple "sketches" in one row into a single field.
*
* @param children "sketch" row fields to merge
*/
@ExpressionDescription(
usage =
"""
_FUNC_(expr) - Returns the merged HLL++ sketch.
""")
case class HyperLogLogRowMerge(children: Seq[Expression])
extends Expression with ExpectsInputTypes with CodegenFallback {

require(children.nonEmpty, s"$prettyName requires at least one argument.")


/** The 1st child (separator) is str, and rest are either str or array of str. */
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(BinaryType)

override def dataType: DataType = BinaryType

override def nullable: Boolean = children.exists(_.nullable)

override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
val flatInputs = children.flatMap(_.eval(input) match {
case null => None
case b: Array[Byte] => Some(HyperLogLogPlus.Builder.build(b))
case _ => throw new IllegalStateException(s"$prettyName only supports Array[Byte]")
})

if (flatInputs.isEmpty) null
else {
val acc = flatInputs.head
flatInputs.tail.foreach(acc.addAll)
acc.getBytes
}
}

override def prettyName: String = "hll_row_merge"
}

/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm.
*
* Returns the estimated cardinality of an HLL++ "sketch"
*
* @param child HLL+ "sketch"
* @param child HLL++ "sketch"
*/
@ExpressionDescription(
usage =
"""
_FUNC_(expr) - Returns the estimated cardinality of the binary representation produced by HyperLogLog++.
_FUNC_(sketch) - Returns the estimated cardinality of the binary representation produced by HyperLogLog++.
""")
case class HyperLogLogCardinality(override val child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback {

override def inputTypes: Seq[DataType] = Seq(BinaryType)

override def dataType: DataType = LongType

override def nullable: Boolean = child.nullable

override def checkInputDataTypes(): TypeCheckResult = {
child.dataType match {
case BinaryType => TypeCheckResult.TypeCheckSuccess
case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports binary input")
}
}

override def nullSafeEval(input: Any): Long = {
val data = input.asInstanceOf[Array[Byte]]
HyperLogLogPlus.Builder.build(data).cardinality()
}

override def prettyName: String = "hll_cardinality"
}

/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm.
*
* Returns the estimated intersection cardinality of two HLL++ "sketches"
*
* @param left HLL++ "sketch"
* @param right HLL++ "sketch"
*/
@ExpressionDescription(
usage =
"""
_FUNC_(sketch, sketch) - Returns the estimated intersection cardinality of the binary representations produced by HyperLogLog++.
""")
case class HyperLogLogIntersectionCardinality(override val left: Expression, override val right: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback {

override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType)

override def dataType: DataType = LongType

override def nullSafeEval(input1: Any, input2: Any): Any = {
val leftHLL = HyperLogLogPlus.Builder.build(input1.asInstanceOf[Array[Byte]])
val rightHLL = HyperLogLogPlus.Builder.build(input2.asInstanceOf[Array[Byte]])

val leftCount = leftHLL.cardinality()
val rightCount = rightHLL.cardinality()
leftHLL.addAll(rightHLL)
val unionCount = leftHLL.cardinality()

(leftCount + rightCount) - unionCount
}

override def prettyName: String = "hll_intersect_cardinality"
}

object functions extends HLLFunctions
Expand Down Expand Up @@ -435,10 +503,24 @@ trait HLLFunctions extends WithHelper {
def hll_merge(columnName: String): Column =
hll_merge(col(columnName))

def hll_row_merge(es: Column*): Column = withExpr {
HyperLogLogRowMerge(es.map(_.expr))
}
// split arguments to avoid collision w/ above after erasure
def hll_row_merge(columnName: String, columnNames: String*): Column =
hll_row_merge((columnName +: columnNames).map(col): _*)

def hll_cardinality(e: Column): Column = withExpr {
HyperLogLogCardinality(e.expr)
}

def hll_cardinality(columnName: String): Column =
hll_cardinality(col(columnName))

def hll_intersect_cardinality(l: Column, r: Column): Column = withExpr {
HyperLogLogIntersectionCardinality(l.expr, r.expr)
}

def hll_intersect_cardinality(leftColumnName: String, rightColumnName: String): Column =
hll_intersect_cardinality(col(leftColumnName), col(rightColumnName))
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ object HLLFunctionsTestHelpers {

case class Data2(c1: Array[String], c2: Map[String, String])

case class Data3(c1: String, c2: String, c3: String)

}

class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec {
Expand Down Expand Up @@ -52,14 +54,16 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec {

noException should be thrownBy spark.sql(
"""select
| hll_cardinality(hll_merge(hll_init(array(1,2,3)))),
| hll_cardinality(hll_merge(hll_init(1))),
| hll_cardinality(hll_merge(hll_init_collection(array(1,2,3)))),
| hll_cardinality(hll_init_agg(array(1,2,3))),
| hll_cardinality(hll_init_agg(1)),
| hll_cardinality(hll_init_collection_agg(array(1,2,3))),
| hll_cardinality(hll_merge(hll_init(array(1,2,3), 0.05))),
| hll_cardinality(hll_merge(hll_init(1, 0.05))),
| hll_cardinality(hll_merge(hll_init_collection(array(1,2,3), 0.05))),
| hll_cardinality(hll_init_agg(array(1,2,3), 0.05)),
| hll_cardinality(hll_init_collection_agg(array(1,2,3), 0.05))
| hll_cardinality(hll_init_agg(1, 0.05)),
| hll_cardinality(hll_init_collection_agg(array(1,2,3), 0.05)),
| hll_cardinality(hll_row_merge(hll_init(1),hll_init(1))),
| hll_intersect_cardinality(hll_init(1), hll_init(1))
""".stripMargin
)
}
Expand Down Expand Up @@ -216,4 +220,52 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec {
hll_cardinality(col(name)).as(s"c$idx")
}: _*
).head.toSeq.map(_.asInstanceOf[Long])

"HyperLogLog row merge function" should {
// @todo merge tests with grouping
"estimate cardinality correctly" in {
import spark.implicits._

val df = spark.createDataset[Data3](Seq[Data3](
Data3("a", "a", "a"),
Data3("a", "b", "c"),
Data3("a", "b", null),
Data3("a", null, null),
Data3(null, null, null)
))

val results = df
.select(hll_init('c1).as('c1), hll_init('c2).as('c2), hll_init('c3).as('c3))
.select(hll_cardinality(hll_row_merge('c1, 'c2, 'c3)))
.na.fill(-1)
.as[Long]
.head(5)
.toSeq

results should be(Seq(1, 3, 2, 1, -1)) // nulls skipped
}
}

"HyperLogLog intersection function" should {
// @todo merge tests with grouping
"estimate cardinality correctly" in {
import spark.implicits._

val df = spark.createDataset[Data3](Seq[Data3](
Data3("a", "e", "f"),
Data3("b", "d", "g"),
Data3("c", "c", "h"),
Data3("d", "b", "i"),
Data3("e", "a", "j")
))

val results = df
.select(hll_init_agg('c1).as('c1), hll_init_agg('c2).as('c2), hll_init_agg('c3).as('c3))
.select(hll_intersect_cardinality('c1, 'c2), hll_intersect_cardinality('c2, 'c3))
.as[(Long,Long)]
.head()

results should be((5, 0))
}
}
}

0 comments on commit 14f294d

Please sign in to comment.