diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcConvertToEnum.scala b/presentation-compiler/src/main/dotty/tools/pc/PcConvertToEnum.scala new file mode 100644 index 000000000000..b27f3092d662 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcConvertToEnum.scala @@ -0,0 +1,125 @@ +package main.dotty.tools.pc + +import java.nio.file.Paths +import java.util as ju + +import scala.jdk.CollectionConverters.* +import scala.meta.pc.OffsetParams + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Names.TypeName +import dotty.tools.dotc.core.Symbols.ClassSymbol +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile + +import dotty.tools.pc.utils.InteractiveEnrichments.* + +import org.eclipse.lsp4j as l +import org.eclipse.lsp4j.TextEdit + +class PcConvertToEnum( + driver: InteractiveDriver, + params: OffsetParams +) { + + given Context = driver.currentCtx + + def convertToEnum: ju.List[TextEdit] = + val uri = params.uri + val filePath = Paths.get(uri) + driver.run(uri, SourceFile.virtual(filePath.toString, params.text)) + val pos = driver.sourcePosition(params) + Interactive.pathTo(driver.openedTrees(uri), pos) match + case (t @ TypeDef(name, rhs: Template)) :: tail if t.symbol.exists && t.symbol.is(Flags.Sealed) => + val sealedClassSymbol = t.symbol.asClass + val (implementations, companionTree) = collectImplementations(sealedClassSymbol) + val useFewerBraces: Boolean = + val checkForBracePosition = companionTree.map(_.span.end).getOrElse(t.span.end) - 1 + if(checkForBracePosition >= 0 && params.text().charAt(checkForBracePosition) == '}') false else true + val indentString = " " * detectIndentation(t) + val toReplace = new TextEdit(t.sourcePos.toLsp, makeReplaceText(name, rhs, implementations, indentString, useFewerBraces)) + val deleteTextEdits = toDelete(sealedClassSymbol, implementations, companionTree) + (toReplace :: deleteTextEdits).asJava + case _ => List.empty.asJava + + private def collectImplementations(sym: ClassSymbol)(using Context) = { + val collector = new DeepFolder[(List[TypeDef], Option[TypeDef])]({ + case ((acc, companion), t @ TypeDef(_, _: Template)) if t.symbol.isClass && !t.symbol.is(Flags.Synthetic) && t.symbol.info.parents.map(_.typeSymbol).exists(_ == sym) => + (t :: acc, companion) + case ((acc, None), t @TypeDef(_, _: Template)) if t.symbol.isClass && t.symbol.companionClass == sym => + (acc, Some(t)) + case (acc, _) => acc + }) + + val (trees, companion) = collector.apply((Nil, None), driver.compilationUnits(params.uri()).tpdTree) + (trees.reverse, companion) + } + + private def toDelete(sealedClassSymbol: ClassSymbol, implementations: List[TypeDef], companionTree: Option[TypeDef])(using Context) = + val moduleClass = sealedClassSymbol.companionClass + val deleteCompanionObject = companionTree.collect: + case ct @ TypeDef(name, t: Template) if getRelevantBodyParts(t).forall(implementations.contains(_)) => + expandedPosition(ct) + val toReexport = + implementations.groupBy(_.symbol.owner).collect: + case (symbol, trees) if symbol != moduleClass => trees.head + .toList + val toDelete = implementations.filterNot(toReexport.contains(_)) + deleteCompanionObject + .map(_ :: toDelete.filter(_.symbol.owner != moduleClass).map(expandedPosition(_))) + .getOrElse(toDelete.map(expandedPosition(_))).map(new TextEdit(_, "")) + ++ toReexport.map(_.sourcePos.toLsp).map(new TextEdit(_, s"export ${sealedClassSymbol.name}.*")) + + private def makeReplaceText(name: TypeName, rhs: Template, implementations: List[TypeDef], baseIndent: String, useFewerBraces: Boolean): String = + val constrString = showFromSource(rhs.constr) + val (simpleCases, complexCases) = + if constrString == "" + then implementations.partition(_.rhs.asInstanceOf[Template].constr.paramss == List(Nil)) + else (Nil, implementations) + val simpleCasesString = if simpleCases.isEmpty then "" else s"\n$baseIndent case " + simpleCases.map(tdefName).mkString("", ", ", "") + val complexCasesString = + complexCases.map: + case tdef @ TypeDef(_, t @ Template(constr, _, _, _)) => + val parentConstructorString = + t.parents.filterNot(_.span.isZeroExtent).map(showFromSource).mkString(", ") + s"\n$baseIndent case ${tdefName(tdef)}${showFromSource(constr)} extends $parentConstructorString" + case _ => "" + .mkString + val newRhs = getRelevantBodyParts(rhs).map(showFromSource).map(stat => s"\n$baseIndent $stat").mkString + val (begMarker, endMarker) = if useFewerBraces then (":", "") else (" {", s"\n$baseIndent}" ) + s"enum $name${constrString}$begMarker$newRhs$simpleCasesString$complexCasesString$endMarker" + + private def showFromSource(t: Tree): String = + params.text().substring(t.sourcePos.start, t.sourcePos.end) + + private def getRelevantBodyParts(rhs: Template): List[Tree] = + def isParamOrTypeParam(stat: Tree): Boolean = stat match + case stat: ValDef => stat.symbol.is(Flags.ParamAccessor) + case stat: TypeDef => stat.symbol.is(Flags.Param) + case _ => false + rhs.body.filterNot(stat => stat.span.isZeroExtent || stat.symbol.is(Flags.Synthetic) || isParamOrTypeParam(stat)) + + private def expandedPosition(tree: Tree): l.Range = + extendRangeToIncludeWhiteCharsAndTheFollowingNewLine(params.text().toCharArray())(tree.span.start, tree.span.end) match + case (start, end) => + tree.source.atSpan(tree.span.withStart(start).withEnd(end)).toLsp + + private def detectIndentation(tree: TypeDef): Int = + val text = params.text() + var curr = tree.span.start + var indent = 0 + while(curr >= 0 && text(curr) != '\n') + if text(curr).isWhitespace then indent += 1 + else indent = 0 + curr -= 1 + indent + + private def tdefName(tdef: TypeDef) = + if tdef.symbol.is(Flags.ModuleClass) then tdef.symbol.companionModule.name else tdef.name +} + +object PcConvertToEnum: + val codeActionId = "ConvertToEnum" diff --git a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala index dc53525480c3..33c15f2c4107 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala @@ -40,6 +40,7 @@ import dotty.tools.dotc.interactive.InteractiveDriver import org.eclipse.lsp4j.DocumentHighlight import org.eclipse.lsp4j.TextEdit import org.eclipse.lsp4j as l +import main.dotty.tools.pc.PcConvertToEnum case class ScalaPresentationCompiler( @@ -61,7 +62,8 @@ case class ScalaPresentationCompiler( CodeActionId.ImplementAbstractMembers, CodeActionId.ExtractMethod, CodeActionId.InlineValue, - CodeActionId.InsertInferredType + CodeActionId.InsertInferredType, + PcConvertToEnum.codeActionId ).asJava def this() = this("", None, Nil, Nil) @@ -80,28 +82,32 @@ case class ScalaPresentationCompiler( params: OffsetParams, codeActionId: String, codeActionPayload: Optional[T] - ): CompletableFuture[ju.List[TextEdit]] = - (codeActionId, codeActionPayload.asScala) match - case ( - CodeActionId.ConvertToNamedArguments, - Some(argIndices: ju.List[_]) - ) => - val payload = - argIndices.asScala.collect { case i: Integer => i.toInt }.toSet - convertToNamedArguments(params, payload) - case (CodeActionId.ImplementAbstractMembers, _) => - implementAbstractMembers(params) - case (CodeActionId.InsertInferredType, _) => - insertInferredType(params) - case (CodeActionId.InlineValue, _) => - inlineValue(params) - case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) => - params match { - case range: RangeParams => - extractMethod(range, extractionPos) - case _ => failedFuture(new IllegalArgumentException(s"Expected range parameters")) - } - case (id, _) => failedFuture(new IllegalArgumentException(s"Unsupported action id $id")) + ): CompletableFuture[ju.List[TextEdit]] = + (codeActionId, codeActionPayload.asScala) match + case ( + CodeActionId.ConvertToNamedArguments, + Some(argIndices: ju.List[_]) + ) => + val payload = + argIndices.asScala.collect { case i: Integer => i.toInt }.toSet + convertToNamedArguments(params, payload) + case (CodeActionId.ImplementAbstractMembers, _) => + implementAbstractMembers(params) + case (CodeActionId.InsertInferredType, _) => + insertInferredType(params) + case (CodeActionId.InlineValue, _) => + inlineValue(params) + case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) => + params match { + case range: RangeParams => + extractMethod(range, extractionPos) + case _ => failedFuture(new IllegalArgumentException(s"Expected range parameters")) + } + case (PcConvertToEnum.codeActionId, _) => + compilerAccess.withNonInterruptableCompiler(List.empty[l.TextEdit].asJava, params.token) { + access => PcConvertToEnum(access.compiler(), params).convertToEnum + }(params.toQueryContext) + case (id, _) => failedFuture(new IllegalArgumentException(s"Unsupported action id $id")) private def failedFuture[T](e: Throwable): CompletableFuture[T] = val f = new CompletableFuture[T]() diff --git a/presentation-compiler/test/dotty/tools/pc/tests/edit/ConvertToEnumSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/edit/ConvertToEnumSuite.scala new file mode 100644 index 000000000000..2f0b853b447b --- /dev/null +++ b/presentation-compiler/test/dotty/tools/pc/tests/edit/ConvertToEnumSuite.scala @@ -0,0 +1,162 @@ +package dotty.tools.pc.tests.edit + +import java.net.URI +import java.util.Optional + +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +import scala.meta.internal.jdk.CollectionConverters._ +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.pc.CodeActionId + +import org.eclipse.lsp4j.TextEdit +import dotty.tools.pc.base.BaseCodeActionSuite +import main.dotty.tools.pc.PcConvertToEnum +import dotty.tools.pc.utils.TextEdits + +import org.junit.Test + +class ConvertToEnumSuite extends BaseCodeActionSuite: + + @Test def basic = + checkEdit( + """|sealed trait <>ow + |object Cow: + | class HolsteinFriesian extends Cow + | class Highland extends Cow + | class BrownSwiss extends Cow + |""".stripMargin, + """|enum Cow: + | case HolsteinFriesian, Highland, BrownSwiss + |""".stripMargin + ) + + @Test def `basic-with-params` = + checkEdit( + """|sealed class <>ow[T](val i: Int, j: Int) + |object Cow: + | class HolsteinFriesian extends Cow[1](1, 1) + | class Highland extends Cow[2](2, 2) + | class BrownSwiss extends Cow[3](3, 3) + |""".stripMargin, + """|enum Cow[T](val i: Int, j: Int): + | case HolsteinFriesian extends Cow[1](1, 1) + | case Highland extends Cow[2](2, 2) + | case BrownSwiss extends Cow[3](3, 3) + |""".stripMargin + ) + + @Test def `class-with-body` = + checkEdit( + """|trait Spotted + | + |sealed trait <>ow: + | def moo = "Mooo!" + | + |object Cow: + | def of(name: String) = HolsteinFriesian(name) + | case class HolsteinFriesian(name: String) extends Cow, Spotted + | class Highland extends Cow + | class BrownSwiss extends Cow + |""".stripMargin, + """|trait Spotted + | + |enum Cow: + | def moo = "Mooo!" + | case Highland, BrownSwiss + | case HolsteinFriesian(name: String) extends Cow, Spotted + | + |object Cow: + | def of(name: String) = HolsteinFriesian(name) + |""".stripMargin + ) + + @Test def `with-indentation` = + checkEdit( + """|object O { + | sealed class <>ow { + | def moo = "Mooo!" + | def mooooo = "Mooooooo!" + | } + | object Cow { + | case class HolsteinFriesian(name: String) extends Cow + | class Highland extends Cow + | class BrownSwiss extends Cow + | } + |} + |""".stripMargin, + """|object O { + | enum Cow { + | def moo = "Mooo!" + | def mooooo = "Mooooooo!" + | case Highland, BrownSwiss + | case HolsteinFriesian(name: String) extends Cow + | } + |} + |""".stripMargin + ) + + @Test def `case-objects` = + checkEdit( + """|sealed trait <>ow + |case object HolsteinFriesian extends Cow + |case object Highland extends Cow + |case object BrownSwiss extends Cow + |""".stripMargin, + """|enum Cow: + | case HolsteinFriesian, Highland, BrownSwiss + |export Cow.* + |""".stripMargin + ) + + @Test def `no-companion-object` = + checkEdit( + """|sealed trait <>ow + |class HolsteinFriesian extends Cow + |class Highland extends Cow + |class BrownSwiss extends Cow + |""".stripMargin, + """|enum Cow: + | case HolsteinFriesian, Highland, BrownSwiss + |export Cow.* + |""".stripMargin + ) + + def checkError( + original: String, + expectedError: String + ): Unit = + Try(getConversionToEnum(original)) match + case Failure(exception: Throwable) => + assertNoDiff( + exception.getCause().getMessage().replaceAll("\\[.*\\]", ""), + expectedError + ) + case Success(_) => + fail("Expected an error but got a result") + + def checkEdit( + original: String, + expected: String, + ): Unit = + val edits = getConversionToEnum(original) + val (code, _, _) = params(original) + val obtained = TextEdits.applyEdits(code, edits) + assertNoDiff(obtained, expected) + + def getConversionToEnum( + original: String, + filename: String = "file:/A.scala" + ): List[TextEdit] = { + val (code, _, offset) = params(original) + val result = presentationCompiler + .codeAction[Boolean]( + CompilerOffsetParams(URI.create(filename), code, offset, cancelToken), + PcConvertToEnum.codeActionId, + Optional.empty() + ) + .get() + result.asScala.toList + }