Skip to content

Commit bbe3650

Browse files
committed
Don't prematurely force info of currently defined fields with inferred types
Don't prematurely force info of currently defined fields with inferred types when computing captureSetImpliedByFields. Fixes #24335
1 parent b1a5500 commit bbe3650

File tree

4 files changed

+94
-60
lines changed

4 files changed

+94
-60
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ class CheckCaptures extends Recheck, SymTransformer:
964964
case cls: ClassSymbol =>
965965
var fieldClassifiers =
966966
for
967-
sym <- cls.info.decls.toList
967+
sym <- setup.fieldsWithExplicitTypes.getOrElse(cls, cls.info.decls.toList)
968968
if contributesFreshToClass(sym)
969969
case fresh: FreshCap <- sym.info.spanCaptureSet.elems
970970
.filter(_.isTerminalCapability)

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 75 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ trait SetupAPI:
4040
/** Check to do after the capture checking traversal */
4141
def postCheck()(using Context): Unit
4242

43+
/** A map from currently compiled class symbols to those of their fields
44+
* that have an explicit type given. Used in `captureSetImpliedByFields`
45+
* to avoid forcing fields with inferred types prematurely. The test file
46+
* where this matters is i24335.scala. The precise failure scenario which
47+
* this avoids is described in #24335.
48+
*/
49+
def fieldsWithExplicitTypes: collection.Map[ClassSymbol, List[Symbol]]
50+
4351
/** Used for error reporting:
4452
* Maps mutable variables to the symbols that capture them (in the
4553
* CheckCaptures sense, i.e. symbol is referred to from a different method
@@ -52,6 +60,7 @@ trait SetupAPI:
5260
* the function that is called.
5361
*/
5462
def anonFunCallee: collection.Map[Symbol, Symbol]
63+
5564
end SetupAPI
5665

5766
object Setup:
@@ -489,6 +498,12 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
489498
extension (sym: Symbol) def nextInfo(using Context): Type =
490499
atPhase(thisPhase.next)(sym.info)
491500

