Skip to content

Add live variables analysis before type inference #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 15, 2025
Merged
Original file line number Diff line number Diff line change
@@ -104,6 +104,11 @@ class EtsApplicationGraphImpl(
.flatMap { it.methods }
.groupByTo(hashMapOf()) { it.name }
}
private val classMethodsByName by lazy {
projectClassesBySignature.mapValues { (_, clazz) ->
clazz.single().methods.groupBy { it.name }
}
}

private val cacheClassWithIdealSignature: MutableMap<EtsClassSignature, Maybe<EtsClass>> = hashMapOf()
private val cacheMethodWithIdealSignature: MutableMap<EtsMethodSignature, Maybe<EtsMethod>> = hashMapOf()
@@ -227,11 +232,8 @@ class EtsApplicationGraphImpl(

// If the complete signature match failed,
// try to find the unique not-the-same neighbour method in the same class:
val neighbors = cls.methods
.asSequence()
.filter { it.name == callee.name }
val neighbors = classMethodsByName[cls.signature].orEmpty()[callee.name].orEmpty()
.filterNot { it.name == node.method.name }
.toList()
if (neighbors.isNotEmpty()) {
val s = neighbors.singleOrNull()
?: error("Multiple methods with the same name: $neighbors")
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.usvm.dataflow.ts.infer

import org.jacodb.ets.base.EtsArrayAccess
import org.jacodb.ets.base.EtsAwaitExpr
import org.jacodb.ets.base.EtsCastExpr
import org.jacodb.ets.base.EtsConstant
import org.jacodb.ets.base.EtsEntity
@@ -129,6 +130,8 @@ fun EtsEntity.toPathOrNull(): AccessPath? = when (this) {

is EtsCastExpr -> arg.toPathOrNull()

is EtsAwaitExpr -> arg.toPathOrNull()

else -> null
}

202 changes: 120 additions & 82 deletions usvm-ts-dataflow/src/main/kotlin/org/usvm/dataflow/ts/infer/Alias.kt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -14,9 +14,18 @@ class ForwardAnalyzer(
methodInitialTypes: Map<EtsMethod, Map<AccessPathBase, EtsTypeFact>>,
typeInfo: Map<EtsType, EtsTypeFact>,
doAddKnownTypes: Boolean = true,
doAliasAnalysis: Boolean = true,
val doLiveVariablesAnalysis: Boolean = true,
) : Analyzer<ForwardTypeDomainFact, AnalyzerEvent, EtsMethod, EtsStmt> {

override val flowFunctions = ForwardFlowFunctions(graph, methodInitialTypes, typeInfo, doAddKnownTypes)
override val flowFunctions = ForwardFlowFunctions(
graph = graph,
methodInitialTypes = methodInitialTypes,
typeInfo = typeInfo,
doAddKnownTypes = doAddKnownTypes,
doAliasAnalysis = doAliasAnalysis,
doLiveVariablesAnalysis = doLiveVariablesAnalysis,
)

override fun handleCrossUnitCall(
caller: Vertex<ForwardTypeDomainFact, EtsStmt>,
@@ -25,13 +34,27 @@ class ForwardAnalyzer(
error("No cross unit calls")
}

private val liveVariablesCache = hashMapOf<EtsMethod, LiveVariables>()
private fun liveVariables(method: EtsMethod): LiveVariables =
liveVariablesCache.computeIfAbsent(method) {
if (doLiveVariablesAnalysis) LiveVariables.from(method) else AlwaysAlive
}

private fun variableIsDying(fact: ForwardTypeDomainFact, stmt: EtsStmt): Boolean {
if (fact !is ForwardTypeDomainFact.TypedVariable) return false
val base = fact.variable.base
if (base !is AccessPathBase.Local) return false
return !liveVariables(stmt.method).isAliveAt(base.name, stmt)
}

override fun handleNewEdge(edge: Edge<ForwardTypeDomainFact, EtsStmt>): List<AnalyzerEvent> {
val (startVertex, currentVertex) = edge
val (current, currentFact) = currentVertex
val method = graph.methodOf(current)
val currentIsExit = current in graph.exitPoints(method) ||
(current is EtsNopStmt && graph.successors(current).none())
if (currentIsExit) {

if (currentIsExit || variableIsDying(currentFact, current)) {
return listOf(
ForwardSummaryAnalyzerEvent(
method = method,
Original file line number Diff line number Diff line change
@@ -5,8 +5,10 @@ import org.jacodb.ets.base.EtsAnyType
import org.jacodb.ets.base.EtsArithmeticExpr
import org.jacodb.ets.base.EtsArrayAccess
import org.jacodb.ets.base.EtsAssignStmt
import org.jacodb.ets.base.EtsAwaitExpr
import org.jacodb.ets.base.EtsBooleanConstant
import org.jacodb.ets.base.EtsCastExpr
import org.jacodb.ets.base.EtsClassType
import org.jacodb.ets.base.EtsFieldRef
import org.jacodb.ets.base.EtsInstanceCallExpr
import org.jacodb.ets.base.EtsLocal
@@ -23,6 +25,7 @@ import org.jacodb.ets.base.EtsStmt
import org.jacodb.ets.base.EtsStringConstant
import org.jacodb.ets.base.EtsThis
import org.jacodb.ets.base.EtsType
import org.jacodb.ets.base.EtsUnclearRefType
import org.jacodb.ets.base.EtsUndefinedConstant
import org.jacodb.ets.base.EtsUnknownType
import org.jacodb.ets.model.EtsMethod
@@ -44,16 +47,33 @@ class ForwardFlowFunctions(
val methodInitialTypes: Map<EtsMethod, Map<AccessPathBase, EtsTypeFact>>,
val typeInfo: Map<EtsType, EtsTypeFact>,
val doAddKnownTypes: Boolean = true,
val doAliasAnalysis: Boolean = true,
val doLiveVariablesAnalysis: Boolean = true,
) : FlowFunctions<ForwardTypeDomainFact, EtsMethod, EtsStmt> {

private val typeProcessor = TypeFactProcessor(graph.cp)

private val aliasesCache: MutableMap<EtsMethod, List<StmtAliasInfo>> = hashMapOf()

private fun getAliases(method: EtsMethod): List<StmtAliasInfo> {
return aliasesCache.computeIfAbsent(method) { MethodAliasInfo(method).computeAliases() }
return aliasesCache.computeIfAbsent(method) {
if (doAliasAnalysis) {
MethodAliasInfoImpl(method).computeAliases()
} else {
NoMethodAliasInfo(method).computeAliases()
}
}
}

private val liveVariablesCache = hashMapOf<EtsMethod, LiveVariables>()
private fun liveVariables(method: EtsMethod) =
liveVariablesCache.computeIfAbsent(method) {
if (doLiveVariablesAnalysis) {
LiveVariables.from(method)
} else {
AlwaysAlive
}
}

override fun obtainPossibleStartFacts(method: EtsMethod): Collection<ForwardTypeDomainFact> {
val initialTypes = methodInitialTypes[method] ?: return listOf(Zero)

@@ -147,7 +167,16 @@ class ForwardFlowFunctions(
}
when (fact) {
Zero -> sequentZero(current)
is TypedVariable -> sequentFact(current, fact).myFilter()
is TypedVariable -> {
val liveVars = liveVariables(current.method)
sequentFact(current, fact).myFilter()
.filter {
when (val base = it.variable.base) {
is AccessPathBase.Local -> liveVars.isAliveAt(base.name, current)
else -> true
}
}
}
}
}

@@ -163,7 +192,13 @@ class ForwardFlowFunctions(
if (path.accesses.isNotEmpty()) {
check(path.accesses.size == 1)
val base = AccessPath(path.base, emptyList())
for (alias in preAliases.getAliases(base)) {
val aliases = preAliases.getAliases(base).filter {
when (val b = it.base) {
is AccessPathBase.Local -> liveVariables(current.method).isAliveAt(b.name, current)
else -> true
}
}
for (alias in aliases) {
val newPath = alias + path.accesses.single()
result += TypedVariable(newPath, type)
}
@@ -267,6 +302,7 @@ class ForwardFlowFunctions(
is EtsFieldRef -> r.toPath()
is EtsArrayAccess -> r.toPath()
is EtsCastExpr -> r.toPath()
is EtsAwaitExpr -> r.toPath()
else -> {
// logger.info { "TODO forward assign: $current" }
null
@@ -338,9 +374,30 @@ class ForwardFlowFunctions(
// Using the cast type directly is just a temporary solution to satisfy simple tests.
if (current.rhv is EtsCastExpr) {
val path = AccessPath(lhv.base, fact.variable.accesses)
// val type = EtsTypeFact.from((current.rhv as EtsCastExpr).type).intersect(fact.type) ?: fact.type
val type = EtsTypeFact.from((current.rhv as EtsCastExpr).type)

return listOf(fact, TypedVariable(path, type))
} else if (current.rhv is EtsAwaitExpr) {
val path = AccessPath(lhv.base, fact.variable.accesses)
val promiseType = fact.type

if (promiseType is EtsTypeFact.ObjectEtsTypeFact) {
val promiseClass = promiseType.cls

if (promiseClass is EtsClassType && promiseClass.signature.name == "Promise") {
val type = EtsTypeFact.from(
type = promiseClass.typeParameters.singleOrNull() ?: return listOf(fact)
)
return listOf(fact, TypedVariable(path, type))
}

if (promiseClass is EtsUnclearRefType && promiseClass.name.startsWith("Promise")) {
val type = EtsTypeFact.from(
type = promiseClass.typeParameters.singleOrNull() ?: return listOf(fact)
)
return listOf(fact, TypedVariable(path, type))
}
}
}

val path = AccessPath(lhv.base, fact.variable.accesses)
@@ -426,7 +483,15 @@ class ForwardFlowFunctions(
val path1 = lhv + fact.variable.accesses
result += TypedVariable(path1, fact.type)
// }
for (z in preAliases.getAliases(x)) {

val aliases = preAliases.getAliases(x).filter {
when (val b = it.base) {
is AccessPathBase.Local -> liveVariables(current.method).isAliveAt(b.name, current)
else -> true
}
}

for (z in aliases) {
// skip duplicate fields
// if (z.accesses.firstOrNull() != a) {
// TODO: what about z.accesses.last == a ?
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package org.usvm.dataflow.ts.infer

import org.jacodb.ets.base.EtsArrayAccess
import org.jacodb.ets.base.EtsAssignStmt
import org.jacodb.ets.base.EtsBinaryExpr
import org.jacodb.ets.base.EtsCallExpr
import org.jacodb.ets.base.EtsCallStmt
import org.jacodb.ets.base.EtsCastExpr
import org.jacodb.ets.base.EtsEntity
import org.jacodb.ets.base.EtsIfStmt
import org.jacodb.ets.base.EtsInstanceCallExpr
import org.jacodb.ets.base.EtsInstanceFieldRef
import org.jacodb.ets.base.EtsInstanceOfExpr
import org.jacodb.ets.base.EtsLocal
import org.jacodb.ets.base.EtsReturnStmt
import org.jacodb.ets.base.EtsStmt
import org.jacodb.ets.base.EtsSwitchStmt
import org.jacodb.ets.base.EtsTernaryExpr
import org.jacodb.ets.base.EtsThrowStmt
import org.jacodb.ets.base.EtsUnaryExpr
import org.jacodb.ets.base.EtsValue
import org.jacodb.ets.model.EtsMethod
import java.util.BitSet

interface LiveVariables {
fun isAliveAt(local: String, stmt: EtsStmt): Boolean

companion object {
private const val THRESHOLD: Int = 20

fun from(method: EtsMethod): LiveVariables =
if (method.cfg.stmts.size > THRESHOLD) LiveVariablesImpl(method) else AlwaysAlive
}
}

object AlwaysAlive : LiveVariables {
override fun isAliveAt(local: String, stmt: EtsStmt): Boolean = true
}

class LiveVariablesImpl(
val method: EtsMethod,
) : LiveVariables {
companion object {
private fun EtsEntity.used(): List<String> = when (this) {
is EtsValue -> this.used()
is EtsUnaryExpr -> arg.used()
is EtsBinaryExpr -> left.used() + right.used()
is EtsCallExpr -> this.used()
is EtsCastExpr -> arg.used()
is EtsInstanceOfExpr -> arg.used()
is EtsTernaryExpr -> condition.used() + thenExpr.used() + elseExpr.used()
else -> emptyList()
}

private fun EtsValue.used(): List<String> = when (this) {
is EtsLocal -> listOf(name)
is EtsInstanceFieldRef -> listOf(instance.name)
is EtsArrayAccess -> array.used() + index.used()
else -> emptyList()
}

private fun EtsCallExpr.used(): List<String> = when (this) {
is EtsInstanceCallExpr -> listOf(instance.name) + args.flatMap { it.used() }
else -> args.flatMap { it.used() }
}
}

private val aliveAtStmt: Array<BitSet>
private val indexOfName = hashMapOf<String, Int>()
private val definedAtStmt = IntArray(method.cfg.stmts.size) { -1 }

private fun emptyBitSet() = BitSet(indexOfName.size)
private fun BitSet.copy() = clone() as BitSet

init {
for (stmt in method.cfg.stmts) {
if (stmt is EtsAssignStmt) {
when (val lhv = stmt.lhv) {
is EtsLocal -> {
val lhvIndex = indexOfName.size
definedAtStmt[stmt.location.index] = lhvIndex
indexOfName[lhv.name] = lhvIndex
}
}
}
}

aliveAtStmt = Array(method.cfg.stmts.size) { emptyBitSet() }

val queue = method.cfg.stmts.toHashSet()
while (queue.isNotEmpty()) {
val stmt = queue.first()
queue.remove(stmt)

val aliveHere = emptyBitSet().apply {
val usedLocals = when (stmt) {
is EtsAssignStmt -> stmt.lhv.used() + stmt.rhv.used()
is EtsCallStmt -> stmt.expr.used()
is EtsReturnStmt -> stmt.returnValue?.used().orEmpty()
is EtsIfStmt -> stmt.condition.used()
is EtsSwitchStmt -> stmt.arg.used()
is EtsThrowStmt -> stmt.arg.used()
else -> emptyList()
}

usedLocals.mapNotNull { indexOfName[it] }.forEach { set(it) }
}

for (succ in method.cfg.successors(stmt)) {
val transferFromSucc = aliveAtStmt[succ.location.index].copy()
val definedAtSucc = definedAtStmt[succ.location.index]
if (definedAtSucc != -1) {
transferFromSucc.clear(definedAtSucc)
}

aliveHere.or(transferFromSucc)
}

if (aliveHere != aliveAtStmt[stmt.location.index]) {
aliveAtStmt[stmt.location.index] = aliveHere
if (stmt !in method.cfg.entries) {
queue.addAll(method.cfg.predecessors(stmt))
}
}
}
}

override fun isAliveAt(local: String, stmt: EtsStmt): Boolean {
if (stmt.location.index < 0) return true
val index = indexOfName[local] ?: return true
return aliveAtStmt[stmt.location.index].get(index)
}
}
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@ import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
import mu.KotlinLogging
import org.jacodb.ets.base.ANONYMOUS_CLASS_PREFIX
import org.jacodb.ets.base.CONSTRUCTOR_NAME
import org.jacodb.ets.base.EtsReturnStmt
import org.jacodb.ets.base.EtsStmt
@@ -80,10 +79,12 @@ class TypeInferenceManager(
allMethods: List<EtsMethod> = entrypoints,
doAddKnownTypes: Boolean = true,
doInferAllLocals: Boolean = true,
doAliasAnalysis: Boolean = true,
): TypeInferenceResult = runBlocking {
val methodTypeScheme = collectSummaries(
startMethods = entrypoints,
doAddKnownTypes = doAddKnownTypes,
doAliasAnalysis = doAliasAnalysis,
)
val remainingMethodsForAnalysis = allMethods.filter { it !in methodTypeScheme.keys }

@@ -93,6 +94,7 @@ class TypeInferenceManager(
collectSummaries(
startMethods = remainingMethodsForAnalysis,
doAddKnownTypes = doAddKnownTypes,
doAliasAnalysis = doAliasAnalysis,
)
}

@@ -105,6 +107,7 @@ class TypeInferenceManager(
private suspend fun collectSummaries(
startMethods: List<EtsMethod>,
doAddKnownTypes: Boolean = true,
doAliasAnalysis: Boolean = true,
): Map<EtsMethod, Map<AccessPathBase, EtsTypeFact>> {
logger.info { "Preparing backward analysis" }
val backwardGraph = graph.reversed
@@ -187,7 +190,15 @@ class TypeInferenceManager(

logger.info { "Preparing forward analysis" }
val forwardGraph = graph
val forwardAnalyzer = ForwardAnalyzer(forwardGraph, methodTypeScheme, typeInfo, doAddKnownTypes)
val forwardAnalyzer = ForwardAnalyzer(
forwardGraph,
methodTypeScheme,
typeInfo,
doAddKnownTypes,
doAliasAnalysis = doAliasAnalysis,
doLiveVariablesAnalysis = true,
)

val forwardRunner = UniRunner(
traits = traits,
manager = this@TypeInferenceManager,
@@ -333,21 +344,11 @@ class TypeInferenceManager(
private fun getInferredCombinedThisTypes(
methodTypeScheme: Map<EtsMethod, Map<AccessPathBase, EtsTypeFact>>,
): Map<EtsClassSignature, EtsTypeFact> {
val classBySignature = graph.cp.projectAndSdkClasses
.groupByTo(hashMapOf()) { it.signature }

val allClasses = methodTypeScheme.keys
.map { it.enclosingClass }
.distinct()
.map { sig -> classBySignature[sig].orEmpty().first() }
.filterNot { it.name.startsWith(ANONYMOUS_CLASS_PREFIX) }

val forwardSummariesByClass = forwardSummaries
.entries.groupByTo(hashMapOf()) { (method, _) -> method.enclosingClass }

return allClasses.mapNotNull { cls ->
val clsMethods = (cls.methods + cls.ctor).toHashSet()
val combinedBackwardType = clsMethods
return graph.cp.projectClasses.mapNotNull { cls ->
val combinedBackwardType = (cls.methods + cls.ctor)
.mapNotNull { methodTypeScheme[it] }
.mapNotNull { facts -> facts[AccessPathBase.This] }.reduceOrNull { acc, type ->
typeProcessor.intersect(acc, type) ?: run {
Original file line number Diff line number Diff line change
@@ -70,6 +70,11 @@ class InferTypes : CliktCommand() {
help = "Do take into account the known types in scene"
).flag("--no-use-known-types", default = true)

val enableAliasAnalysis by option(
"--alias-analysis",
help = "Enable alias analysis"
).flag("--no-alias-analysis", default = true)

override fun run() {
logger.info { "Running InferTypes" }
val startTime = System.currentTimeMillis()
@@ -91,6 +96,7 @@ class InferTypes : CliktCommand() {
entrypoints = dummyMains,
allMethods = publicMethods,
doAddKnownTypes = useKnownTypes,
doAliasAnalysis = enableAliasAnalysis,
)
}
logger.info { "Inferred types for ${resultBasic.inferredTypes.size} methods in $timeAnalyze" }
Original file line number Diff line number Diff line change
@@ -28,7 +28,8 @@ import java.io.File
@Disabled
class EtsTypeResolverPerformanceTest {
companion object {
const val RUNS = 5
const val WARMUP_ITERATIONS = 5
const val TEST_ITERATIONS = 5
const val OUTPUT_FILE = "performance_report.md"

@JvmStatic
@@ -38,7 +39,7 @@ class EtsTypeResolverPerformanceTest {
}

private fun runOnAbcProject(projectID: String, abcPath: String): PerformanceReport {
val report = generateReportForProject(projectID, abcPath, RUNS)
val report = generateReportForProject(projectID, abcPath, WARMUP_ITERATIONS, TEST_ITERATIONS)
return report
}

@@ -75,15 +76,14 @@ class EtsTypeResolverPerformanceTest {
)
)

val file = File(OUTPUT_FILE)
file.writeText(
buildString {
appendLine("|project|min time|max time|%|")
appendLine("|:--|:--|:--|:--|")
reports.forEach {
appendLine(it.dumpToString())
}
val reportStr = buildString {
appendLine("|project|min time|max time|avg time|median time|%|")
appendLine("|:--|:--|:--|:--|:--|:--|")
reports.forEach {
appendLine(it.dumpToString())
}
)
}
val file = File(OUTPUT_FILE)
file.writeText(reportStr)
}
}
Original file line number Diff line number Diff line change
@@ -48,6 +48,12 @@ object AbcProjects {
}

fun runOnAbcProject(scene: EtsScene): Pair<TypeInferenceResult, TypeInferenceStatistics> {
val result = inferTypes(scene)
val statistics = calculateStatistics(scene, result)
return result to statistics
}

fun inferTypes(scene: EtsScene): TypeInferenceResult {
val abcScene = when (val result = verify(scene)) {
is VerificationResult.Success -> scene
is VerificationResult.Fail -> scene.annotateWithTypes(result.erasureScheme)
@@ -66,13 +72,19 @@ object AbcProjects {
.analyze(entrypoint.mainMethods, allMethods, doAddKnownTypes = true)
.withGuessedTypes(guesser)

return result
}

fun calculateStatistics(scene: EtsScene, result: TypeInferenceResult): TypeInferenceStatistics {
val graphAbc = createApplicationGraph(scene)
val entrypoint = EntryPointsProcessor(scene).extractEntryPoints()
val sceneStatistics = TypeInferenceStatistics()
entrypoint.allMethods
.filter { it.cfg.stmts.isNotEmpty() }
.forEach {
val methodTypeFacts = MethodTypesFacts.from(result, it)
sceneStatistics.compareSingleMethodFactsWithTypesInScene(methodTypeFacts, it, graphAbc)
}
return Pair(result, sceneStatistics)
return sceneStatistics
}
}
Original file line number Diff line number Diff line change
@@ -16,20 +16,27 @@

package org.usvm.dataflow.ts.test.utils

import java.math.RoundingMode
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.measureTimedValue


data class PerformanceReport(
val projectId: String,
val maxTime: Duration,
val avgTime: Duration,
val medianTime: Duration,
val minTime: Duration,
val improvement: Double,
) {
fun dumpToString(): String = listOf<Any>(
projectId,
minTime,
maxTime,
improvement
minTime.toString(unit = DurationUnit.SECONDS, decimals = 3),
maxTime.toString(unit = DurationUnit.SECONDS, decimals = 3),
avgTime.toString(unit = DurationUnit.SECONDS, decimals = 3),
medianTime.toString(unit = DurationUnit.SECONDS, decimals = 3),
improvement.toBigDecimal().setScale(4, RoundingMode.HALF_UP).toDouble()
).joinToString(
separator = "|",
prefix = "|",
@@ -40,24 +47,29 @@ data class PerformanceReport(
fun generateReportForProject(
projectId: String,
abcPath: String,
runCount: Int,
warmupIterationsCount: Int,
runIterationsCount: Int,
): PerformanceReport {
val abcScene = AbcProjects.getAbcProject(abcPath)

val results = List(runCount) {
val (statistics, time) = measureTimedValue {
AbcProjects.runOnAbcProject(abcScene).second
val results = List(warmupIterationsCount + runIterationsCount) {
val (result, time) = measureTimedValue {
AbcProjects.inferTypes(abcScene)
}
val statistics = AbcProjects.calculateStatistics(abcScene, result)
time to statistics.calculateImprovement()
}
}.drop(warmupIterationsCount)

val times = results.map { it.first }
val improvement = results.map { it.second }.distinct().single()
val improvement = results.map { it.second }.distinct().first()
val totalTime = times.reduce { acc, duration -> acc + duration }

return PerformanceReport(
projectId = projectId,
minTime = times.min(),
maxTime = times.max(),
avgTime = totalTime / runIterationsCount,
medianTime = times.sorted()[runIterationsCount / 2],
improvement = improvement
)
}