Skip to content

Commit df2e033

Browse files
bishaboshaWojciechMazur
authored andcommitted
Record progress in the current run
Test that the callbacks are called with expected values [Cherry-picked 2b7a09e]
1 parent 767dc68 commit df2e033

File tree

13 files changed

+388
-30
lines changed

13 files changed

+388
-30
lines changed

compiler/src/dotty/tools/dotc/Run.scala

+120-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import typer.Typer
1212
import typer.ImportInfo.withRootImports
1313
import Decorators._
1414
import io.AbstractFile
15-
import Phases.unfusedPhases
15+
import Phases.{unfusedPhases, Phase}
16+
17+
import sbt.interfaces.ProgressCallback
1618

1719
import util._
1820
import reporting.{Suppression, Action, Profile, ActiveProfile, NoProfile}
@@ -32,6 +34,9 @@ import scala.collection.mutable
3234
import scala.util.control.NonFatal
3335
import scala.io.Codec
3436

37+
import Run.Progress
38+
import scala.compiletime.uninitialized
39+
3540
/** A compiler run. Exports various methods to compile source files */
3641
class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo {
3742

@@ -155,14 +160,51 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
155160
}
156161

157162
/** The source files of all late entered symbols, as a set */
158-
private var lateFiles = mutable.Set[AbstractFile]()
163+
private val lateFiles = mutable.Set[AbstractFile]()
159164

160165
/** A cache for static references to packages and classes */
161166
val staticRefs = util.EqHashMap[Name, Denotation](initialCapacity = 1024)
162167

163168
/** Actions that need to be performed at the end of the current compilation run */
164169
private var finalizeActions = mutable.ListBuffer[() => Unit]()
165170

171+
private var _progress: Progress | Null = null // Set if progress reporting is enabled
172+
173+
/** Only safe to call if progress is being tracked. */
174+
private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit =
175+
val local = _progress
176+
if local != null then
177+
op(using ctx)(local)
178+
179+
def doBeginUnit(unit: CompilationUnit)(using Context): Unit =
180+
trackProgress: progress =>
181+
progress.informUnitStarting(unit)
182+
183+
def doAdvanceUnit()(using Context): Unit =
184+
trackProgress: progress =>
185+
progress.unitc += 1 // trace that we completed a unit in the current phase
186+
progress.refreshProgress()
187+
188+
def doAdvanceLate()(using Context): Unit =
189+
trackProgress: progress =>
190+
progress.latec += 1 // trace that we completed a late compilation
191+
progress.refreshProgress()
192+
193+
private def doEnterPhase(currentPhase: Phase)(using Context): Unit =
194+
trackProgress: progress =>
195+
progress.enterPhase(currentPhase)
196+
197+
private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit =
198+
trackProgress: progress =>
199+
progress.unitc = 0 // reset unit count in current phase
200+
progress.seen += 1 // trace that we've seen a phase
201+
if wasRan then
202+
// add an extra traversal now that we completed a phase
203+
progress.traversalc += 1
204+
else
205+
// no phase was ran, remove a traversal from expected total
206+
progress.runnablePhases -= 1
207+
166208
/** Will be set to true if any of the compiled compilation units contains
167209
* a pureFunctions language import.
168210
*/
@@ -233,13 +275,15 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
233275
if ctx.settings.YnoDoubleBindings.value then
234276
ctx.base.checkNoDoubleBindings = true
235277

236-
def runPhases(using Context) = {
278+
def runPhases(allPhases: Array[Phase])(using Context) = {
237279
var lastPrintedTree: PrintedTree = NoPrintedTree
238280
val profiler = ctx.profiler
239281
var phasesWereAdjusted = false
240282

241-
for (phase <- ctx.base.allPhases)
242-
if (phase.isRunnable)
283+
for phase <- allPhases do
284+
doEnterPhase(phase)
285+
val phaseWillRun = phase.isRunnable
286+
if phaseWillRun then
243287
Stats.trackTime(s"phase time ms/$phase") {
244288
val start = System.currentTimeMillis
245289
val profileBefore = profiler.beforePhase(phase)
@@ -260,14 +304,21 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
260304
if !Feature.ccEnabledSomewhere then
261305
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase.prev)
262306
ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase)
263-
307+
end if
308+
end if
309+
end if
310+
doAdvancePhase(phase, wasRan = phaseWillRun)
311+
end for
264312
profiler.finished()
265313
}
266314

