diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala index 263251775694..f434cc6fe429 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala @@ -3,7 +3,7 @@ package io.joern.dataflowengineoss.queryengine import io.joern.dataflowengineoss.queryengine.QueryEngineStatistics.{PATH_CACHE_HITS, PATH_CACHE_MISSES} import io.joern.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language.{toCfgNodeMethods, toExpressionMethods, _} +import io.shiftleft.semanticcpg.language._ import java.util.concurrent.Callable import scala.collection.mutable @@ -33,7 +33,7 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg val table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]] = mutable.Map() results(task.sink, path, table, task.callSiteStack) // TODO why do we update the call depth here? - val finalResults = table.get(task.fingerprint).get.map { r => + val finalResults = table(task.fingerprint).map { r => r.copy( taskStack = r.taskStack.dropRight(1) :+ r.fingerprint.copy(callDepth = task.callDepth), path = r.path ++ task.initialPath @@ -68,20 +68,20 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg * @param callSiteStack * This stack holds all call sites we expanded to arrive at the generation of the current task */ - private def results[NodeType <: CfgNode]( + private def results( sink: CfgNode, path: Vector[PathElement], table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], callSiteStack: List[Call] )(implicit semantics: Semantics): Vector[ReachableByResult] = { - val curNode = path.head.node + val curNode = path.head.node.asInstanceOf[CfgNode] /** For each parent of the current node, determined via `expandIn`, check if results are available in the result * table. If not, determine results recursively. */ def computeResultsForParents() = { - deduplicateWithinTask(expandIn(curNode.asInstanceOf[CfgNode], path, callSiteStack).iterator.flatMap { parent => + deduplicateWithinTask(expandIn(curNode, path, callSiteStack).iterator.flatMap { parent => createResultsFromCacheOrCompute(parent, path) }.toVector) } @@ -117,14 +117,14 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg } def createResultsFromCacheOrCompute(elemToPrepend: PathElement, path: Vector[PathElement]) = { - val cachedResult = createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) - if (cachedResult.isDefined) { - QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L) - cachedResult.get - } else { - QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L) - val newPath = elemToPrepend +: path - results(sink, newPath, table, callSiteStack) + createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) match { + case Some(result) => + QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L) + result + case None => + QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L) + val newPath = elemToPrepend +: path + results(sink, newPath, table, callSiteStack) } } @@ -163,7 +163,7 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg */ val res = curNode match { // Case 1: we have reached a source => return result and continue traversing (expand into parents) - case x if sources.contains(x.asInstanceOf[NodeType]) => + case x if sources.contains(x) => if (x.isInstanceOf[MethodParameterIn]) { Vector( ReachableByResult(task.taskStack, path), @@ -172,31 +172,27 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg } else { Vector(ReachableByResult(task.taskStack, path)) ++ computeResultsForParents() } - // Case 2: we have reached a method parameter (that isn't a source) => return partial result and stop traversing + // Case 2: we have reached a method parameter (that isn't a source) + // => return partial result and stop traversing case _: MethodParameterIn => Vector(ReachableByResult(task.taskStack, path, partial = true)) - // Case 3: we have reached a call to an internal method without semantic (return value) and - // this isn't the start node => return partial result and stop traversing - case call: Call - if isCallToInternalMethodWithoutSemantic(call) - && !isArgOrRetOfMethodWeCameFrom(call, path) => + // Case 3: we have reached a call to an internal method without semantic (return value) + // => return partial result and stop traversing + case call: Call if isCallToInternalMethodWithoutSemantic(call) => createPartialResultForOutputArgOrRet() - - // Case 4: we have reached an argument to an internal method without semantic (output argument) and - // this isn't the start node nor is it the argument for the parameter we just expanded => return partial result and stop traversing + // Case 4: we have reached an argument to an internal method without semantic (output argument) and this isn't the start node + // => return partial result and stop traversing case arg: Expression if path.size > 1 - && arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) - && !arg.inCall.headOption.exists(x => isArgOrRetOfMethodWeCameFrom(x, path)) => + && arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) => createPartialResultForOutputArgOrRet() case _: MethodRef => createPartialResultForOutputArgOrRet() // All other cases: expand into parents - case _ => - computeResultsForParents() + case _ => computeResultsForParents() } - val key = TaskFingerprint(curNode.asInstanceOf[CfgNode], task.callSiteStack, task.callDepth) + val key = TaskFingerprint(curNode, task.callSiteStack, task.callDepth) table.updateWith(key) { case Some(existingValue) => Some(existingValue ++ res) case None => Some(res) @@ -204,11 +200,4 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg res } - private def isArgOrRetOfMethodWeCameFrom(call: Call, path: Vector[PathElement]): Boolean = - path match { - case Vector(_, PathElement(x: MethodReturn, _, _, _, _), _*) => methodsForCall(call).contains(x.method) - case Vector(_, PathElement(x: MethodParameterIn, _, _, _, _), _*) => methodsForCall(call).contains(x.method) - case _ => false - } - }