Skip to content

Commit 53d79fb

Browse files
authored
Allow multiple spreads in function arguments (#23855)
Implements SIP 70. Currently, only `Seq`s and `Array`s can be unpacked whereas the SIP also specifies `Option` unpacking.
2 parents a4e1309 + 61d0a78 commit 53d79fb

22 files changed

+587
-54
lines changed

compiler/src/dotty/tools/dotc/config/Feature.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ object Feature:
3737
val modularity = experimental("modularity")
3838
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
3939
val packageObjectValues = experimental("packageObjectValues")
40+
val multiSpreads = experimental("multiSpreads")
4041
val subCases = experimental("subCases")
4142

4243
def experimentalAutoEnableFeatures(using Context): List[TermName] =

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ class Definitions {
468468
@tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw,
469469
MethodType(List(ThrowableType), NothingType))
470470

471+
@tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread,
472+
PolyType(TypeBounds.empty :: Nil)(
473+
tl => MethodType(AnyType :: Nil, tl.paramRefs(0))
474+
))
475+
471476
@tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol(
472477
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType))
473478
def NothingType: TypeRef = NothingClass.typeRef
@@ -519,6 +524,8 @@ class Definitions {
519524
@tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray")
520525
@tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray")
521526

527+
@tu lazy val VarArgsBuilderModule: Symbol = requiredModule("scala.runtime.VarArgsBuilder")
528+
522529
def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule
523530

524531
// The set of all wrap{X, Ref}Array methods, where X is a value type
@@ -563,11 +570,12 @@ class Definitions {
563570
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
564571
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
565572
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
573+
@tu lazy val Seq_dropRight : Symbol = SeqClass.requiredMethod(nme.dropRight)
574+
@tu lazy val Seq_takeRight : Symbol = SeqClass.requiredMethod(nme.takeRight)
566575
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
567576
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
568577
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
569578

570-
571579
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
572580
@tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format)
573581

@@ -2234,7 +2242,7 @@ class Definitions {
22342242

22352243
/** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
22362244
@tu lazy val syntheticCoreMethods: List[TermSymbol] =
2237-
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
2245+
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod)
22382246

22392247
@tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet
22402248

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ object StdNames {
470470
val doubleHash: N = "doubleHash"
471471
val dotty: N = "dotty"
472472
val drop: N = "drop"
473+
val dropRight: N = "dropRight"
473474
val dynamics: N = "dynamics"
474475
val elem: N = "elem"
475476
val elems: N = "elems"
@@ -619,6 +620,7 @@ object StdNames {
619620
val setSymbol: N = "setSymbol"
620621
val setType: N = "setType"
621622
val setTypeSignature: N = "setTypeSignature"
623+
val spread: N = "spread"
622624
val standardInterpolator: N = "standardInterpolator"
623625
val staticClass : N = "staticClass"
624626
val staticModule : N = "staticModule"
@@ -801,6 +803,7 @@ object StdNames {
801803
val takeModulo: N = "takeModulo"
802804
val takeNot: N = "takeNot"
803805
val takeOr: N = "takeOr"
806+
val takeRight: N = "takeRight"
804807
val takeXor: N = "takeXor"
805808
val testEqual: N = "testEqual"
806809
val testGreaterOrEqualThan: N = "testGreaterOrEqualThan"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,17 +1056,22 @@ object Parsers {
10561056
}
10571057

10581058
/** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not
1059-
syntactically valid, but we need to include them here for error recovery. */
1059+
syntactically valid, but we need to include them here for error recovery.
1060+
Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally.
1061+
*/
10601062
def followingIsVararg(): Boolean =
10611063
in.isIdent(nme.raw.STAR) && {
10621064
val lookahead = in.LookaheadScanner()
10631065
lookahead.nextToken()
10641066
lookahead.token == RPAREN
10651067
|| lookahead.token == COMMA
1066-
&& {
1067-
lookahead.nextToken()
1068-
lookahead.token == RPAREN || lookahead.token == EOF
1069-
}
1068+
&& (
1069+
in.featureEnabled(Feature.multiSpreads)
1070+
|| {
1071+
lookahead.nextToken()
1072+
lookahead.token == RPAREN || lookahead.token == EOF
1073+
}
1074+
)
10701075
}
10711076

10721077
/** When encountering a `:`, is that in the binding of a lambda?
@@ -3347,7 +3352,9 @@ object Parsers {
33473352
if (in.token == RPAREN) Nil else patterns(location)
33483353

33493354
/** ArgumentPatterns ::= ‘(’ [Patterns] ‘)’
3350-
* | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’
3355+
* | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns] ‘)’
3356+
*
3357+
* -- It is checked in Typer that there are no repeated PatVar arguments.
33513358
*/
33523359
def argumentPatterns(): List[Tree] =
33533360
inParensWithCommas(patternsOpt(Location.InPatternArgs))

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ object PatternMatcher {
198198
case object NonNullTest extends Test // scrutinee ne null
199199
case object GuardTest extends Test // scrutinee
200200

201+
val noLengthTest = LengthTest(0, exact = false)
202+
201203
// ------- Generating plans from trees ------------------------
202204

203205
/** A set of variabes that are known to be not null */
@@ -291,38 +293,67 @@ object PatternMatcher {
291293
/** Plan for matching the sequence in `seqSym` against sequence elements `args`.
292294
* If `exact` is true, the sequence is not permitted to have any elements following `args`.
293295
*/
294-
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
295-
val selectors = args.indices.toList.map(idx =>
296-
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
297-
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.span,
298-
matchArgsPlan(selectors, args, onSuccess))
299-
}
296+
def matchElemsPlan(seqSym: Symbol, args: List[Tree], lengthTest: LengthTest, onSuccess: Plan) =
297+
val selectors = args.indices.toList.map: idx =>
298+
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx)))
299+
if lengthTest.len == 0 && lengthTest.exact == false then // redundant test
300+
matchArgsPlan(selectors, args, onSuccess)
301+
else
302+
TestPlan(lengthTest, seqSym, seqSym.span,
303+
matchArgsPlan(selectors, args, onSuccess))
300304

