forked from twitter/the-algorithm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOfflineAggregateInjections.scala
46 lines (42 loc) · 1.87 KB
/
OfflineAggregateInjections.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
package com.twitter.timelines.data_processing.ml_util.aggregation_framework
import com.twitter.dal.personal_data.thriftscala.PersonalDataType
import com.twitter.ml.api.DataRecord
import com.twitter.ml.api.Feature
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.Batched
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.JavaCompactThrift
import com.twitter.scalding_internal.multiformat.format.keyval.KeyValInjection.genericInjection
import com.twitter.summingbird.batch.BatchID
import scala.collection.JavaConverters._
object OfflineAggregateInjections {
val offlineDataRecordAggregateInjection: KeyValInjection[AggregationKey, (BatchID, DataRecord)] =
KeyValInjection(
genericInjection(AggregationKeyInjection),
Batched(JavaCompactThrift[DataRecord])
)
private[aggregation_framework] def getPdts[T](
aggregateGroups: Iterable[T],
featureExtractor: T => Iterable[Feature[_]]
): Option[Set[PersonalDataType]] = {
val pdts: Set[PersonalDataType] = for {
group <- aggregateGroups.toSet[T]
feature <- featureExtractor(group)
pdtSet <- feature.getPersonalDataTypes.asSet().asScala
javaPdt <- pdtSet.asScala
scalaPdt <- PersonalDataType.get(javaPdt.getValue)
} yield {
scalaPdt
}
if (pdts.nonEmpty) Some(pdts) else None
}
def getInjection(
aggregateGroups: Set[TypedAggregateGroup[_]]
): KeyValInjection[AggregationKey, (BatchID, DataRecord)] = {
val keyPdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputKeys)
val valuePdts = getPdts[TypedAggregateGroup[_]](aggregateGroups, _.allOutputFeatures)
KeyValInjection(
genericInjection(AggregationKeyInjection, keyPdts),
genericInjection(Batched(JavaCompactThrift[DataRecord]), valuePdts)
)
}
}