-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds HyperLogLog++ functions for Spark
Includes native function registration, and a hashing strategy that handles all Spark datatypes
- Loading branch information
Showing
12 changed files
with
1,014 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.1.0-SNAPSHOT | ||
0.2.0-SNAPSHOT |
7 changes: 7 additions & 0 deletions
7
alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/FunctionRegistration.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package com.swoop.alchemy.spark.expressions | ||
|
||
import org.apache.spark.sql.SparkSession | ||
|
||
trait FunctionRegistration { | ||
def registerFunctions(spark: SparkSession): Unit | ||
} |
85 changes: 85 additions & 0 deletions
85
alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package com.swoop.alchemy.spark.expressions | ||
|
||
import org.apache.spark.sql.EncapsulationViolator.createAnalysisException | ||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.catalyst.FunctionIdentifier | ||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, RuntimeReplaceable} | ||
|
||
import scala.reflect.ClassTag | ||
import scala.util.{Failure, Success, Try} | ||
|
||
// based on Spark's FunctionRegistry @ossSpark | ||
trait NativeFunctionRegistration extends FunctionRegistration { | ||
|
||
type FunctionBuilder = Seq[Expression] => Expression | ||
|
||
def expressions: Map[String, (ExpressionInfo, FunctionBuilder)] | ||
|
||
|
||
def registerFunctions(fr: FunctionRegistry): Unit = { | ||
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder) } | ||
} | ||
|
||
def registerFunctions(spark: SparkSession): Unit = { | ||
registerFunctions(spark.sessionState.functionRegistry) | ||
} | ||
|
||
/** See usage above. */ | ||
protected def expression[T <: Expression](name: String) | ||
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { | ||
|
||
// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main | ||
// constructor and contains non-parameter `child` and should not be used as function builder. | ||
val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) { | ||
val all = tag.runtimeClass.getConstructors | ||
val maxNumArgs = all.map(_.getParameterCount).max | ||
all.filterNot(_.getParameterCount == maxNumArgs) | ||
} else { | ||
tag.runtimeClass.getConstructors | ||
} | ||
// See if we can find a constructor that accepts Seq[Expression] | ||
val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) | ||
val builder = (expressions: Seq[Expression]) => { | ||
if (varargCtor.isDefined) { | ||
// If there is an apply method that accepts Seq[Expression], use that one. | ||
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { | ||
case Success(e) => e | ||
case Failure(e) => | ||
// the exception is an invocation exception. To get a meaningful message, we need the | ||
// cause. | ||
throw createAnalysisException(e.getCause.getMessage) | ||
} | ||
} else { | ||
// Otherwise, find a constructor method that matches the number of arguments, and use that. | ||
val params = Seq.fill(expressions.size)(classOf[Expression]) | ||
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { | ||
throw createAnalysisException(s"Invalid number of arguments for function $name") | ||
} | ||
Try(f.newInstance(expressions: _*).asInstanceOf[Expression]) match { | ||
case Success(e) => e | ||
case Failure(e) => | ||
// the exception is an invocation exception. To get a meaningful message, we need the | ||
// cause. | ||
throw createAnalysisException(e.getCause.getMessage) | ||
} | ||
} | ||
} | ||
|
||
(name, (expressionInfo[T](name), builder)) | ||
} | ||
|
||
/** | ||
* Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name. | ||
*/ | ||
protected def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = { | ||
val clazz = scala.reflect.classTag[T].runtimeClass | ||
val df = clazz.getAnnotation(classOf[ExpressionDescription]) | ||
if (df != null) { | ||
new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended()) | ||
} else { | ||
new ExpressionInfo(clazz.getCanonicalName, name) | ||
} | ||
} | ||
|
||
} |
15 changes: 15 additions & 0 deletions
15
alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/WithHelper.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package com.swoop.alchemy.spark.expressions | ||
|
||
import org.apache.spark.sql.Column | ||
import org.apache.spark.sql.catalyst.expressions.Expression | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction | ||
|
||
trait WithHelper { | ||
def withExpr(expr: Expression): Column = new Column(expr) | ||
|
||
def withAggregateFunction( | ||
func: AggregateFunction, | ||
isDistinct: Boolean = false): Column = { | ||
new Column(func.toAggregateExpression(isDistinct)) | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package com.swoop.alchemy.spark.expressions.hll | ||
|
||
import org.apache.spark.sql | ||
import org.apache.spark.sql.Column | ||
|
||
|
||
/** Convenience trait to use HyperLogLog functions with the same error consistently. | ||
* Spark's own [[sql.functions.approx_count_distinct()]] as well as the granular HLL | ||
* [[HLLFunctions.hll_init()]] and [[HLLFunctions.hll_init_collection()]] will be | ||
* automatically parameterized by [[BoundHLL.hllError]]. | ||
*/ | ||
trait BoundHLL extends Serializable { | ||
|
||
def hllError: Double | ||
|
||
def approx_count_distinct(col: Column): Column = | ||
sql.functions.approx_count_distinct(col, hllError) | ||
|
||
def approx_count_distinct(colName: String): Column = | ||
sql.functions.approx_count_distinct(colName, hllError) | ||
|
||
def hll_init(col: Column): Column = | ||
functions.hll_init(col, hllError) | ||
|
||
def hll_init(columnName: String): Column = | ||
functions.hll_init(columnName, hllError) | ||
|
||
def hll_init_collection(col: Column): Column = | ||
functions.hll_init_collection(col, hllError) | ||
|
||
def hll_init_collection(columnName: String): Column = | ||
functions.hll_init_collection(columnName, hllError) | ||
|
||
def hll_init_agg(col: Column): Column = | ||
functions.hll_init_agg(col, hllError) | ||
|
||
def hll_init_agg(columnName: String): Column = | ||
functions.hll_init_agg(columnName, hllError) | ||
|
||
def hll_init_collection_agg(col: Column): Column = | ||
functions.hll_init_collection_agg(col, hllError) | ||
|
||
def hll_init_collection_agg(columnName: String): Column = | ||
functions.hll_init_collection_agg(columnName, hllError) | ||
|
||
} | ||
|
||
object BoundHLL { | ||
def apply(error: Double): BoundHLL = new BoundHLL { | ||
def hllError: Double = error | ||
} | ||
} |
47 changes: 47 additions & 0 deletions
47
alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package com.swoop.alchemy.spark.expressions.hll | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{InterpretedHashFunction, XXH64} | ||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.types.UTF8String | ||
|
||
/** | ||
* Hash function for Spark data values that is suitable for cardinality counting. Unlike Spark's built-in hashing, | ||
* it differentiates between different data types and accounts for nulls. | ||
*/ | ||
abstract class CardinalityHashFunction extends InterpretedHashFunction { | ||
|
||
override def hash(value: Any, dataType: DataType, seed: Long): Long = { | ||
|
||
def hashWithTag(typeTag: Long) = | ||
super.hash(value, dataType, hashLong(typeTag, seed)) | ||
|
||
value match { | ||
// change null handling to differentiate between things like Array.empty and Array(null) | ||
case null => hashLong(seed, seed) | ||
// add type tags to differentiate between values on their own or in complex types | ||
case _: Array[Byte] => hashWithTag(-3698894927619418744L) | ||
case _: UTF8String => hashWithTag(-8468821688391060513L) | ||
case _: ArrayData => hashWithTag(-1666055126678331734L) | ||
case _: MapData => hashWithTag(5587693012926141532L) | ||
case _: InternalRow => hashWithTag(-891294170547231607L) | ||
// pass through everything else (simple types) | ||
case _ => super.hash(value, dataType, seed) | ||
} | ||
} | ||
|
||
} | ||
|
||
|
||
object CardinalityXxHash64Function extends CardinalityHashFunction { | ||
|
||
override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) | ||
|
||
override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) | ||
|
||
override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { | ||
XXH64.hashUnsafeBytes(base, offset, len, seed) | ||
} | ||
|
||
} |
17 changes: 17 additions & 0 deletions
17
alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package com.swoop.alchemy.spark.expressions.hll | ||
|
||
import com.swoop.alchemy.spark.expressions.NativeFunctionRegistration | ||
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo | ||
|
||
object HLLFunctionRegistration extends NativeFunctionRegistration { | ||
|
||
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( | ||
expression[HyperLogLogInitSimple]("hll_init"), | ||
expression[HyperLogLogInitCollection]("hll_init_collection"), | ||
expression[HyperLogLogInitSimpleAgg]("hll_init_agg"), | ||
expression[HyperLogLogInitCollectionAgg]("hll_init_collection_agg"), | ||
expression[HyperLogLogMerge]("hll_merge"), | ||
expression[HyperLogLogCardinality]("hll_cardinality") | ||
) | ||
|
||
} |
Oops, something went wrong.