Skip to content

Commit

Permalink
Finish HLL changes
Browse files Browse the repository at this point in the history
Correct null handling for interesection cardinality
Guarantee non-negative intersection cardinality and clarify documentation
  • Loading branch information
pidge committed Jan 2, 2019
1 parent 14f294d commit 293bc91
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ case class HyperLogLogInitCollectionAgg(
/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm.
*
* This version merges the "sketches" into a combined binary composable representation.
* This version aggregates the "sketches" into a single merged "sketch" that represents the union of the constituents.
*
* @param child "sketch" to merge
*/
Expand Down Expand Up @@ -335,6 +335,7 @@ case class HyperLogLogMerge(
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm.
*
* This version merges multiple "sketches" in one row into a single field.
* @see HyperLogLogMerge
*
* @param children "sketch" row fields to merge
*/
Expand All @@ -348,13 +349,11 @@ case class HyperLogLogRowMerge(children: Seq[Expression])

require(children.nonEmpty, s"$prettyName requires at least one argument.")


/** The 1st child (separator) is str, and rest are either str or array of str. */
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(BinaryType)

override def dataType: DataType = BinaryType

override def nullable: Boolean = children.exists(_.nullable)
override def nullable: Boolean = children.forall(_.nullable)

override def foldable: Boolean = children.forall(_.foldable)

Expand Down Expand Up @@ -405,32 +404,54 @@ case class HyperLogLogCardinality(override val child: Expression) extends UnaryE
/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm.
*
* Returns the estimated intersection cardinality of two HLL++ "sketches"
* Computes a merged (unioned) sketch and uses the fact that |A intersect B| = (|A| + |B|) - |A union B| to estimate
* the intersection cardinality of two HLL++ "sketches".
*
* @see HyperLogLogRowMerge
* @see HyperLogLogCardinality
*
* @param left HLL++ "sketch"
* @param right HLL++ "sketch"
*
* @return the estimated intersection cardinality (0 if one sketch is null, but null if both are)
*/
@ExpressionDescription(
usage =
"""
_FUNC_(sketch, sketch) - Returns the estimated intersection cardinality of the binary representations produced by HyperLogLog++.
_FUNC_(sketchL, sketchR) - Returns the estimated intersection cardinality of the binary representations produced by
HyperLogLog++. Computes a merged (unioned) sketch and uses the fact that |A intersect B| = (|A| + |B|) - |A union B|.
Returns null if both sketches are null, but 0 if only one is
""")
case class HyperLogLogIntersectionCardinality(override val left: Expression, override val right: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback {

override def inputTypes: Seq[DataType] = Seq(BinaryType, BinaryType)

override def dataType: DataType = LongType

override def nullSafeEval(input1: Any, input2: Any): Any = {
val leftHLL = HyperLogLogPlus.Builder.build(input1.asInstanceOf[Array[Byte]])
val rightHLL = HyperLogLogPlus.Builder.build(input2.asInstanceOf[Array[Byte]])
override def nullable: Boolean = left.nullable && right.nullable

override def eval(input: InternalRow): Any = {
val leftValue = left.eval(input)
val rightValue = right.eval(input)

if (leftValue != null && rightValue != null) {
val leftHLL = HyperLogLogPlus.Builder.build(leftValue.asInstanceOf[Array[Byte]])
val rightHLL = HyperLogLogPlus.Builder.build(rightValue.asInstanceOf[Array[Byte]])

val leftCount = leftHLL.cardinality()
val rightCount = rightHLL.cardinality()
leftHLL.addAll(rightHLL)
val unionCount = leftHLL.cardinality()
val leftCount = leftHLL.cardinality()
val rightCount = rightHLL.cardinality()
leftHLL.addAll(rightHLL)
val unionCount = leftHLL.cardinality()

(leftCount + rightCount) - unionCount
// guarantee a non-negative result despite the approximate nature of the counts
math.max((leftCount + rightCount) - unionCount, 0L)
} else {
if (leftValue != null || rightValue != null) {
0L
} else {
null
}
}
}

override def prettyName: String = "hll_intersect_cardinality"
Expand Down Expand Up @@ -506,6 +527,7 @@ trait HLLFunctions extends WithHelper {
def hll_row_merge(es: Column*): Column = withExpr {
HyperLogLogRowMerge(es.map(_.expr))
}

// split arguments to avoid collision w/ above after erasure
def hll_row_merge(columnName: String, columnNames: String*): Column =
hll_row_merge((columnName +: columnNames).map(col): _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec {

"HyperLogLog row merge function" should {
// @todo merge tests with grouping
"estimate cardinality correctly" in {
"estimate cardinality correctly, with nulls" in {
import spark.implicits._

val df = spark.createDataset[Data3](Seq[Data3](
Expand All @@ -237,7 +237,7 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec {
val results = df
.select(hll_init('c1).as('c1), hll_init('c2).as('c2), hll_init('c3).as('c3))
.select(hll_cardinality(hll_row_merge('c1, 'c2, 'c3)))
.na.fill(-1)
.na.fill(-1L)
.as[Long]
.head(5)
.toSeq
Expand All @@ -262,10 +262,32 @@ class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec {
val results = df
.select(hll_init_agg('c1).as('c1), hll_init_agg('c2).as('c2), hll_init_agg('c3).as('c3))
.select(hll_intersect_cardinality('c1, 'c2), hll_intersect_cardinality('c2, 'c3))
.as[(Long,Long)]
.as[(Long, Long)]
.head()

results should be((5, 0))
}

"handle nulls correctly" in {
import spark.implicits._

val df = spark.createDataset[Data3](Seq[Data3](
Data3("a", null, null),
Data3("b", null, null),
Data3("c", null, null),
Data3("d", null, null),
Data3("e", null, null)
))

val results = df
.select(hll_init_agg('c1).as('c1), hll_init_agg('c2).as('c2), hll_init_agg('c3).as('c3))
.select(hll_intersect_cardinality('c1, 'c2), hll_intersect_cardinality('c2, 'c3))
.na.fill(-1L)
.as[(Long, Long)]
.head()

println(results)
results should be((0, -1))
}
}
}

0 comments on commit 293bc91

Please sign in to comment.