301305
/** Plan for matching the sequence in `getResult` against sequence elements
302-
* and a possible last varargs argument `args`.
306+
* `args`. Sequence elements may contain a varargs argument.
307+
* Example:
308+
*
309+
* lst match case Seq(1, xs*, 2, 3) => ...
310+
*
311+
* generates code which is equivalent to:
312+
*
313+
* if lst != null then
314+
* if lst.lengthCompare >= 3 then
315+
* if lst(0) == 1 then
316+
* val x1 = lst.drop(1)
317+
* val xs = x1.dropRight(2)
318+
* val x2 = lst.takeRight(2)
319+
* if x2(0) == 2 && x2(1) == 3 then
320+
* return[matchResult] ...
303321
*/
304-
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
305-
case Some(VarArgPattern(arg)) =>
306-
val matchRemaining =
307-
if (args.length == 1) {
308-
val toSeq = ref(getResult)
309-
.select(defn.Seq_toSeq.matchingMember(getResult.info))
310-
letAbstract(toSeq) { toSeqResult =>
311-
patternPlan(toSeqResult, arg, onSuccess)
312-
}
313-
}
314-
else {
315-
val dropped = ref(getResult)
316-
.select(defn.Seq_drop.matchingMember(getResult.info))
317-
.appliedTo(Literal(Constant(args.length - 1)))
318-
letAbstract(dropped) { droppedResult =>
319-
patternPlan(droppedResult, arg, onSuccess)
320-
}
321-
}
322-
matchElemsPlan(getResult, args.init, exact = false, matchRemaining)
323-
case _ =>
324-
matchElemsPlan(getResult, args, exact = true, onSuccess)
325-
}
322+
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan =
323+
val (leading, varargAndRest) = args.span:
324+
case VarArgPattern(_) => false
325+
case _ => true
326+
varargAndRest match
327+
case VarArgPattern(arg) :: trailing =>
328+
val remaining =
329+
if leading.isEmpty then
330+
ref(getResult)
331+
.select(defn.Seq_toSeq.matchingMember(getResult.info))
332+
else
333+
ref(getResult)
334+
.select(defn.Seq_drop.matchingMember(getResult.info))
335+
.appliedTo(Literal(Constant(leading.length)))
336+
val matchRemaining =
337+
letAbstract(remaining): remainingResult =>
338+
if trailing.isEmpty then
339+
patternPlan(remainingResult, arg, onSuccess)
340+
else
341+
val seq = ref(remainingResult)
342+
.select(defn.Seq_dropRight.matchingMember(remainingResult.info))
343+
.appliedTo(Literal(Constant(trailing.length)))
344+
letAbstract(seq): seqResult =>
345+
val rest = ref(remainingResult)
346+
.select(defn.Seq_takeRight.matchingMember(remainingResult.info))
347+
.appliedTo(Literal(Constant(trailing.length)))
348+
val matchTrailing =
349+
letAbstract(rest): trailingResult =>
350+
matchElemsPlan(trailingResult, trailing, noLengthTest, onSuccess)
351+
patternPlan(seqResult, arg, matchTrailing)
352+
matchElemsPlan(getResult, leading,
353+
LengthTest(leading.length + trailing.length, exact = false),
354+
matchRemaining)
355+
case _ =>
356+
matchElemsPlan(getResult, args, LengthTest(args.length, exact = true), onSuccess)
326357

