forked from scala/scala3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLambdaLift.scala
341 lines (297 loc) · 13.7 KB
/
LambdaLift.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
package dotty.tools.dotc
package transform
import MegaPhase.*
import core.Denotations.NonSymSingleDenotation
import core.DenotTransformers.*
import core.Symbols.*
import core.Contexts.*
import core.Types.*
import core.Flags.*
import core.Decorators.*
import core.StdNames.nme
import core.Names.*
import core.NameOps.*
import core.NameKinds.ExpandPrefixName
import ExplicitOuter.outer
import util.Store
import collection.mutable.{HashMap, LinkedHashMap, ListBuffer}
object LambdaLift:
import ast.tpd.*
val name: String = "lambdaLift"
val description: String = "lifts out nested functions to class scope"
/** The core lambda lift functionality. */
class Lifter(thisPhase: MiniPhase & DenotTransformer)(using Context):
/** The outer parameter of a constructor */
private val outerParam = new HashMap[Symbol, Symbol]
/** Buffers for lifted out classes and methods, indexed by owner */
val liftedDefs: HashMap[Symbol, ListBuffer[Tree]] = new HashMap
val deps = new Dependencies(ctx.compilationUnit.tpdTree, ctx.withPhase(thisPhase)):
def isExpr(sym: Symbol)(using Context): Boolean = sym.is(Method)
def enclosure(using Context) = ctx.owner.enclosingMethod
override def process(tree: Tree)(using Context): Unit =
super.process(tree)
tree match
case tree: DefDef if tree.symbol.isConstructor =>
tree.termParamss.head.find(_.name == nme.OUTER) match
case Some(vdef) => outerParam(tree.symbol) = vdef.symbol
case _ =>
case tree: Template =>
liftedDefs(tree.symbol.owner) = new ListBuffer
case _ =>
end deps
/** A map storing the free variable proxies of functions and classes.
* For every function and class, this is a map from the free variables
* of that function or class to the proxy symbols accessing them.
*/
private val proxyMap = new LinkedHashMap[Symbol, Map[Symbol, Symbol]]
def proxyOf(sym: Symbol, fv: Symbol): Symbol = proxyMap.getOrElse(sym, Map.empty)(fv)
def proxies(sym: Symbol): List[Symbol] =
deps.freeVars(sym).toList.map(proxyOf(sym, _))
private def newName(sym: Symbol)(using Context): Name =
if (sym.isAnonymousFunction && sym.owner.is(Method))
sym.name.replace {
case name: SimpleName => ExpandPrefixName(sym.owner.name.asTermName, name)
}.freshened
else sym.name.freshened
private def generateProxies()(using Context): Unit =
for owner <- deps.tracked do
val fvs = deps.freeVars(owner).toList
val newFlags = Synthetic | (if (owner.isClass) PrivateParamAccessor else Param)
report.debuglog(i"free var proxy of ${owner.showLocated}: $fvs%, %")
val freeProxyPairs =
for fv <- fvs yield
val proxyName = newName(fv)
val proxy =
newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord)
.enteredAfter(thisPhase)
(fv, proxy)
proxyMap(owner) = freeProxyPairs.toMap
private def liftedInfo(local: Symbol)(using Context): Type = local.info match {
case MethodTpe(pnames, ptypes, restpe) =>
val ps = proxies(local)
MethodType(
ps.map(_.name.asTermName) ++ pnames,
ps.map(_.info) ++ ptypes,
restpe)
case info => info
}
private def liftLocals()(using Context): Unit = {
for ((local, lOwner) <- deps.logicalOwner) {
val (newOwner, maybeStatic) =
if lOwner is Package then (local.topLevelClass, JavaStatic)
else (lOwner, EmptyFlags)
// Drop Module because class is no longer a singleton in the lifted context.
var initFlags = local.flags &~ Module | Private | Lifted | maybeStatic
if (local is Method)
if (newOwner is Trait)
// Drop Final when a method is lifted into a trait.
// According to the JVM specification, a method declared inside interface cannot have the final flag.
// "Methods of interfaces may have any of the flags in Table 4.6-A set except ACC_PROTECTED, ACC_FINAL, ..."
// (https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.6)
initFlags = initFlags &~ Final
else
// Add Final when a method is lifted into a class.
initFlags = initFlags | Final
local.copySymDenotation(
owner = newOwner,
name = newName(local),
initFlags = initFlags,
info = liftedInfo(local)).installAfter(thisPhase)
}
for (local <- deps.tracked)
if (!deps.logicalOwner.contains(local))
local.copySymDenotation(info = liftedInfo(local)).installAfter(thisPhase)
}
def currentEnclosure(using Context): Symbol =
ctx.owner.enclosingMethodOrClass
private def inCurrentOwner(sym: Symbol)(using Context) =
sym.enclosure == currentEnclosure
private def proxy(sym: Symbol)(using Context): Symbol = {
def liftedEnclosure(sym: Symbol) =
deps.logicalOwner.getOrElse(sym, sym.enclosure)
def searchIn(enclosure: Symbol): Symbol = {
if (!enclosure.exists) {
def enclosures(encl: Symbol): List[Symbol] =
if (encl.exists) encl :: enclosures(liftedEnclosure(encl)) else Nil
throw new IllegalArgumentException(i"Could not find proxy for ${sym.showDcl} in ${sym.ownersIterator.toList}, encl = $currentEnclosure, owners = ${currentEnclosure.ownersIterator.toList}%, %; enclosures = ${enclosures(currentEnclosure)}%, %")
}
report.debuglog(i"searching for $sym(${sym.owner}) in $enclosure")
proxyMap get enclosure match {
case Some(pmap) =>
pmap get sym match {
case Some(proxy) => return proxy
case none =>
}
case none =>
}
searchIn(liftedEnclosure(enclosure))
}
if (inCurrentOwner(sym)) sym else searchIn(currentEnclosure)
}
def memberRef(sym: Symbol)(using Context): Tree = {
val clazz = sym.enclosingClass
val qual =
if (clazz.isStaticOwner || ctx.owner.enclosingClass == clazz)
singleton(clazz.thisType)
else if (ctx.owner.isConstructor)
outerParam.get(ctx.owner) match {
case Some(param) => outer.path(start = Ident(param.termRef), toCls = clazz)
case _ => outer.path(toCls = clazz)
}
else outer.path(toCls = clazz)
thisPhase.transformFollowingDeep(qual.select(sym))
}
def proxyRef(sym: Symbol)(using Context): Tree = {
val psym = atPhase(thisPhase)(proxy(sym))
thisPhase.transformFollowingDeep(if (psym.owner.isTerm) ref(psym) else memberRef(psym))
}
def addFreeArgs(sym: Symbol, args: List[Tree])(using Context): List[Tree] =
val fvs = deps.freeVars(sym)
if fvs.nonEmpty then fvs.toList.map(proxyRef(_)) ++ args else args
def addFreeParams(tree: Tree, proxies: List[Symbol])(using Context): Tree = proxies match {
case Nil => tree
case proxies =>
val sym = tree.symbol
val freeParamDefs = proxies.map(proxy =>
thisPhase.transformFollowingDeep(ValDef(proxy.asTerm).withSpan(tree.span)).asInstanceOf[ValDef])
def proxyInit(field: Symbol, param: Symbol) =
thisPhase.transformFollowingDeep(memberRef(field).becomes(ref(param)))
/** Initialize proxy fields from proxy parameters and map `rhs` from fields to parameters */
def copyParams(rhs: Tree) = {
val fvs = deps.freeVars(sym.owner).toList
val classProxies = fvs.map(proxyOf(sym.owner, _))
val constrProxies = fvs.map(proxyOf(sym, _))
report.debuglog(i"copy params ${constrProxies.map(_.showLocated)}%, % to ${classProxies.map(_.showLocated)}%, %}")
seq(classProxies.lazyZip(constrProxies).map(proxyInit), rhs)
}
tree match {
case tree: DefDef =>
cpy.DefDef(tree)(
paramss = tree.termParamss.map(freeParamDefs ++ _),
rhs =
if (sym.isPrimaryConstructor && !sym.owner.is(Trait)) copyParams(tree.rhs)
else tree.rhs)
case tree: Template =>
cpy.Template(tree)(body = freeParamDefs ++ tree.body)
}
}
def liftDef(tree: MemberDef)(using Context): Tree = {
val buf = liftedDefs(tree.symbol.owner)
thisPhase.transformFollowing(rename(tree, tree.symbol.name)).foreachInThicket(buf += _)
EmptyTree
}
def needsLifting(sym: Symbol): Boolean = deps.logicalOwner.contains(sym)
// initialization
atPhase(thisPhase.next) {
generateProxies()
liftLocals()
}
end Lifter
end LambdaLift
/** This phase performs the necessary rewritings to eliminate classes and methods
* nested in other methods. In detail:
* 1. It adds all free variables of local functions as additional parameters (proxies).
* 2. It rebinds references to free variables to the corresponding proxies,
* 3. It lifts all local functions and classes out as far as possible, but at least
* to the enclosing class.
* 4. It stores free variables of non-trait classes as additional fields of the class.
* The fields serve as proxies for methods in the class, which avoids the need
* of passing additional parameters to these methods.
*
* A particularly tricky case are local traits. These cannot store free variables
* as field proxies, because LambdaLift runs after Mixin, so the fields cannot be
* expanded anymore. Instead, methods of local traits get free variables of
* the trait as additional proxy parameters. The difference between local classes
* and local traits is illustrated by the two rewritings below.
*
* def f(x: Int) = { def f(x: Int) = new C(x).f2
* class C { ==> class C(x$1: Int) {
* def f2 = x def f2 = x$1
* } }
* new C().f2
* }
*
* def f(x: Int) = { def f(x: Int) = new C().f2(x)
* trait T { ==> trait T
* def f2 = x def f2(x$1: Int) = x$1
* } }
* class C extends T class C extends T
* new C().f2
* }
*/
class LambdaLift extends MiniPhase with IdentityDenotTransformer { thisPhase =>
import LambdaLift.*
import ast.tpd.*
override def phaseName: String = LambdaLift.name
override def description: String = LambdaLift.description
override def relaxedTypingInGroup: Boolean = true
// Because it adds free vars as additional proxy parameters
override def runsAfterGroupsOf: Set[String] = Set(Constructors.name, HoistSuperArgs.name)
// Constructors has to happen before LambdaLift because the lambda lift logic
// becomes simpler if it can assume that parameter accessors have already been
// converted to parameters in super calls. Without this it is very hard to get
// lambda lift for super calls right. Witness the implementation restrictions to
// this effect in scalac.
private var Lifter: Store.Location[Lifter] = _
private def lifter(using Context) = ctx.store(Lifter)
override def initContext(ctx: FreshContext): Unit =
Lifter = ctx.addLocation[Lifter]()
override def prepareForUnit(tree: Tree)(using Context): Context =
ctx.fresh.updateStore(Lifter, new Lifter(thisPhase))
override def transformIdent(tree: Ident)(using Context): Tree = {
val sym = tree.symbol
tree.tpe match {
case tpe @ TermRef(prefix, _) =>
val lft = lifter
if (prefix eq NoPrefix)
if (sym.enclosure != lft.currentEnclosure && !sym.isStatic)
(if (sym is Method) lft.memberRef(sym) else lft.proxyRef(sym)).withSpan(tree.span)
else if (sym.owner.isClass) // sym was lifted out
ref(sym).withSpan(tree.span)
else
tree
else if (!prefixIsElidable(tpe)) ref(tpe)
else tree
case _ =>
tree
}
}
override def transformSelect(tree: Select)(using Context): Tree =
val denot = tree.denot
val sym = tree.symbol
// The Lifter updates the type of symbols using `installAfter` to give them a
// new `SymDenotation`, but that doesn't affect non-sym denotations, so we
// reload them manually here.
// Note: If you tweak this code, make sure to test your changes with
// `Config.reuseSymDenotations` set to false to exercise this path more.
if denot.isInstanceOf[NonSymSingleDenotation] && lifter.deps.freeVars(sym).nonEmpty then
tree.qualifier.select(sym).withSpan(tree.span)
else tree
override def transformApply(tree: Apply)(using Context): Apply =
cpy.Apply(tree)(tree.fun, lifter.addFreeArgs(tree.symbol, tree.args)).withSpan(tree.span)
override def transformClosure(tree: Closure)(using Context): Closure =
cpy.Closure(tree)(env = lifter.addFreeArgs(tree.meth.symbol, tree.env))
override def transformDefDef(tree: DefDef)(using Context): Tree = {
val sym = tree.symbol
val lft = lifter
val paramsAdded =
if lft.deps.freeVars(sym).nonEmpty then lft.addFreeParams(tree, lft.proxies(sym)).asInstanceOf[DefDef]
else tree
if (lft.needsLifting(sym)) lft.liftDef(paramsAdded)
else paramsAdded
}
override def transformReturn(tree: Return)(using Context): Tree = tree.expr match {
case Block(stats, value) =>
Block(stats, Return(value, tree.from)).withSpan(tree.span)
case _ =>
tree
}
override def transformTemplate(tree: Template)(using Context): Template = {
val cls = ctx.owner
val lft = lifter
val impl = lft.addFreeParams(tree, lft.proxies(cls)).asInstanceOf[Template]
cpy.Template(impl)(body = impl.body ++ lft.liftedDefs.remove(cls).get)
}
override def transformTypeDef(tree: TypeDef)(using Context): Tree =
if (lifter.needsLifting(tree.symbol)) lifter.liftDef(tree) else tree
}