267315
val runCtx = ctx.fresh
268316
runCtx.setProfiler(Profiler())
269317
unfusedPhases.foreach(_.initContext(runCtx))
270-
runPhases(using runCtx)
318+
val fusedPhases = runCtx.base.allPhases
319+
runCtx.withProgressCallback: cb =>
320+
_progress = Progress(cb, this, fusedPhases.length)
321+
runPhases(allPhases = fusedPhases)(using runCtx)
271322
if (!ctx.reporter.hasErrors)
272323
Rewrites.writeBack()
273324
suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
@@ -293,10 +344,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
293344
.withRootImports
294345

295346
def process()(using Context) =
296-
ctx.typer.lateEnterUnit(doTypeCheck =>
297-
if typeCheck then
298-
if compiling then finalizeActions += doTypeCheck
299-
else doTypeCheck()
347+
ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck =>
348+
if compiling then finalizeActions += doTypeCheck
349+
else doTypeCheck()
300350
)
301351

302352
process()(using unitCtx)
@@ -399,7 +449,66 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
399449
}
400450

401451
object Run {
452+
453+
/**Computes the next MegaPhase for the given phase.*/
454+
def nextMegaPhase(phase: Phase)(using Context): Phase = phase.megaPhase.next.megaPhase
455+
456+
private class Progress(cb: ProgressCallback, private val run: Run, val initialPhases: Int):
457+
private[Run] var runnablePhases: Int = initialPhases // track how many phases we expect to run
458+
private[Run] var unitc: Int = 0 // current unit count in the current phase
459+
private[Run] var latec: Int = 0 // current late unit count
460+
private[Run] var traversalc: Int = 0 // completed traversals over all files
461+
private[Run] var seen: Int = 0 // how many phases we've seen so far
462+
463+
private var currPhase: Phase = uninitialized // initialized by enterPhase
464+
private var currPhaseName: String = uninitialized // initialized by enterPhase
465+
private var nextPhaseName: String = uninitialized // initialized by enterPhase
466+
467+
private def phaseNameFor(phase: Phase): String =
468+
if phase.exists then phase.phaseName
469+
else "<end>"
470+
471+
private[Run] def enterPhase(newPhase: Phase)(using Context): Unit =
472+
if newPhase ne currPhase then
473+
currPhase = newPhase
474+
currPhaseName = phaseNameFor(newPhase)
475+
nextPhaseName = phaseNameFor(Run.nextMegaPhase(newPhase))
476+
if seen > 0 then
477+
refreshProgress()
478+
479+
480+
/** Counts the number of completed full traversals over files, plus the number of units in the current phase */
481+
private def currentProgress()(using Context): Int =
482+
traversalc * run.files.size + unitc + latec
483+
484+
/**Total progress is computed as the sum of
485+
* - the number of traversals we expect to make over all files
486+
* - the number of late compilations
487+
*/
488+
private def totalProgress()(using Context): Int =
489+
runnablePhases * run.files.size + run.lateFiles.size
490+
491+
private def requireInitialized(): Unit =
492+
require((currPhase: Phase | Null) != null, "enterPhase was not called")
493+
494+
private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit =
495+
requireInitialized()
496+
cb.informUnitStarting(currPhaseName, unit)
497+
498+
private[Run] def refreshProgress()(using Context): Unit =
499+
requireInitialized()
500+
cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName)
501+
402502
extension (run: Run | Null)
503+
def beginUnit(unit: CompilationUnit)(using Context): Unit =
504+
if run != null then run.doBeginUnit(unit)
505+
506+
def advanceUnit()(using Context): Unit =
507+
if run != null then run.doAdvanceUnit()
508+
509+
def advanceLate()(using Context): Unit =
510+
if run != null then run.doAdvanceLate()
511+
403512
def enrichedErrorMessage: Boolean = if run == null then false else run.myEnrichedErrorMessage
404513
def enrichErrorMessage(errorMessage: String)(using Context): String =
405514
if run == null then