327358
/** Plan for matching the sequence in `getResult`
328359
*
@@ -491,7 +522,7 @@ object PatternMatcher {
491522
case WildcardPattern() | This(_) =>
492523
onSuccess
493524
case SeqLiteral(pats, _) =>
494-
matchElemsPlan(scrutinee, pats, exact = true, onSuccess)
525+
matchElemsPlan(scrutinee, pats, LengthTest(pats.length, exact = true), onSuccess)
495526
case _ =>
496527
TestPlan(EqualTest(tree), scrutinee, tree.span, onSuccess)
497528
}

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ import config.Printers.typr
1818
import config.Feature
1919
import util.{SrcPos, Stats}
2020
import reporting.*
21-
import NameKinds.WildcardParamName
21+
import NameKinds.{WildcardParamName, TempResultName}
22+
import typer.Applications.{spread, HasSpreads}
23+
import typer.Implicits.SearchFailureType
24+
import Constants.Constant
2225
import cc.*
2326
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
2427
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
28+
import ast.TreeInfo
2529

2630
object PostTyper {
2731
val name: String = "posttyper"
@@ -376,6 +380,86 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
376380
case _ =>
377381
tpt
378382

383+
/** If one of `trees` is a spread of an expression that is not idempotent, lift out all
384+
* non-idempotent expressions (not just the spreads) and apply `within` to the resulting
385+
* pure references. Otherwise apply `within` to the original trees.
386+
*/
387+
private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree =
388+
if trees.exists:
389+
case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent)
390+
case _ => false
391+
then
392+
val lifted = new mutable.ListBuffer[ValDef]
393+
def liftIfImpure(tree: Tree): Tree = tree match
394+
case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod =>
395+
cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure))
396+
case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent =>
397+
tree
398+
case _ =>
399+
val vdef = SyntheticValDef(TempResultName.fresh(), tree).withSpan(tree.span)
400+
lifted += vdef
401+
Ident(vdef.namedType).withSpan(tree.span)
402+
val pureTrees = trees.mapConserve(liftIfImpure)
403+
Block(lifted.toList, within(pureTrees))
404+
else within(trees)
405+
406+
/** Translate sequence literal containing spread operators. Example:
407+
*
408+
* val xs, ys: List[Int]
409+
* [1, xs*, 2, ys*]
410+
*
411+
* Here the sequence literal is translated at typer to
412+
*
413+
* [1, spread(xs), 2, spread(ys)]
414+
*
415+
* This then translates to
416+
*
417+
* scala.runtime.VarArgsBuilder.ofInt(2 + xs.length + ys.length)
418+
* .add(1)
419+
* .addSeq(xs)
420+
* .add(2)
421+
* .addSeq(ys)
422+
*
423+
* The reason for doing a two-step typer/postTyper translation is that
424+
* at typer, we don't have all type variables instantiated yet.
425+
*/
426+
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
427+
val SeqLiteral(rawElems, elemtpt) = tree
428+
val elemType = elemtpt.tpe
429+
val elemCls = elemType.classSymbol
430+
431+
evalSpreadsOnce(rawElems): elems =>
432+
val lengthCalls = elems.collect:
433+
case spread(elem) => elem.select(nme.length)
434+
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
435+
val totalLength =
436+
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
437+
acc.select(defn.Int_+).appliedTo(len)
438+
439+
def makeBuilder(name: String) =
440+
ref(defn.VarArgsBuilderModule).select(name.toTermName)
441+
442+
val builder =
443+
if defn.ScalaValueClasses().contains(elemCls) then
444+
makeBuilder(s"of${elemCls.name}")
445+
else if elemCls.derivesFrom(defn.ObjectClass) then
446+
makeBuilder("ofRef").appliedToType(elemType)
447+
else
448+
makeBuilder("generic").appliedToType(elemType)
449+
450+
elems.foldLeft(builder.appliedTo(totalLength)): (bldr, elem) =>
451+
elem match
452+
case spread(arg) =>
453+
if arg.tpe.derivesFrom(defn.SeqClass) then
454+
bldr.select("addSeq".toTermName).appliedTo(arg)
455+
else
456+
bldr.select("addArray".toTermName).appliedTo(
457+
arg.ensureConforms(defn.ArrayOf(elemType)))
458+
case _ => bldr.select("add".toTermName).appliedTo(elem)
459+
.select("result".toTermName)
460+
.appliedToNone
461+
end flattenSpreads
462+
379463
override def transform(tree: Tree)(using Context): Tree =
380464
try tree match {
381465
// TODO move CaseDef case lower: keep most probable trees first for performance
@@ -592,6 +676,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
592676
case tree: RefinedTypeTree =>
593677
Checking.checkPolyFunctionType(tree)
594678
super.transform(tree)
679+
case tree: SeqLiteral if tree.hasAttachment(HasSpreads) =>
680+
flattenSpreads(tree)
595681
case _: Quote | _: QuotePattern =>
596682
ctx.compilationUnit.needsStaging = true
597683
super.transform(tree)

0 commit comments

Comments
 (0)