@@ -12,11 +12,6 @@ import Annotations.Annotation
12
12
13
13
object MainProxies {
14
14
15
- /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16
- def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
17
- mainAnnotationProxies(stats) ++ mainProxies(stats)
18
- }
19
-
20
15
/** Generate proxy classes for @main functions.
21
16
* A function like
22
17
*
@@ -35,7 +30,7 @@ object MainProxies {
35
30
* catch case err: ParseError => showError(err)
36
31
* }
37
32
*/
38
- private def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
33
+ def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
39
34
import tpd .*
40
35
def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
41
36
case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -127,323 +122,4 @@ object MainProxies {
127
122
result
128
123
}
129
124
130
- private type DefaultValueSymbols = Map [Int , Symbol ]
131
- private type ParameterAnnotationss = Seq [Seq [Annotation ]]
132
-
133
- /**
134
- * Generate proxy classes for main functions.
135
- * A function like
136
- *
137
- * /* *
138
- * * Lorem ipsum dolor sit amet
139
- * * consectetur adipiscing elit.
140
- * *
141
- * * @param x my param x
142
- * * @param ys all my params y
143
- * */
144
- * @myMain(80) def f(
145
- * @myMain.Alias("myX") x: S,
146
- * y: S,
147
- * ys: T*
148
- * ) = ...
149
- *
150
- * would be translated to something like
151
- *
152
- * final class f {
153
- * static def main(args: Array[String]): Unit = {
154
- * val annotation = new myMain(80)
155
- * val info = new Info(
156
- * name = "f",
157
- * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
158
- * parameters = Seq(
159
- * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))),
160
- * new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()),
161
- * new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq())
162
- * )
163
- * ),
164
- * val command = annotation.command(info, args)
165
- * if command.isDefined then
166
- * val cmd = command.get
167
- * val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None)
168
- * val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1()))
169
- * val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2))
170
- * annotation.run(() => f(args0(), args1(), args2()*))
171
- * }
172
- * }
173
- */
174
- private def mainAnnotationProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
175
- import tpd .*
176
-
177
- /**
178
- * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
179
- * point of the compilation, they must be explicitly passed by [[mainProxy ]].
180
- */
181
- def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
182
- scope match {
183
- case TypeDef (_, template : Template ) =>
184
- template.body.flatMap((_ : Tree ) match {
185
- case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
186
- val DefaultGetterName .NumberedInfo (index) = dd.name.info: @ unchecked
187
- List (index -> dd.symbol)
188
- case _ => Nil
189
- }).toMap
190
- case _ => Map .empty
191
- }
192
-
193
- /** Computes the list of main methods present in the code. */
194
- def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
195
- case stat : DefDef =>
196
- val sym = stat.symbol
197
- sym.annotations.filter(_.matches(defn.MainAnnotationClass )) match {
198
- case Nil =>
199
- Nil
200
- case _ :: Nil =>
201
- val paramAnnotations = stat.paramss.flatMap(_.map(
202
- valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation ))
203
- ))
204
- (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
205
- case mainAnnot :: others =>
206
- report.error(em " method cannot have multiple main annotations " , mainAnnot.tree)
207
- Nil
208
- }
209
- case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
210
- mainMethods(stat, impl.body)
211
- case _ =>
212
- Nil
213
- }
214
-
215
- // Assuming that the top-level object was already generated, all main methods will have a scope
216
- mainMethods(EmptyTree , stats).flatMap(mainAnnotationProxy)
217
- }
218
-
219
- private def mainAnnotationProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): Option [TypeDef ] = {
220
- val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass ).get
221
- def pos = mainFun.sourcePos
222
-
223
- val documentation = new Documentation (docComment)
224
-
225
- /** () => value */
226
- def unitToValue (value : Tree ): Tree =
227
- val defDef = DefDef (nme.ANON_FUN , List (Nil ), TypeTree (), value)
228
- Block (defDef, Closure (Nil , Ident (nme.ANON_FUN ), EmptyTree ))
229
-
230
- /** Generate a list of trees containing the ParamInfo instantiations.
231
- *
232
- * A ParamInfo has the following shape
233
- * ```
234
- * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
235
- * ```
236
- */
237
- def parameterInfos (mt : MethodType ): List [Tree ] =
238
- extension (tree : Tree ) def withProperty (sym : Symbol , args : List [Tree ]) =
239
- Apply (Select (tree, sym.name), args)
240
-
241
- for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
242
- val param = paramName.toString
243
- val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
244
- val paramType = paramType0.dealias
245
- val paramTypeOwner = paramType.typeSymbol.owner
246
- val paramTypeStr =
247
- if paramTypeOwner == defn.EmptyPackageClass then paramType.show
248
- else paramTypeOwner.showFullName + " ." + paramType.show
249
- val hasDefault = defaultValueSymbols.contains(idx)
250
- val isRepeated = formal.isRepeatedParam
251
- val paramDoc = documentation.argDocs.getOrElse(param, " " )
252
- val paramAnnots =
253
- val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
254
- Apply (ref(defn.SeqModule .termRef), annotationTrees)
255
-
256
- val constructorArgs = List (param, paramTypeStr, hasDefault, isRepeated, paramDoc)
257
- .map(value => Literal (Constant (value)))
258
-
259
- New (TypeTree (defn.MainAnnotationParameter .typeRef), List (constructorArgs :+ paramAnnots))
260
-
261
- end parameterInfos
262
-
263
- /**
264
- * Creates a list of references and definitions of arguments.
265
- * The goal is to create the
266
- * `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)`
267
- * part of the code.
268
- */
269
- def argValDefs (mt : MethodType ): List [ValDef ] =
270
- for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
271
- val argName = nme.args ++ idx.toString
272
- val isRepeated = formal.isRepeatedParam
273
- val formalType = if isRepeated then formal.argTypes.head else formal
274
- val getterName = if isRepeated then nme.varargGetter else nme.argGetter
275
- val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
276
- case None => ref(defn.NoneModule .termRef)
277
- case Some (dvSym) =>
278
- val value = unitToValue(ref(dvSym.termRef))
279
- Apply (ref(defn.SomeClass .companionModule.termRef), value)
280
- val argGetter0 = TypeApply (Select (Ident (nme.annotation), getterName), TypeTree (formalType) :: Nil )
281
- val index = Literal (Constant (idx))
282
- val paramInfo = Apply (Select (Ident (nme.info), nme.parameters), index)
283
- val argGetter =
284
- if isRepeated then Apply (argGetter0, List (paramInfo, Apply (Select (Ident (nme.cmd), nme.drop), List (index))))
285
- else Apply (argGetter0, List (paramInfo, Apply (Ident (nme.cmd), List (index)), defaultValueGetterOpt))
286
- ValDef (argName, TypeTree (), argGetter)
287
- end argValDefs
288
-
289
-
290
- /** Create a list of argument references that will be passed as argument to the main method.
291
- * `args0`, ...`argn*`
292
- */
293
- def argRefs (mt : MethodType ): List [Tree ] =
294
- for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
295
- val argRef = Apply (Ident (nme.args ++ idx.toString), Nil )
296
- if formal.isRepeatedParam then repeated(argRef) else argRef
297
- end argRefs
298
-
299
-
300
- /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
301
- def instantiateAnnotation (annot : Annotation ): Tree =
302
- val argss = {
303
- def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
304
- case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
305
- case _ => acc
306
- }
307
-
308
- def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
309
- args.flatMap {
310
- case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
311
- case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
312
- case arg => List (TypedSplice (arg))
313
- }
314
-
315
- recurse(annot.tree, Nil )
316
- }
317
-
318
- New (TypeTree (annot.symbol.typeRef), argss)
319
- end instantiateAnnotation
320
-
321
- def generateMainClass (mainCall : Tree , args : List [Tree ], parameterInfos : List [Tree ]): TypeDef =
322
- val cmdInfo =
323
- val nameTree = Literal (Constant (mainFun.showName))
324
- val docTree = Literal (Constant (documentation.mainDoc))
325
- val paramInfos = Apply (ref(defn.SeqModule .termRef), parameterInfos)
326
- New (TypeTree (defn.MainAnnotationInfo .typeRef), List (List (nameTree, docTree, paramInfos)))
327
-
328
- val annotVal = ValDef (
329
- nme.annotation,
330
- TypeTree (),
331
- instantiateAnnotation(mainAnnot)
332
- )
333
- val infoVal = ValDef (
334
- nme.info,
335
- TypeTree (),
336
- cmdInfo
337
- )
338
- val command = ValDef (
339
- nme.command,
340
- TypeTree (),
341
- Apply (
342
- Select (Ident (nme.annotation), nme.command),
343
- List (Ident (nme.info), Ident (nme.args))
344
- )
345
- )
346
- val argsVal = ValDef (
347
- nme.cmd,
348
- TypeTree (),
349
- Select (Ident (nme.command), nme.get)
350
- )
351
- val run = Apply (Select (Ident (nme.annotation), nme.run), mainCall)
352
- val body0 = If (
353
- Select (Ident (nme.command), nme.isDefined),
354
- Block (argsVal :: args, run),
355
- EmptyTree
356
- )
357
- val body = Block (List (annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)`
358
-
359
- val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
360
- .withFlags(Param )
361
- /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
362
- * The annotations will be retype-checked in another scope that may not have the same imports.
363
- */
364
- def insertTypeSplices = new TreeMap {
365
- override def transform (tree : Tree )(using Context ): Tree = tree match
366
- case tree : tpd.Ident @ unchecked => TypedSplice (tree)
367
- case tree => super .transform(tree)
368
- }
369
- val annots = mainFun.annotations
370
- .filterNot(_.matches(defn.MainAnnotationClass ))
371
- .map(annot => insertTypeSplices.transform(annot.tree))
372
- val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
373
- .withFlags(JavaStatic )
374
- .withAnnotations(annots)
375
- val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
376
- val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
377
- .withFlags(Final | Invisible )
378
- mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
379
- end generateMainClass
380
-
381
- if (! mainFun.owner.isStaticOwner)
382
- report.error(em " main method is not statically accessible " , pos)
383
- None
384
- else mainFun.info match {
385
- case _ : ExprType =>
386
- Some (generateMainClass(unitToValue(ref(mainFun.termRef)), Nil , Nil ))
387
- case mt : MethodType =>
388
- if (mt.isImplicitMethod)
389
- report.error(em " main method cannot have implicit parameters " , pos)
390
- None
391
- else mt.resType match
392
- case restpe : MethodType =>
393
- report.error(em " main method cannot be curried " , pos)
394
- None
395
- case _ =>
396
- Some (generateMainClass(unitToValue(Apply (ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
397
- case _ : PolyType =>
398
- report.error(em " main method cannot have type parameters " , pos)
399
- None
400
- case _ =>
401
- report.error(em " main can only annotate a method " , pos)
402
- None
403
- }
404
- }
405
-
406
- /** A class responsible for extracting the docstrings of a method. */
407
- private class Documentation (docComment : Option [Comment ]):
408
- import util .CommentParsing .*
409
-
410
- /** The main part of the documentation. */
411
- lazy val mainDoc : String = _mainDoc
412
- /** The parameters identified by @param. Maps from parameter name to its documentation. */
413
- lazy val argDocs : Map [String , String ] = _argDocs
414
-
415
- private var _mainDoc : String = " "
416
- private var _argDocs : Map [String , String ] = Map ()
417
-
418
- docComment match {
419
- case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
420
- case None =>
421
- }
422
-
423
- private def cleanComment (raw : String ): String =
424
- var lines : Seq [String ] = raw.trim.nn.split('\n ' ).nn.toSeq
425
- lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).nn.trim.nn)
426
- var s = lines.foldLeft(" " ) {
427
- case (" " , s2) => s2
428
- case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
429
- case (s1, " " ) => s1 + '\n '
430
- case (s1, s2) if s1.last == '\n ' => s1 + s2
431
- case (s1, s2) => s1 + ' ' + s2
432
- }
433
- s.replaceAll(raw " \[\[ " , " " ).nn.replaceAll(raw " \]\] " , " " ).nn.trim.nn
434
-
435
- private def parseDocComment (raw : String ): Unit =
436
- // Positions of the sections (@) in the docstring
437
- val tidx : List [(Int , Int )] = tagIndex(raw)
438
-
439
- // Parse main comment
440
- var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx)).nn
441
- _mainDoc = cleanComment(mainComment)
442
-
443
- // Parse arguments comments
444
- val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
445
- val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
446
- val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
447
- _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
448
- end Documentation
449
125
}
0 commit comments