diff --git a/core/api/core.api b/core/api/core.api index c788698ff0..74f31e3dbd 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -5756,7 +5756,7 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/Aggregations } public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator : org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler, org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler, org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler { - public fun (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler;Ljava/lang/String;)V + public fun (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler;Ljava/lang/String;Ljava/util/Map;)V public fun aggregateMultipleColumns (Lkotlin/sequences/Sequence;)Ljava/lang/Object; public fun aggregateSequence (Lkotlin/sequences/Sequence;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType;)Ljava/lang/Object; public fun aggregateSingleColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Object; @@ -5769,6 +5769,7 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ public final fun getInputHandler ()Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler; public final fun getMultipleColumnsHandler ()Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler; public final fun getName ()Ljava/lang/String; + public final fun getStatisticsParameters ()Ljava/util/Map; public fun indexOfAggregationResultSingleSequence (Lkotlin/sequences/Sequence;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType;)I public fun init (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;)V public fun preprocessAggregation (Lkotlin/sequences/Sequence;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType;)Lkotlin/Pair; diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index a7cb2a8b14..4746df815b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -42,6 +42,7 @@ public class Aggregator( public val inputHandler: AggregatorInputHandler, public val multipleColumnsHandler: AggregatorMultipleColumnsHandler, public val name: String, + public val statisticsParameters: Map, ) : AggregatorInputHandler by inputHandler, AggregatorMultipleColumnsHandler by multipleColumnsHandler, AggregatorAggregationHandler by aggregationHandler { @@ -75,6 +76,7 @@ public class Aggregator( aggregationHandler: AggregatorAggregationHandler, inputHandler: AggregatorInputHandler, multipleColumnsHandler: AggregatorMultipleColumnsHandler, + statisticsParameters: Map, ): AggregatorProvider = AggregatorProvider { name -> Aggregator( @@ -82,6 +84,7 @@ public class Aggregator( inputHandler = inputHandler, multipleColumnsHandler = multipleColumnsHandler, name = name, + statisticsParameters = statisticsParameters, ) } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt index 7b1b0357eb..d838b1b061 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt @@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.api.asSequence import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler +import org.jetbrains.kotlinx.dataframe.impl.columns.StatisticResult +import org.jetbrains.kotlinx.dataframe.impl.columns.ValueColumnInternal import kotlin.reflect.KType /** @@ -26,13 +28,32 @@ public interface AggregatorAggregationHandler /** * Aggregates the data in the given column and computes a single resulting value. - * Calls [aggregateSequence]. + * Calls [aggregateSequence]. It tries to exploit a cache for statistics which is proper of + * [ValueColumnInternal] */ - public fun aggregateSingleColumn(column: DataColumn): Return = - aggregateSequence( + public fun aggregateSingleColumn(column: DataColumn): Return { + if (column is ValueColumnInternal<*>) { + // cache check, cache is dynamically created + val aggregator = this.aggregator ?: throw IllegalStateException("Aggregator is required") + val statisticName = aggregator.name + val parameters = aggregator.statisticsParameters + val desiredStatistic = column.getStatisticCacheOrNull(statisticName, parameters) + // if desiredStatistic is null, statistic was never calculated. + if (desiredStatistic != null) { + return desiredStatistic.value as Return + } + val statisticValue = aggregateSequence( + values = column.asSequence(), + valueType = column.type().toValueType(), + ) + column.putStatisticCache(statisticName, parameters, StatisticResult(statisticValue)) + return aggregateSingleColumn(column) + } + return aggregateSequence( values = column.asSequence(), valueType = column.type().toValueType(), ) + } /** * Function that can give the return type of [aggregateSequence] as [KType], given the type of the input. diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 9648fed3ad..47b0f79eb7 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -35,20 +35,24 @@ public object Aggregators { getReturnType: CalculateReturnType, indexOfResult: IndexOfResult, stepOneSelector: Selector, + statisticsParameters: Map, ) = Aggregator( aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType), inputHandler = AnyInputHandler(), multipleColumnsHandler = TwoStepMultipleColumnsHandler(), + statisticsParameters = statisticsParameters, ) private fun flattenHybridForAny( getReturnType: CalculateReturnType, indexOfResult: IndexOfResult, reducer: Reducer, + statisticsParameters: Map, ) = Aggregator( aggregationHandler = HybridAggregationHandler(reducer, indexOfResult, getReturnType), inputHandler = AnyInputHandler(), multipleColumnsHandler = FlatteningMultipleColumnsHandler(), + statisticsParameters = statisticsParameters, ) private fun twoStepReducingForAny( @@ -63,6 +67,7 @@ public object Aggregators { ReducingAggregationHandler(stepTwoReducer, getReturnType) }, ), + statisticsParameters = emptyMap(), ) private fun flattenReducingForAny(reducer: Reducer) = @@ -70,6 +75,7 @@ public object Aggregators { aggregationHandler = ReducingAggregationHandler(reducer, preserveReturnTypeNullIfEmpty), inputHandler = AnyInputHandler(), multipleColumnsHandler = FlatteningMultipleColumnsHandler(), + statisticsParameters = emptyMap(), ) private fun flattenReducingForAny( @@ -79,24 +85,29 @@ public object Aggregators { aggregationHandler = ReducingAggregationHandler(reducer, getReturnType), inputHandler = AnyInputHandler(), multipleColumnsHandler = FlatteningMultipleColumnsHandler(), + statisticsParameters = emptyMap(), ) private fun flattenReducingForNumbers( getReturnType: CalculateReturnType, + statisticsParameters: Map, reducer: Reducer, ) = Aggregator( aggregationHandler = ReducingAggregationHandler(reducer, getReturnType), inputHandler = NumberInputHandler(), multipleColumnsHandler = FlatteningMultipleColumnsHandler(), + statisticsParameters = statisticsParameters, ) private fun twoStepReducingForNumbers( getReturnType: CalculateReturnType, + statisticsParameters: Map, reducer: Reducer, ) = Aggregator( aggregationHandler = ReducingAggregationHandler(reducer, getReturnType), inputHandler = NumberInputHandler(), multipleColumnsHandler = TwoStepMultipleColumnsHandler(), + statisticsParameters = statisticsParameters, ) /** @include [AggregatorOptionSwitch1] */ @@ -117,8 +128,9 @@ public object Aggregators { by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( getReturnType = minTypeConversion, - stepOneSelector = { type -> minOrNull(type, skipNaN) }, indexOfResult = { type -> indexOfMin(type, skipNaN) }, + stepOneSelector = { type -> minOrNull(type, skipNaN) }, + statisticsParameters = mapOf("skipNaN" to skipNaN), ) } @@ -132,6 +144,7 @@ public object Aggregators { getReturnType = maxTypeConversion, stepOneSelector = { type -> maxOrNull(type, skipNaN) }, indexOfResult = { type -> indexOfMax(type, skipNaN) }, + statisticsParameters = mapOf("skipNaN" to skipNaN), ) } @@ -140,17 +153,30 @@ public object Aggregators { skipNaN: Boolean, ddof: Int, -> - flattenReducingForNumbers(stdTypeConversion) { type -> - std(type, skipNaN, ddof) - } + flattenReducingForNumbers( + getReturnType = stdTypeConversion, + statisticsParameters = mapOf( + ("skipNaN" to skipNaN), + ("ddof" to ddof), + ), + reducer = { type -> + std(type, skipNaN, ddof) + }, + ) } // step one: T: Number? -> Double // step two: Double -> Double public val mean: AggregatorOptionSwitch1 by withOneOption { skipNaN: Boolean -> - twoStepReducingForNumbers(meanTypeConversion) { type -> - mean(type, skipNaN) - } + twoStepReducingForNumbers( + getReturnType = meanTypeConversion, + statisticsParameters = mapOf( + ("skipNaN" to skipNaN), + ), + reducer = { type -> + mean(type, skipNaN) + }, + ) } // T: primitive Number? -> Double? @@ -187,6 +213,10 @@ public object Aggregators { getReturnType = percentileConversion, reducer = { type -> percentileOrNull(percentile, type, skipNaN) as Comparable? }, indexOfResult = { type -> indexOfPercentile(percentile, type, skipNaN) }, + statisticsParameters = mapOf( + ("skipNaN" to skipNaN), + ("percentile" to percentile), + ), ) } @@ -215,6 +245,7 @@ public object Aggregators { getReturnType = medianConversion, reducer = { type -> medianOrNull(type, skipNaN) as Comparable? }, indexOfResult = { type -> indexOfMedian(type, skipNaN) }, + statisticsParameters = mapOf("skipNaN" to skipNaN), ) } @@ -223,8 +254,12 @@ public object Aggregators { // Short -> Int // Nothing -> Double public val sum: AggregatorOptionSwitch1 by withOneOption { skipNaN: Boolean -> - twoStepReducingForNumbers(sumTypeConversion) { type -> - sum(type, skipNaN) - } + twoStepReducingForNumbers( + getReturnType = sumTypeConversion, + statisticsParameters = mapOf("skipNaN" to skipNaN), + reducer = { type -> + sum(type, skipNaN) + }, + ) } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt index c24a3e34a8..bb8fa1356d 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt @@ -47,6 +47,7 @@ internal class TwoStepMultipleColumnsHandler( ?: aggregator as AggregatorAggregationHandler, inputHandler = stepTwoInputHandler ?: aggregator as AggregatorInputHandler, multipleColumnsHandler = NoMultipleColumnsHandler(), + statisticsParameters = emptyMap(), ).create(aggregator!!.name) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt index f758360d1f..947b082552 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt @@ -8,6 +8,15 @@ import org.jetbrains.kotlinx.dataframe.columns.ValueColumn import kotlin.reflect.KType import kotlin.reflect.full.withNullability +@JvmInline +internal value class StatisticResult(val value: Any?) + +internal interface ValueColumnInternal : ValueColumn { + fun putStatisticCache(statName: String, arguments: Map, value: StatisticResult) + + fun getStatisticCacheOrNull(statName: String, arguments: Map): StatisticResult? +} + internal open class ValueColumnImpl( values: List, name: String, @@ -15,7 +24,8 @@ internal open class ValueColumnImpl( val defaultValue: T? = null, distinct: Lazy>? = null, ) : DataColumnImpl(values, name, type, distinct), - ValueColumn { + ValueColumn, + ValueColumnInternal { override fun distinct() = ValueColumnImpl(toSet().toList(), name, type, defaultValue, distinct) @@ -48,10 +58,22 @@ internal open class ValueColumnImpl( override fun defaultValue() = defaultValue override fun forceResolve() = ResolvingValueColumn(this) + + private val statisticsCache = mutableMapOf, StatisticResult>>() + + override fun putStatisticCache(statName: String, arguments: Map, value: StatisticResult) { + statisticsCache.getOrPut(statName) { + mutableMapOf, StatisticResult>() + }[arguments] = value + } + + override fun getStatisticCacheOrNull(statName: String, arguments: Map): StatisticResult? = + statisticsCache[statName]?.get(arguments) } internal class ResolvingValueColumn(override val source: ValueColumn) : ValueColumn by source, + ValueColumnInternal, ForceResolvedColumn { override fun resolve(context: ColumnResolutionContext) = super.resolve(context) @@ -70,4 +92,15 @@ internal class ResolvingValueColumn(override val source: ValueColumn) : override fun equals(other: Any?) = source.checkEquals(other) override fun hashCode(): Int = source.hashCode() + + private val statisticsCache = mutableMapOf, StatisticResult>>() + + override fun putStatisticCache(statName: String, arguments: Map, value: StatisticResult) { + statisticsCache.getOrPut(statName) { + mutableMapOf, StatisticResult>() + }[arguments] = value + } + + override fun getStatisticCacheOrNull(statName: String, arguments: Map): StatisticResult? = + statisticsCache[statName]?.get(arguments) }