Skip to content

Commit 96eceb0

Browse files
committed
Fix issues with operator fusion
1 parent 771a748 commit 96eceb0

File tree

8 files changed

+138
-112
lines changed

8 files changed

+138
-112
lines changed

src/main/scala/se/kth/cda/compiler/Compiler.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ object Compiler {
2222
}
2323
//print(metadata)
2424

25-
IdGenerator.reset()
25+
IdGenerator.resetGlobal()
2626

2727
val inputStream = CharStreams.fromString(metadata.arc_code)
2828
val lexer = new ArcLexer(inputStream)
@@ -33,10 +33,12 @@ object Compiler {
3333
//val expanded = MacroExpansion.expand(ast).get
3434
val typed = TypeInference.solve(ast).get
3535
val dfg = typed.toDFG
36+
//import se.kth.cda.arc.syntaxtree.PrettyPrint._
37+
//println(pretty(typed))
3638
//println(dfg.pretty)
3739

3840
val enriched_dfg = dfg.enrich(metadata)
39-
val optimized_dfg = enriched_dfg.optimize(fusion = false)
41+
val optimized_dfg = enriched_dfg.optimize(fusion = true)
4042
//println(optimized_dfg.pretty)
4143
//println(enriched_dfg.pretty)
4244

src/main/scala/se/kth/cda/compiler/dataflow/DFG.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ final case class DFG(id: String = DFGId.generate,
3737

3838
//case class Scope(depth: Long, parent: Option[Scope]) extends Id
3939

40-
final case class Node(var id: String, parallelism: Long = 1, kind: NodeKind, ord: Int = NodeId.newOrdering)
40+
final case class Node(var id: String, parallelism: Long = 1, kind: NodeKind, ord: Int = NodeId.newGlobalOrd)
4141

4242
sealed trait NodeKind
4343

@@ -55,7 +55,7 @@ object NodeKind {
5555
extends NodeKind
5656
final case class Task(var weldFunc: Expr,
5757
inputType: Type,
58-
outputType: Type,
58+
var outputType: Type,
5959
var predecessor: Node,
6060
var successors: Vector[ChannelKind] = Vector.empty,
6161
channelStrategy: ChannelStrategy = Forward,

src/main/scala/se/kth/cda/compiler/dataflow/encode/EncodeDFG.scala

+70-56
Original file line numberDiff line numberDiff line change
@@ -28,59 +28,73 @@ object EncodeDFG {
2828
Json.obj(
2929
("id", node.id.asJson),
3030
("parallelism", node.parallelism.asJson),
31-
("kind", node.kind.asJson(node.id)),
31+
("kind", node.kind.asJson(node)),
3232
)
3333

34-
implicit def encodeNodeKind(id: String): Encoder[NodeKind] = {
34+
implicit def encodeNodeKind(node: Node): Encoder[NodeKind] = {
3535
case source: Source =>
36+
val outputId = source.successors(0) match {
37+
case Local(succ) => StructId.from(node.id, succ.id)
38+
case Remote(succ, _) => StructId.from(node.id, succ.id)
39+
}
3640
Json.obj(
3741
("Source",
38-
Json.obj(
39-
("source_type", source.sourceType.asJson(encodeType(source.successors(0) match {
40-
case Local(node) => StructId.from(id, node.id)
41-
case Remote(node, _) => StructId.from(id, node.id)
42-
}))),
43-
("format", source.format.asJson),
44-
("channel_strategy", source.channelStrategy.asJson),
45-
("successors", source.successors.asJson),
46-
("kind", source.kind.asJson),
47-
)))
42+
Json.obj(
43+
("source_type", source.sourceType.asJson(encodeType(outputId))),
44+
("format", source.format.asJson),
45+
("channel_strategy", source.channelStrategy.asJson),
46+
("successors", source.successors.asJson),
47+
("kind", source.kind.asJson),
48+
)))
4849
case task: Task =>
50+
val (a, b) = getInputId(task.predecessor, node.id)
51+
val inputId = StructId.from(a, b)
52+
val outputId = task.kind match {
53+
case TaskKind.Filter => inputId
54+
case _ => task.successors(0) match {
55+
case Local(succ) => StructId.from(node.id, succ.id)
56+
case Remote(succ, _) => StructId.from(node.id, succ.id)
57+
}
58+
}
4959
Json.obj(
5060
("Task",
51-
Json.obj(
52-
("weld_code", pretty(task.weldFunc).asJson),
53-
("input_type", task.inputType.asJson(encodeType(StructId.from(task.predecessor.id, id)))),
54-
("output_type", task.outputType.asJson(encodeType(task.successors(0) match {
55-
case Local(node) => StructId.from(id, node.id)
56-
case Remote(node, _) => StructId.from(id, node.id)
57-
}))),
58-
("channel_strategy", task.channelStrategy.asJson),
59-
("predecessor", task.predecessor.id.asJson),
60-
("successors", task.successors.asJson),
61-
("kind", task.kind.asJson),
62-
)))
61+
Json.obj(
62+
("weld_code", pretty(task.weldFunc).asJson),
63+
("input_type", task.inputType.asJson(encodeType(inputId))),
64+
("output_type", task.outputType.asJson(encodeType(outputId))),
65+
("channel_strategy", task.channelStrategy.asJson),
66+
("predecessor", task.predecessor.id.asJson),
67+
("successors", task.successors.asJson),
68+
("kind", task.kind.asJson),
69+
)))
6370
case sink: Sink =>
71+
val inputId = StructId.from(sink.predecessor.id, node.id)
6472
Json.obj(
6573
("Sink",
66-
Json.obj(
67-
("sink_type", sink.sinkType.asJson(encodeType(StructId.from(sink.predecessor.id, id)))),
68-
("format", sink.format.asJson),
69-
("predecessor", sink.predecessor.id.asJson),
70-
("kind", sink.kind.asJson),
71-
)))
74+
Json.obj(
75+
("sink_type", sink.sinkType.asJson(encodeType(inputId))),
76+
("format", sink.format.asJson),
77+
("predecessor", sink.predecessor.id.asJson),
78+
("kind", sink.kind.asJson),
79+
)))
7280
case window: Window =>
81+
val (a, b) = getInputId(window.predecessor, node.id)
82+
val inputId = StructId.from(a, b)
83+
val outputId = window.successors(0) match {
84+
case Local(succ) => StructId.from(node.id, succ.id)
85+
case Remote(succ, _) => StructId.from(node.id, succ.id)
86+
}
7387
Json.obj(
7488
("Window",
75-
Json.obj(
76-
("channel_strategy", window.channelStrategy.asJson),
77-
("predecessor", window.predecessor.id.asJson),
78-
("successors", window.successors.asJson),
79-
("assigner", window.assigner.asJson),
80-
("window_function", window.function.asJson),
81-
("time_kind", window.time.asJson),
82-
("window_kind", window.kind.asJson),
83-
)))
89+
Json.obj(
90+
("channel_strategy", window.channelStrategy.asJson),
91+
("predecessor", window.predecessor.id.asJson),
92+
("successors", window.successors.asJson),
93+
("assigner", window.assigner.asJson),
94+
("window_function", window.function.asJson(encodeWindowFunction(inputId, outputId))),
95+
("time_kind", window.time.asJson),
96+
("window_kind", window.kind.asJson),
97+
)))
8498
}
8599