compiler/src/dotty/tools/dotc/core/Phases.scala

+2
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,12 @@ object Phases {
324324
def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] =
325325
units.map { unit =>
326326
given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports
327+
ctx.run.beginUnit(unit)
327328
try run
328329
catch case ex: Throwable if !ctx.run.enrichedErrorMessage =>
329330
println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit"))
330331
throw ex
332+
finally ctx.run.advanceUnit()
331333
unitCtx.compilationUnit
332334
}
333335

compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ class ReadTasty extends Phase {
2222
ctx.settings.fromTasty.value
2323

2424
override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
25-
withMode(Mode.ReadPositions)(units.flatMap(readTASTY(_)))
25+
withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_)))
26+
27+
private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] =
28+
ctx.run.beginUnit(unit)
29+
try readTASTY(unit)
30+
finally ctx.run.advanceUnit()
2631

2732
def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match {
2833
case unit: TASTYCompilationUnit =>

compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@ class Parser extends Phase {
4343
override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = {
4444
val unitContexts =
4545
for unit <- units yield
46+
ctx.run.beginUnit(unit)
4647
report.inform(s"parsing ${unit.source}")
4748
ctx.fresh.setCompilationUnit(unit).withRootImports
4849

49-
unitContexts.foreach(parse(using _))
50+
for given Context <- unitContexts do
51+
try parse
52+
finally ctx.run.advanceUnit()
5053
record("parsedTrees", ast.Trees.ntrees)
5154

5255
unitContexts.map(_.compilationUnit)

compiler/src/dotty/tools/dotc/transform/init/Checker.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import Phases._
1818
import scala.collection.mutable
1919

2020
import Semantic._
21+
import dotty.tools.unsupported
2122

2223
class Checker extends Phase:
2324

@@ -33,16 +34,17 @@ class Checker extends Phase:
3334
override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
3435
val checkCtx = ctx.fresh.setPhase(this.start)
3536
val traverser = new InitTreeTraverser()
36-
units.foreach { unit => traverser.traverse(unit.tpdTree) }
37+
for unit <- units do
38+
checkCtx.run.beginUnit(unit)
39+
try traverser.traverse(unit.tpdTree)
40+
finally ctx.run.advanceUnit()
3741
val classes = traverser.getClasses()
3842

3943
Semantic.checkClasses(classes)(using checkCtx)
4044

4145
units
4246

43-
def run(using Context): Unit =
44-
// ignore, we already called `Semantic.check()` in `runOn`
45-
()
47+
def run(using Context): Unit = unsupported("run")
4648

4749
class InitTreeTraverser extends TreeTraverser:
4850
private val classes: mutable.ArrayBuffer[ClassSymbol] = new mutable.ArrayBuffer

compiler/src/dotty/tools/dotc/typer/Namer.scala

+16-8
Original file line numberDiff line numberDiff line change
@@ -722,20 +722,27 @@ class Namer { typer: Typer =>
722722
* Will call the callback with an implementation of type checking
723723
* That will set the tpdTree and root tree for the compilation unit.
724724
*/
725-
def lateEnterUnit(typeCheckCB: (() => Unit) => Unit)(using Context) =
725+
def lateEnterUnit(typeCheck: Boolean)(typeCheckCB: (() => Unit) => Unit)(using Context) =
726726
val unit = ctx.compilationUnit
727727

728728
/** Index symbols in unit.untpdTree with lateCompile flag = true */
729729
def lateEnter()(using Context): Context =
730730
val saved = lateCompile
731731
lateCompile = true
732-
try index(unit.untpdTree :: Nil) finally lateCompile = saved
732+
try
733+
index(unit.untpdTree :: Nil)
734+
finally
735+
lateCompile = saved
736+
if !typeCheck then ctx.run.advanceLate()
733737

734738
/** Set the tpdTree and root tree of the compilation unit */
735739
def lateTypeCheck()(using Context) =
736-
unit.tpdTree = typer.typedExpr(unit.untpdTree)
737-
val phase = new transform.SetRootTree()
738-
phase.run
740+
try
741+
unit.tpdTree = typer.typedExpr(unit.untpdTree)
742+
val phase = new transform.SetRootTree()
743+
phase.run
744+
finally
745+
if typeCheck then ctx.run.advanceLate()
739746

740747
unit.untpdTree =
741748
if (unit.isJava) new JavaParser(unit.source).parse()
@@ -746,9 +753,10 @@ class Namer { typer: Typer =>
746753
// inline body annotations are set in namer, capturing the current context
747754
// we need to prepare the context for inlining.
748755
lateEnter()
749-
typeCheckCB { () =>
750-
lateTypeCheck()
751-
}
756+
if typeCheck then
757+
typeCheckCB { () =>
758+
lateTypeCheck()
759+
}
752760
}
753761
}
754762
end lateEnterUnit

compiler/src/dotty/tools/dotc/typer/TyperPhase.scala

+9-3
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
6363
for unit <- units yield
6464
val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit)
6565
val newCtx = PrepareInlineable.initContext(newCtx0)
66+
newCtx.run.beginUnit(unit)
6667
report.inform(s"typing ${unit.source}")
6768
if (addRootImports)
6869
newCtx.withRootImports
6970
else
7071
newCtx
7172

72-
unitContexts.foreach(enterSyms(using _))
73+
for given Context <- unitContexts do
74+
enterSyms
7375

7476
ctx.base.parserPhase match {
7577
case p: ParserPhase =>
@@ -81,9 +83,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase {
8183
case _ =>
8284
}
8385

84-
unitContexts.foreach(typeCheck(using _))
86+
for given Context <- unitContexts do
87+
typeCheck
88+
8589
record("total trees after typer", ast.Trees.ntrees)
86-
unitContexts.foreach(javaCheck(using _)) // after typechecking to avoid cycles
90+
for given Context <- unitContexts do
91+
try javaCheck // after typechecking to avoid cycles
92+
finally ctx.run.advanceUnit()
8793

8894
val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper)
8995
ctx.run.nn.checkSuspendedUnits(newUnits)

compiler/test/dotty/tools/DottyTest.scala

+15-1
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,14 @@ trait DottyTest extends ContextEscapeDetection {
4444
fc.setProperty(ContextDoc, new ContextDocstrings)
4545
}
4646

47+
protected def defaultCompiler: Compiler = Compiler()
48+
4749
private def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler {
50+
51+
private val baseCompiler = defaultCompiler
52+
4853
override def phases = {
49-
val allPhases = super.phases
54+
val allPhases = baseCompiler.phases
5055
val targetPhase = allPhases.flatten.find(p => p.phaseName == phase).get
5156
val groupsBefore = allPhases.takeWhile(x => !x.contains(targetPhase))
5257
val lastGroup = allPhases.find(x => x.contains(targetPhase)).get.takeWhile(x => !(x eq targetPhase))
@@ -67,6 +72,15 @@ trait DottyTest extends ContextEscapeDetection {
6772
run.runContext
6873
}
6974

75+
def checkAfterCompile(checkAfterPhase: String, sources: List[String])(assertion: Context => Unit): Context = {
76+
val c = defaultCompiler
77+
val run = c.newRun
78+
run.compileFromStrings(sources)
79+
val rctx = run.runContext
80+
assertion(rctx)
81+
rctx
82+
}
83+
7084
def checkTypes(source: String, typeStrings: String*)(assertion: (List[Type], Context) => Unit): Unit =
7185
checkTypes(source, List(typeStrings.toList)) { (tpess, ctx) => (tpess: @unchecked) match {
7286
case List(tpes) => assertion(tpes, ctx)

0 commit comments

Comments
 (0)