Skip to content

Commit 5710a74

Browse files
noti0na1tgodzik
authored andcommitted
Treat asserted set of terminated NotNullInfo as universal set; fix test
[Cherry-picked 00430c0]
1 parent 578c18b commit 5710a74

File tree

5 files changed

+40
-19
lines changed

5 files changed

+40
-19
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -755,13 +755,13 @@ object Contexts {
755755

756756
extension (c: Context)
757757
def addNotNullInfo(info: NotNullInfo) =
758-
c.withNotNullInfos(c.notNullInfos.extendWith(info))
758+
if c.explicitNulls then c.withNotNullInfos(c.notNullInfos.extendWith(info)) else c
759759

760760
def addNotNullRefs(refs: Set[TermRef]) =
761-
c.addNotNullInfo(NotNullInfo(refs, Set()))
761+
if c.explicitNulls then c.addNotNullInfo(NotNullInfo(refs, Set())) else c
762762

763763
def withNotNullInfos(infos: List[NotNullInfo]): Context =
764-
if c.notNullInfos eq infos then c else c.fresh.setNotNullInfos(infos)
764+
if !c.explicitNulls || (c.notNullInfos eq infos) then c else c.fresh.setNotNullInfos(infos)
765765

766766
def relaxedOverrideContext: Context =
767767
c.withModeBits(c.mode &~ Mode.SafeNulls | Mode.RelaxedOverriding)

compiler/src/dotty/tools/dotc/typer/Nullables.scala

+12-10
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ object Nullables:
210210
*/
211211
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
212212
case info :: infos1 =>
213-
if info.asserted != null && info.asserted.contains(ref) then true
213+
if info.asserted == null || info.asserted.contains(ref) then true
214214
else if info.retracted.contains(ref) then false
215215
else infos1.impliesNotNull(ref)
216216
case _ =>
@@ -290,8 +290,8 @@ object Nullables:
290290
extension (tree: Tree)
291291

292292
/* The `tree` with added nullability attachment */
293-
def withNotNullInfo(info: NotNullInfo): tree.type =
294-
if !info.isEmpty then tree.putAttachment(NNInfo, info)
293+
def withNotNullInfo(info: NotNullInfo)(using Context): tree.type =
294+
if ctx.explicitNulls && !info.isEmpty then tree.putAttachment(NNInfo, info)
295295
tree
296296

297297
/* Collect the nullability info from parts of `tree` */
@@ -310,13 +310,15 @@ object Nullables:
310310

311311
/* The nullability info of `tree` */
312312
def notNullInfo(using Context): NotNullInfo =
313-
val tree1 = stripInlined(tree)
314-
tree1.getAttachment(NNInfo) match
315-
case Some(info) if !ctx.erasedTypes => info
316-
case _ =>
317-
val nnInfo = tree1.collectNotNullInfo
318-
tree1.withNotNullInfo(nnInfo)
319-
nnInfo
313+
if !ctx.explicitNulls then NotNullInfo.empty
314+
else
315+
val tree1 = stripInlined(tree)
316+
tree1.getAttachment(NNInfo) match
317+
case Some(info) if !ctx.erasedTypes => info
318+
case _ =>
319+
val nnInfo = tree1.collectNotNullInfo
320+
tree1.withNotNullInfo(nnInfo)
321+
nnInfo
320322

321323
/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
322324
def notNullInfoIf(c: Boolean)(using Context): NotNullInfo =

compiler/src/dotty/tools/dotc/typer/Typer.scala

+2
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
25142514
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
25152515
postProcessInfo(vdef1, sym)
25162516
vdef1.setDefTree
2517+
val nnInfo = rhs1.notNullInfo
2518+
vdef1.withNotNullInfo(if sym.is(Lazy) then nnInfo.retractedInfo else nnInfo)
25172519
}
25182520

25192521
private def retractDefDef(sym: Symbol)(using Context): Tree =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class C(val x: Int, val next: C | Null)
2+
3+
def test1(x: String | Null, c: C | Null): Int =
4+
return 0
5+
// We know that the following code is unreachable,
6+
// so we can treat `x`, `c`, and any variable/path non-nullable.
7+
x.length + c.next.x
8+
9+
def test2(x: String | Null, c: C | Null): Int =
10+
throw new Exception()
11+
x.length + c.next.x
12+
13+
def fail(): Nothing = ???
14+
15+
def test3(x: String | Null, c: C | Null): Int =
16+
fail()
17+
x.length + c.next.x

tests/explicit-nulls/unsafe-common/unsafe-overload.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class S {
1616
val o: O = ???
1717

1818
locally {
19-
def h1(hh: String => String) = ???
20-
def h2(hh: Array[String] => Array[String]) = ???
19+
def h1(hh: String => String): Unit = ???
20+
def h2(hh: Array[String] => Array[String]): Unit = ???
2121
def f1(x: String | Null): String | Null = ???
2222
def f2(x: Array[String | Null]): Array[String | Null] = ???
2323

@@ -29,10 +29,10 @@ class S {
2929
}
3030

3131
locally {
32-
def h1(hh: String | Null => String | Null) = ???
33-
def h2(hh: Array[String | Null] => Array[String | Null]) = ???
32+
def h1(hh: String | Null => String | Null): Unit = ???
33+
def h2(hh: Array[String | Null] => Array[String | Null]): Unit = ???
3434
def g1(x: String): String = ???
35-
def g2(x: Array[String]): Array[String] = ???
35+
def g2(x: Array[String]): Array[String] = ???
3636

3737
h1(g1) // error
3838
h1(o.g) // error
@@ -51,7 +51,7 @@ class S {
5151

5252
locally {
5353
def g1(x: String): String = ???
54-
def g2(x: Array[String]): Array[String] = ???
54+
def g2(x: Array[String]): Array[String] = ???
5555

5656
o.i(g1) // error
5757
o.i(g2) // error

0 commit comments

Comments
 (0)