501+
val fieldsWithExplicitTypes: mutable.HashMap[ClassSymbol, List[Symbol]] = mutable.HashMap()
502+
503+
val capturedBy: mutable.HashMap[Symbol, Symbol] = mutable.HashMap()
504+
505+
val anonFunCallee: mutable.HashMap[Symbol, Symbol] = mutable.HashMap()
506+
492507
/** A traverser that adds knownTypes and updates symbol infos */
493508
def setupTraverser(checker: CheckerAPI) = new TreeTraverserWithPreciseImportContexts:
494509
import checker.*
@@ -693,59 +708,65 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
693708
case tree: Bind =>
694709
val sym = tree.symbol
695710
updateInfo(sym, transformInferredType(sym.info), sym.owner)
696-
case tree: TypeDef =>
697-
tree.symbol match
698-
case cls: ClassSymbol =>
699-
checkClassifiedInheritance(cls)
700-
val cinfo @ ClassInfo(prefix, _, ps, decls, selfInfo) = cls.classInfo
701-
702-
// Compute new self type
703-
def isInnerModule = cls.is(ModuleClass) && !cls.isStatic
704-
val selfInfo1 =
705-
if (selfInfo ne NoType) && !isInnerModule then
706-
// if selfInfo is explicitly given then use that one, except if
707-
// self info applies to non-static modules, these still need to be inferred
708-
selfInfo
709-
else if cls.isPureClass then
710-
// is cls is known to be pure, nothing needs to be added to self type
711-
selfInfo
712-
else if !cls.isEffectivelySealed && !cls.baseClassHasExplicitNonUniversalSelfType then
713-
// assume {cap} for completely unconstrained self types of publicly extensible classes
714-
CapturingType(cinfo.selfType, CaptureSet.universal)
715-
else
716-
// Infer the self type for the rest, which is all classes without explicit
717-
// self types (to which we also add nested module classes), provided they are
718-
// neither pure, nor are publicily extensible with an unconstrained self type.
719-
val cs = CaptureSet.ProperVar(cls, CaptureSet.emptyRefs, nestedOK = false, isRefining = false)
720-
if cls.derivesFrom(defn.Caps_Capability) then
721-
// If cls is a capability class, we need to add a fresh capability to ensure
722-
// we cannot treat the class as pure.
723-
CaptureSet.fresh(cls, cls.thisType, Origin.InDecl(cls)).subCaptures(cs)
724-
CapturingType(cinfo.selfType, cs)
725-
726-
// Compute new parent types
727-
val ps1 = inContext(ctx.withOwner(cls)):
728-
ps.mapConserve(transformExplicitType(_, NoSymbol, freshen = false))
729-
730-
// Install new types and if it is a module class also update module object
731-
if (selfInfo1 ne selfInfo) || (ps1 ne ps) then
732-
val newInfo = ClassInfo(prefix, cls, ps1, decls, selfInfo1)
733-
updateInfo(cls, newInfo, cls.owner)
734-
capt.println(i"update class info of $cls with parents $ps selfinfo $selfInfo to $newInfo")
735-
cls.thisType.asInstanceOf[ThisType].invalidateCaches()
736-
if cls.is(ModuleClass) then
737-
// if it's a module, the capture set of the module reference is the capture set of the self type
738-
val modul = cls.sourceModule
739-
val selfCaptures = selfInfo1 match
740-
case CapturingType(_, refs) => refs
741-
case _ => CaptureSet.empty
742-
// Note: Can't do val selfCaptures = selfInfo1.captureSet here.
743-
// This would potentially give stackoverflows when setup is run repeatedly.
744-
// One test case is pos-custom-args/captures/checkbounds.scala under
745-
// ccConfig.alwaysRepeatRun = true.
746-
updateInfo(modul, CapturingType(modul.info, selfCaptures), modul.owner)
747-
modul.termRef.invalidateCaches()
748-
case _ =>
711+
case tree @ TypeDef(_, impl: Template) =>
712+
val cls: ClassSymbol = tree.symbol.asClass
713+
714+
fieldsWithExplicitTypes(cls) =
715+
for
716+
case vd @ ValDef(_, tpt: TypeTree, _) <- impl.body
717+
if !tpt.isInferred && vd.symbol.exists && !vd.symbol.is(NonMember)
718+
yield
719+
vd.symbol
720+
721+
checkClassifiedInheritance(cls)
722+
val cinfo @ ClassInfo(prefix, _, ps, decls, selfInfo) = cls.classInfo
723+
724+
// Compute new self type
725+
def isInnerModule = cls.is(ModuleClass) && !cls.isStatic
726+
val selfInfo1 =
727+
if (selfInfo ne NoType) && !isInnerModule then
728+
// if selfInfo is explicitly given then use that one, except if
729+
// self info applies to non-static modules, these still need to be inferred
730+
selfInfo
731+
else if cls.isPureClass then
732+
// is cls is known to be pure, nothing needs to be added to self type
733+
selfInfo
734+
else if !cls.isEffectivelySealed && !cls.baseClassHasExplicitNonUniversalSelfType then
735+
// assume {cap} for completely unconstrained self types of publicly extensible classes
736+
CapturingType(cinfo.selfType, CaptureSet.universal)
737+
else
738+
// Infer the self type for the rest, which is all classes without explicit
739+
// self types (to which we also add nested module classes), provided they are
740+
// neither pure, nor are publicily extensible with an unconstrained self type.
741+
val cs = CaptureSet.ProperVar(cls, CaptureSet.emptyRefs, nestedOK = false, isRefining = false)
742+
if cls.derivesFrom(defn.Caps_Capability) then
743+
// If cls is a capability class, we need to add a fresh capability to ensure
744+
// we cannot treat the class as pure.
745+
CaptureSet.fresh(cls, cls.thisType, Origin.InDecl(cls)).subCaptures(cs)
746+
CapturingType(cinfo.selfType, cs)
747+
748+
// Compute new parent types
749+
val ps1 = inContext(ctx.withOwner(cls)):
750+
ps.mapConserve(transformExplicitType(_, NoSymbol, freshen = false))
751+
752+
// Install new types and if it is a module class also update module object
753+
if (selfInfo1 ne selfInfo) || (ps1 ne ps) then
754+
val newInfo = ClassInfo(prefix, cls, ps1, decls, selfInfo1)
755+
updateInfo(cls, newInfo, cls.owner)
756+
capt.println(i"update class info of $cls with parents $ps selfinfo $selfInfo to $newInfo")
757+
cls.thisType.asInstanceOf[ThisType].invalidateCaches()
758+
if cls.is(ModuleClass) then
759+
// if it's a module, the capture set of the module reference is the capture set of the self type
760+
val modul = cls.sourceModule
761+
val selfCaptures = selfInfo1 match
762+
case CapturingType(_, refs) => refs
763+
case _ => CaptureSet.empty
764+
// Note: Can't do val selfCaptures = selfInfo1.captureSet here.
765+
// This would potentially give stackoverflows when setup is run repeatedly.
766+
// One test case is pos-custom-args/captures/checkbounds.scala under
767+
// ccConfig.alwaysRepeatRun = true.
768+
updateInfo(modul, CapturingType(modul.info, selfCaptures), modul.owner)
769+
modul.termRef.invalidateCaches()
749770
case _ =>
750771
end postProcess
751772

@@ -918,16 +939,11 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
918939
else t
919940
case _ => mapFollowingAliases(t)
920941

921-
val capturedBy: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]()
922-
923-
val anonFunCallee: mutable.HashMap[Symbol, Symbol] = mutable.HashMap[Symbol, Symbol]()
924-
925942
/** Run setup on a compilation unit with given `tree`.
926943
* @param recheckDef the function to run for completing a val or def
927944
*/
928945
def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit =
929-
inContext(ctx.withPhase(thisPhase)):
930-
setupTraverser(checker).traverse(tree)
946+
setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase))
931947

932948
// ------ Checks to run at Setup ----------------------------------------
933949

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i24335.scala:7:22 ----------------------------------------
2+
7 | val _: () -> Unit = l1 // error
3+
| ^^
4+
| Found: (C.this.l1 : () ->{C.this.c.io} Unit)
5+
| Required: () -> Unit
6+
|
7+
| Note that capability C.this.c.io is not included in capture set {}.
8+
|
9+
| longer explanation available when compiling with `-explain`
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class IO extends caps.SharedCapability:
2+
def write(): Unit = ()
3+
4+
class C(val io: IO):
5+
val c = C(io)
6+
val l1 = () => c.io.write()
7+
val _: () -> Unit = l1 // error
8+
9+

0 commit comments

Comments
 (0)