86100
implicit val encodeSourceKind: Encoder[SourceKind] = {
@@ -151,15 +165,16 @@ object EncodeDFG {
151165
case Broadcast => "Broadcast".asJson
152166
}
153167

154-
implicit val encodeWindowFunction: Encoder[WindowFunction] = function =>
155-
Json.obj(
156-
("input_type", function.inputType.asJson(encodeType())),
157-
("output_type", function.outputType.asJson(encodeType())),
158-
("builder_type", function.builderType.asJson(encodeType())),
159-
("builder", pretty(function.init).asJson),
160-
("udf", pretty(function.lift).asJson),
161-
("materialiser", pretty(function.lower).asJson),
162-
)
168+
implicit def encodeWindowFunction(inputId: String, outputId: String): Encoder[WindowFunction] =
169+
function =>
170+
Json.obj(
171+
("input_type", function.inputType.asJson(encodeType(inputId))),
172+
("output_type", function.outputType.asJson(encodeType(outputId))),
173+
("builder_type", function.builderType.asJson(encodeType())),
174+
("builder", pretty(function.init).asJson),
175+
("udf", pretty(function.lift).asJson),
176+
("materialiser", pretty(function.lower).asJson),
177+
)
163178

164179
implicit val encodeWindowAssigner: Encoder[WindowAssigner] = {
165180
case tumbling: Tumbling =>
@@ -179,13 +194,12 @@ object EncodeDFG {
179194
case All => "All".asJson
180195
}
181196

182-
//implicit val encodeType: Encoder[Type] = {
183-
// case _: Appender => "Appender".asJson
184-
// case _: Merger => "Merger".asJson
185-
// case _: VecMerger => "VecMerger".asJson
186-
// case _: DictMerger => "DictMerger".asJson
187-
// case _: GroupMerger => "GroupMerger".asJson
188-
// case _ => ???
189-
//}
197+
// Walks backwards and finds the first output coming from a non-filter
198+
private def getInputId(node: Node, succ_id: String): (String, String) = {
199+
node.kind match {
200+
case Task(_, _, _, pred, _, _, Filter, _) => getInputId(pred, node.id)
201+
case _: Source | _: Task | _: Window | _: Sink => (node.id, succ_id)
202+
}
203+
}
190204

191205
}

src/main/scala/se/kth/cda/compiler/dataflow/encode/EncodeType.scala

+49-46
Original file line numberDiff line numberDiff line change
@@ -9,52 +9,55 @@ import se.kth.cda.compiler.dataflow.IdGenerator.StructId
99

1010
object EncodeType {
1111

12-
implicit def encodeType(name: String = StructId.newId, key: Option[Long] = None): Encoder[Type] = ty => {
13-
val newName = StructId.next(name)
14-
ty match {
15-
case Bool => Json.obj(("Scalar", "Bool".asJson))
16-
case I8 => Json.obj(("Scalar", "I8".asJson))
17-
case I16 => Json.obj(("Scalar", "I16".asJson))
18-
case I32 => Json.obj(("Scalar", "I32".asJson))
19-
case I64 => Json.obj(("Scalar", "I64".asJson))
20-
case U8 => Json.obj(("Scalar", "U8".asJson))
21-
case U16 => Json.obj(("Scalar", "U16".asJson))
22-
case U32 => Json.obj(("Scalar", "U32".asJson))
23-
case U64 => Json.obj(("Scalar", "U64".asJson))
24-
case F32 => Json.obj(("Scalar", "F32".asJson))
25-
case F64 => Json.obj(("Scalar", "F64".asJson))
26-
case UnitT => Json.obj(("Scalar", "Unit".asJson))
27-
case StringT => Json.obj(("Scalar", "String".asJson))
28-
case Simd(elemTy) => Json.obj(("Simd", elemTy.asJson(encodeType(newName, key))))
29-
case Vec(elemTy) => Json.obj(("Vector", Json.obj(("elem_ty", elemTy.asJson(encodeType(newName, key))))))
30-
case Struct(elemTys) =>
31-
Json.obj(
32-
("Struct",
33-
Json.obj(("id", newName.asJson), ("field_tys", elemTys.map(_.asJson(encodeType(newName, key))).asJson))))
34-
case Dict(keyTy, valueTy) =>
35-
Json.obj(
36-
("Dict",
37-
Json.obj(("key_ty", keyTy.asJson(encodeType(newName, key))),
38-
("value_ty", valueTy.asJson(encodeType(newName, key))))))
39-
case Appender(elemTy, _) =>
40-
Json.obj(("Appender", Json.obj(("elem_ty", elemTy.asJson(encodeType(newName, key))))))
41-
case Merger(elemTy, opTy, _) =>
42-
Json.obj(("Merger", Json.obj(("elem_ty", elemTy.asJson(encodeType(newName, key))), ("op_ty", opTy.asJson))))
43-
case VecMerger(elemTy, opTy, _) =>
44-
Json.obj(("VecMerger", Json.obj(("elem_ty", elemTy.asJson(encodeType(newName, key))), ("op_ty", opTy.asJson))))
45-
case GroupMerger(keyTy, valueTy, _) =>
46-
Json.obj(
47-
("GroupMerger",
48-
Json.obj(("key_ty", keyTy.asJson(encodeType(newName, key))),
49-
("value_ty", valueTy.asJson(encodeType(newName, key))))))
50-
case DictMerger(keyTy, valueTy, opTy, _) =>
51-
Json.obj(
52-
("DictMerger",
53-
Json.obj(("key_ty", keyTy.asJson(encodeType(newName, key))),
54-
("value_ty", valueTy.asJson(encodeType(newName, key))),
55-
("op_ty", opTy.asJson))))
56-
case _ => ???
57-
}
12+
implicit def encodeType(nodeId: String = StructId.newGlobalId, key: Option[Long] = None): Encoder[Type] = {
13+
StructId.resetLocal()
14+
encodeTypeRec(nodeId, key)
15+
}
16+
17+
def encodeTypeRec(nodeId: String, key: Option[Long] = None): Encoder[Type] = {
18+
case Bool => Json.obj(("Scalar", "Bool".asJson))
19+
case I8 => Json.obj(("Scalar", "I8".asJson))
20+
case I16 => Json.obj(("Scalar", "I16".asJson))
21+
case I32 => Json.obj(("Scalar", "I32".asJson))
22+
case I64 => Json.obj(("Scalar", "I64".asJson))
23+
case U8 => Json.obj(("Scalar", "U8".asJson))
24+
case U16 => Json.obj(("Scalar", "U16".asJson))
25+
case U32 => Json.obj(("Scalar", "U32".asJson))
26+
case U64 => Json.obj(("Scalar", "U64".asJson))
27+
case F32 => Json.obj(("Scalar", "F32".asJson))
28+
case F64 => Json.obj(("Scalar", "F64".asJson))
29+
case UnitT => Json.obj(("Scalar", "Unit".asJson))
30+
case StringT => Json.obj(("Scalar", "String".asJson))
31+
case Simd(elemTy) => Json.obj(("Simd", elemTy.asJson(encodeTypeRec(nodeId, key))))
32+
case Vec(elemTy) => Json.obj(("Vector", Json.obj(("elem_ty", elemTy.asJson(encodeTypeRec(nodeId, key))))))
33+
case Struct(elemTys) =>
34+
Json.obj(
35+
("Struct",
36+
Json.obj(("id", StructId.nextLocal(nodeId).asJson),
37+
("field_tys", elemTys.map(_.asJson(encodeTypeRec(nodeId, key))).asJson))))
38+
case Dict(keyTy, valueTy) =>
39+
Json.obj(
40+
("Dict",
41+
Json.obj(("key_ty", keyTy.asJson(encodeTypeRec(nodeId, key))),
42+
("value_ty", valueTy.asJson(encodeTypeRec(nodeId, key))))))
43+
case Appender(elemTy, _) =>
44+
Json.obj(("Appender", Json.obj(("elem_ty", elemTy.asJson(encodeTypeRec(nodeId, key))))))
45+
case Merger(elemTy, opTy, _) =>
46+
Json.obj(("Merger", Json.obj(("elem_ty", elemTy.asJson(encodeTypeRec(nodeId, key))), ("op_ty", opTy.asJson))))
47+
case VecMerger(elemTy, opTy, _) =>
48+
Json.obj(("VecMerger", Json.obj(("elem_ty", elemTy.asJson(encodeTypeRec(nodeId, key))), ("op_ty", opTy.asJson))))
49+
case GroupMerger(keyTy, valueTy, _) =>
50+
Json.obj(
51+
("GroupMerger",
52+
Json.obj(("key_ty", keyTy.asJson(encodeTypeRec(nodeId, key))),
53+
("value_ty", valueTy.asJson(encodeTypeRec(nodeId, key))))))
54+
case DictMerger(keyTy, valueTy, opTy, _) =>
55+
Json.obj(
56+
("DictMerger",
57+
Json.obj(("key_ty", keyTy.asJson(encodeTypeRec(nodeId, key))),
58+
("value_ty", valueTy.asJson(encodeTypeRec(nodeId, key))),
59+
("op_ty", opTy.asJson))))
60+
case _ => ???
5861
}
5962

6063
implicit val encodeMergeOp: Encoder[MergeOp] = {

src/main/scala/se/kth/cda/compiler/dataflow/optimize/Fusion.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import se.kth.cda.arc.syntaxtree.Type.Function
88
import se.kth.cda.compiler.Utils._
99
import se.kth.cda.compiler.dataflow.Analyzer._
1010
import se.kth.cda.compiler.dataflow.ChannelKind.{Local, Remote}
11+
import se.kth.cda.compiler.dataflow.IdGenerator.NodeId
1112
import se.kth.cda.compiler.dataflow.Node
1213
import se.kth.cda.compiler.dataflow.NodeKind.{Sink, Source, Task, Window}
1314

@@ -92,7 +93,8 @@ object Fusion {
9293
}
9394
pred.successors = self.successors
9495
pred.weldFunc = fuseWeldFuncs(pred.weldFunc, self.weldFunc)
95-
self.predecessor.id = s"${self.predecessor.id}_${node.id}"
96+
pred.outputType = self.outputType
97+
self.predecessor.id = NodeId.fuse(self.predecessor.id, node.id)
9698
self.removed = true
9799
case _ => ()
98100
}

src/main/scala/se/kth/cda/compiler/dataflow/optimize/OptimizeDFG.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ object OptimizeDFG {
99
import se.kth.cda.compiler.dataflow.optimize.Specialization._
1010
implicit class OptimizeDFG(val dfg: DFG) extends AnyVal {
1111
def optimize(fusion: Boolean = true): DFG = {
12-
//println(encodeDFG(dfg))
12+
//import se.kth.cda.compiler.dataflow.pretty.PrettyPrint._
13+
//import se.kth.cda.compiler.dataflow.deploy.Deploy._
14+
//println(dfg.order.pretty)
1315
if (fusion) {
1416
dfg.nodes
1517
.filter(_.kind match {
@@ -34,6 +36,8 @@ object OptimizeDFG {
3436
})
3537
.foreach(_.specialize())
3638

39+
//println(dfg.order.pretty)
40+
3741
dfg
3842
}
3943
}

src/main/scala/se/kth/cda/compiler/dataflow/transform/ToDFG.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ object ToDFG {
2727
case lambda: Lambda =>
2828
(lambda.params.map {
2929
case Parameter(symbol, StreamAppender(elemTy, _)) =>
30-
symbol.name -> Node(id = SinkId.newId, kind = Sink(sinkType = elemTy))
30+
symbol.name -> Node(id = SinkId.newGlobalId, kind = Sink(sinkType = elemTy))
3131
case Parameter(symbol, Stream(elemTy)) =>
32-
symbol.name -> Node(id = SourceId.newId, kind = Source(sourceType = elemTy))
32+
symbol.name -> Node(id = SourceId.newGlobalId, kind = Source(sourceType = elemTy))
3333
case _ => ???
3434
}.toMap, lambda.body)
3535
case _ => ???
@@ -70,7 +70,7 @@ object ToDFG {
7070
val (inputType, precedessor, _) = transformSource(iter, nodes)
7171
val nodeKind = transformSink(sink, func, inputType, precedessor)
7272
// Add node as successor to predecessor
73-
val newNode = Node(id = TaskId.newId, kind = nodeKind)
73+
val newNode = Node(id = TaskId.newGlobalId, kind = nodeKind)
7474
precedessor.kind match {
7575
case source: Source => source.successors = source.successors :+ Local(node = newNode)
7676
case task: Task => task.successors = task.successors :+ Local(node = newNode)

src/main/scala/se/kth/cda/compiler/dataflow/transform/Utils.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ object Utils {
5858
}(self)
5959
}
6060

61-
// Transforms an Arc stream into its element type
61+
// Transforms an Arc stream or streamappender into its element type
6262
def toElemType: Type = {
6363
fix[Type, Type] { f =>
6464
{
6565
case ty: Stream => ty.elemTy
66+
case ty: StreamAppender => ty.elemTy
6667
case ty: Struct => Struct(ty.elemTys.map(f))
6768
case ty @ _ => ty
6869
}

0 commit comments

Comments
 (0)