From c969255adbd724efcd07cc32a8e20157322d867e Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Fri, 5 May 2023 12:22:48 +0200 Subject: [PATCH 1/3] javasrc2cpg: Identifier Decl + Method Type Args * Fixed bug where `MethodReturn` type will be `ANY` for generic types due to type arguments being added to the search * Created a prototype Type->TypeArgument tree for representing nested type arguments * Added `evalType` edges to `Ast` class --- .../joern/javasrc2cpg/passes/AstCreator.scala | 86 ++++++++++++++++--- .../querying/TypeInferenceTests.scala | 37 +++++++- .../src/main/scala/io/joern/x2cpg/Ast.scala | 25 +++++- .../scala/io/joern/x2cpg/AstCreatorBase.scala | 6 +- 4 files changed, 132 insertions(+), 22 deletions(-) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala index eeece5a4afca..e9f5ee67f90d 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala @@ -1,6 +1,6 @@ package io.joern.javasrc2cpg.passes -import com.github.javaparser.ast.`type`.TypeParameter +import com.github.javaparser.ast.`type`.{ClassOrInterfaceType, Type, TypeParameter} import com.github.javaparser.ast.{CompilationUnit, Node, NodeList, PackageDeclaration} import com.github.javaparser.ast.body.{ AnnotationDeclaration, @@ -145,8 +145,11 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ NewNamespaceBlock, NewNode, NewReturn, + NewType, + NewTypeArgument, NewTypeDecl, - NewTypeRef + NewTypeRef, + NewUnknown } import io.joern.x2cpg.{Ast, AstCreatorBase, Defines} import io.joern.x2cpg.datastructures.Global @@ -698,7 +701,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa val modifiers = List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) - methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers) + methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, Ast(returnNode), modifiers) } private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = { @@ -820,7 +823,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa constructorNode, thisAst :: parameterAsts, bodyAst, - methodReturn, + Ast(methodReturn), modifiers, annotationAsts ) @@ -1003,7 +1006,15 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa val expectedReturnType = Try(symbolSolver.toResolvedType(methodDeclaration.getType, classOf[ResolvedType])).toOption val returnTypeFullName = expectedReturnType .flatMap(typeInfoCalc.fullName) - .orElse(scopeStack.lookupVariableType(methodDeclaration.getTypeAsString, wildcardFallback = true)) + .orElse( + scopeStack.lookupVariableType(methodDeclaration.getTypeAsString.takeWhile(_ != '<'), wildcardFallback = true) + ) + .orElse(Option(s"${Defines.UnresolvedNamespace}.${methodDeclaration.getTypeAsString}")) + val typeNode = methodDeclaration.getType match { + case x: ClassOrInterfaceType if x.getTypeArguments.isPresent => + astForGenericType(x) + case _ => Ast() // This will be created by some TypePass + } scopeStack.pushNewScope(MethodScope(ExpectedType(returnTypeFullName, expectedReturnType))) @@ -1036,6 +1047,10 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa line(methodDeclaration.getType), column(methodDeclaration.getType) ) + val methodReturnAst = typeNode.root match { + case Some(t) => Ast(methodReturn).withEvalTypeEdge(methodReturn, t) + case None => Ast(methodReturn) + } val annotationAsts = methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq @@ -1043,7 +1058,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa scopeStack.popScope() - methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturn, modifiers, annotationAsts) + methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturnAst, modifiers, annotationAsts) } private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration): NewMethodReturn = { @@ -2080,6 +2095,45 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa } } + private def typeToTypeArgument(x: Type): Ast = { + val typeWithoutGeneric = x.asString().takeWhile(_ != '<') + val typeFullName = typeInfoCalc + .fullName(x) + .orElse(scopeStack.lookupVariableType(typeWithoutGeneric)) + .orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true)) + .getOrElse(typeWithoutGeneric) + x match { + case t: ClassOrInterfaceType if t.getTypeArguments.isPresent => + Ast(NewTypeArgument().code(typeFullName).lineNumber(line(x)).columnNumber(column(x))) + .withChildren(astForTypeArgument(t.getTypeArguments.get().asScala.toList)) + case _ => + Ast(NewTypeArgument().code(typeFullName).lineNumber(line(x)).columnNumber(column(x))) + } + } + + private def astForTypeArgument(xs: List[Type]): Seq[Ast] = xs match { + case head :: next => typeToTypeArgument(head) +: astForTypeArgument(next) + case Nil => Seq.empty + } + + private def astForGenericType(x: ClassOrInterfaceType): Ast = { + val typeArguments = + if (x.getTypeArguments.isPresent) + astForTypeArgument(x.getTypeArguments.get().asScala.toList) + else Seq.empty + val typeWithoutGeneric = x.asString().takeWhile(_ != '<') + val typeFullName = typeInfoCalc + .fullName(x) + .orElse(scopeStack.lookupVariableType(typeWithoutGeneric)) + .orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true)) + .getOrElse(typeWithoutGeneric) + val typeName = typeFullName match { + case t if t.contains(".") && !t.endsWith(".") => t.substring(t.lastIndexOf('.') + 1) + case t => t + } + Ast(NewType().name(typeName).fullName(typeFullName)).withChildren(typeArguments) + } + private def assignmentsForVarDecl( variables: Iterable[VariableDeclarator], lineNumber: Option[Integer], @@ -2104,22 +2158,28 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa // Need the actual resolvedType here for when the RHS is a lambda expression. val resolvedExpectedType = Try(symbolSolver.toResolvedType(variable.getType, classOf[ResolvedType])).toOption val initializerAsts = astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType)) - - val typeName = typeFullName - .map(TypeNodePass.fullToShortName) - .getOrElse(s"${Defines.UnresolvedNamespace}.${variable.getTypeAsString}") - val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}" + val code = s"${variable.getTypeAsString} $name = ${initializerAsts.rootCodeOrEmpty}" val callNode = newOperatorCallNode(Operators.assignment, code, typeFullName, lineNumber, columnNumber) + val typeNode = variable.getType match { + case x: ClassOrInterfaceType if x.getTypeArguments.isPresent => + astForGenericType(x) + case _ => Ast() // This will be created by the TypeUsagePass + } + val targetAst = scopeStack.lookupVariable(name) match { case Some(nodeTypeInfo) if nodeTypeInfo.isField && !nodeTypeInfo.isStatic => val thisType = scopeStack.getEnclosingTypeDecl.map(_.fullName) fieldAccessAst(NameConstants.This, thisType, name, typeFullName, line(variable), column(variable)) case maybeCorrespNode => - val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any)) - Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) + val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any)) + val identifierAst = Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) + typeNode.root match { + case Some(t) => identifierAst.withEvalTypeEdge(identifier, t) + case None => identifierAst + } } // Since all partial constructors will be dealt with here, don't pass them up. diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala index 507351bffa09..a7ed0a3702b0 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala @@ -3,7 +3,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal} +import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal, TypeArgument} import io.shiftleft.semanticcpg.language._ import java.io.File @@ -210,6 +210,10 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover |package net.javaguides.hibernate; | |import java.util.List; + |import java.util.Map; + |import java.lang.Integer; + |import java.lang.Long; + |import java.lang.String; | |import org.hibernate.Session; |import org.hibernate.Transaction; @@ -235,8 +239,11 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover | transaction.rollback(); | } | } - | | } + | + | public List> foo() { + | return new List<>(); + | } |} |""".stripMargin, Seq("net", "javaguides", "hibernate", "NamedQueryExample.java").mkString(File.separator) @@ -254,6 +261,32 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover transaction.typeFullName shouldBe "org.hibernate.Transaction" transaction.dynamicTypeHintFullName.contains("null") } + + "present type arguments to generic types if known" in { + // List + // | Long + val Some(totalStudents) = cpg.identifier.nameExact("totalStudents").headOption + val List(list) = totalStudents.evalTypeOut.l + list.name shouldBe "List" + list.fullName shouldBe "java.util.List" + val List(long) = list.astOut.l + long.code shouldBe "java.lang.Long" + } + + "present (nested) type arguments to method returns" in { + // List + // | Map + // | String | Integer + val Some(fooReturn) = cpg.method("foo").methodReturn.headOption + val List(list) = fooReturn.evalTypeOut.l + list.name shouldBe "List" + list.fullName shouldBe "java.util.List" + val List(map) = list.astOut.collectAll[TypeArgument].l + map.code shouldBe "java.util.Map" + val List(string, integer) = map._astOut.collectAll[TypeArgument].l + string.code shouldBe "java.lang.String" + integer.code shouldBe "java.lang.Integer" + } } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala index e189657a7c69..a66ac850c2fb 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala @@ -54,6 +54,10 @@ object Ast { ast.bindsEdges.foreach { edge => diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) } + + ast.evalTypeEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.EVAL_TYPE) + } } /** For all `order` fields that are unset, derive the `order` field automatically by determining the position of the @@ -86,7 +90,8 @@ case class Ast( refEdges: collection.Seq[AstEdge] = Vector.empty, bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, - argEdges: collection.Seq[AstEdge] = Vector.empty + argEdges: collection.Seq[AstEdge] = Vector.empty, + evalTypeEdges: collection.Seq[AstEdge] = Vector.empty ) { def root: Option[NewNode] = nodes.headOption @@ -107,7 +112,8 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges + bindsEdges = bindsEdges ++ other.bindsEdges, + evalTypeEdges = evalTypeEdges ++ other.evalTypeEdges ) } @@ -119,7 +125,8 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges + bindsEdges = bindsEdges ++ other.bindsEdges, + evalTypeEdges = evalTypeEdges ++ other.evalTypeEdges ) } @@ -154,6 +161,10 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst))) } + def withEvalTypeEdge(src: NewNode, dst: NewNode): Ast = { + this.copy(evalTypeEdges = evalTypeEdges ++ List(AstEdge(src, dst))) + } + def withArgEdge(src: NewNode, dst: NewNode): Ast = { this.copy(argEdges = argEdges ++ List(AstEdge(src, dst))) } @@ -200,6 +211,10 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) } + def withEvalTypeEdges(src: NewNode, dsts: List[NewNode]): Ast = { + this.copy(evalTypeEdges = evalTypeEdges ++ dsts.map(AstEdge(src, _))) + } + /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex` * fields of the new root node are set to `order`. */ @@ -229,6 +244,7 @@ case class Ast( val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newEvalTypeEdges = evalTypeEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) Ast(newNode) .copy( @@ -236,7 +252,8 @@ case class Ast( conditionEdges = newConditionEdges, refEdges = newRefEdges, bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges + receiverEdges = newReceiverEdges, + evalTypeEdges = newEvalTypeEdges ) .withChildren(newChildren) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala index db35c73ac5ef..1fb80d1f8c9f 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala @@ -53,7 +53,7 @@ abstract class AstCreatorBase(filename: String) { methodReturn: NewMethodReturn, modifiers: Seq[NewModifier] = Nil ): Ast = - methodAstWithAnnotations(method, parameters, body, methodReturn, modifiers, annotations = Nil) + methodAstWithAnnotations(method, parameters, body, Ast(methodReturn), modifiers, annotations = Nil) /** Creates an AST that represents an entire method, including its content and with support for both method and * parameter annotations. @@ -62,7 +62,7 @@ abstract class AstCreatorBase(filename: String) { method: NewMethod, parameters: Seq[Ast], body: Ast, - methodReturn: NewMethodReturn, + methodReturn: Ast, modifiers: Seq[NewModifier] = Nil, annotations: Seq[Ast] = Nil ): Ast = @@ -71,7 +71,7 @@ abstract class AstCreatorBase(filename: String) { .withChild(body) .withChildren(modifiers.map(Ast(_))) .withChildren(annotations) - .withChild(Ast(methodReturn)) + .withChild(methodReturn) /** Creates an AST that represents a method stub, containing information about the method, its parameters, and the * return type. From c64c5a9276add7c788e59ed2921fa5a13909d48b Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Fri, 5 May 2023 14:21:18 +0200 Subject: [PATCH 2/3] Reverting AST changes --- .../src/main/scala/io/joern/x2cpg/Ast.scala | 23 +++---------------- .../scala/io/joern/x2cpg/AstCreatorBase.scala | 6 ++--- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala index a66ac850c2fb..549a64e8e576 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala @@ -54,10 +54,6 @@ object Ast { ast.bindsEdges.foreach { edge => diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) } - - ast.evalTypeEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.EVAL_TYPE) - } } /** For all `order` fields that are unset, derive the `order` field automatically by determining the position of the @@ -90,8 +86,7 @@ case class Ast( refEdges: collection.Seq[AstEdge] = Vector.empty, bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, - argEdges: collection.Seq[AstEdge] = Vector.empty, - evalTypeEdges: collection.Seq[AstEdge] = Vector.empty + argEdges: collection.Seq[AstEdge] = Vector.empty ) { def root: Option[NewNode] = nodes.headOption @@ -112,8 +107,7 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges, - evalTypeEdges = evalTypeEdges ++ other.evalTypeEdges + bindsEdges = bindsEdges ++ other.bindsEdges ) } @@ -125,8 +119,7 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges, - evalTypeEdges = evalTypeEdges ++ other.evalTypeEdges + bindsEdges = bindsEdges ++ other.bindsEdges ) } @@ -161,10 +154,6 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst))) } - def withEvalTypeEdge(src: NewNode, dst: NewNode): Ast = { - this.copy(evalTypeEdges = evalTypeEdges ++ List(AstEdge(src, dst))) - } - def withArgEdge(src: NewNode, dst: NewNode): Ast = { this.copy(argEdges = argEdges ++ List(AstEdge(src, dst))) } @@ -211,10 +200,6 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) } - def withEvalTypeEdges(src: NewNode, dsts: List[NewNode]): Ast = { - this.copy(evalTypeEdges = evalTypeEdges ++ dsts.map(AstEdge(src, _))) - } - /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex` * fields of the new root node are set to `order`. */ @@ -244,7 +229,6 @@ case class Ast( val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newEvalTypeEdges = evalTypeEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) Ast(newNode) .copy( @@ -253,7 +237,6 @@ case class Ast( refEdges = newRefEdges, bindsEdges = newBindsEdges, receiverEdges = newReceiverEdges, - evalTypeEdges = newEvalTypeEdges ) .withChildren(newChildren) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala index 1fb80d1f8c9f..db35c73ac5ef 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala @@ -53,7 +53,7 @@ abstract class AstCreatorBase(filename: String) { methodReturn: NewMethodReturn, modifiers: Seq[NewModifier] = Nil ): Ast = - methodAstWithAnnotations(method, parameters, body, Ast(methodReturn), modifiers, annotations = Nil) + methodAstWithAnnotations(method, parameters, body, methodReturn, modifiers, annotations = Nil) /** Creates an AST that represents an entire method, including its content and with support for both method and * parameter annotations. @@ -62,7 +62,7 @@ abstract class AstCreatorBase(filename: String) { method: NewMethod, parameters: Seq[Ast], body: Ast, - methodReturn: Ast, + methodReturn: NewMethodReturn, modifiers: Seq[NewModifier] = Nil, annotations: Seq[Ast] = Nil ): Ast = @@ -71,7 +71,7 @@ abstract class AstCreatorBase(filename: String) { .withChild(body) .withChildren(modifiers.map(Ast(_))) .withChildren(annotations) - .withChild(methodReturn) + .withChild(Ast(methodReturn)) /** Creates an AST that represents a method stub, containing information about the method, its parameters, and the * return type. From 24260ba3ceefa2aaec3bdcd007d569072e3da9dd Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Fri, 5 May 2023 15:09:03 +0200 Subject: [PATCH 3/3] Redo implementation by adding it to `TypeNode` pass as a demo for now --- .../io/joern/javasrc2cpg/JavaSrc2Cpg.scala | 14 ++-- .../javasrc2cpg/passes/AstCreationPass.scala | 9 ++- .../joern/javasrc2cpg/passes/AstCreator.scala | 65 +++++++------------ .../javasrc2cpg/querying/GenericsTests.scala | 6 +- .../querying/TypeInferenceTests.scala | 12 +--- .../src/main/scala/io/joern/x2cpg/Ast.scala | 2 +- .../joern/x2cpg/datastructures/Global.scala | 40 ++++++++++++ .../x2cpg/passes/frontend/TypeNodePass.scala | 63 ++++++++++++++++-- 8 files changed, 145 insertions(+), 66 deletions(-) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala index e648d291661d..99e6bda10268 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala @@ -7,13 +7,7 @@ import com.github.javaparser.ast.Node.Parsedness import com.github.javaparser.symbolsolver.JavaSymbolSolver import com.github.javaparser.symbolsolver.resolution.typesolvers.JarTypeSolver import com.github.javaparser.{JavaParser, ParserConfiguration} -import io.joern.javasrc2cpg.passes.{ - AstCreationPass, - ConfigFileCreationPass, - JavaTypeHintCallLinker, - JavaTypeRecoveryPass, - TypeInferencePass -} +import io.joern.javasrc2cpg.passes._ import io.joern.javasrc2cpg.typesolvers.{CachingReflectionTypeSolver, EagerSourceTypeSolver, SimpleCombinedTypeSolver} import io.joern.javasrc2cpg.util.Delombok.DelombokMode import io.joern.javasrc2cpg.util.{Delombok, SourceRootFinder} @@ -85,7 +79,11 @@ class JavaSrc2Cpg extends X2CpgFrontend[Config] { val astCreationPass = new AstCreationPass(javaparserAsts.analysisAsts, config, cpg, symbolSolver) astCreationPass.createAndApply() new ConfigFileCreationPass(config.inputPath, cpg).createAndApply() - new TypeNodePass(astCreationPass.global.usedTypes.keys().asScala.toList, cpg).createAndApply() + new TypeNodePass( + astCreationPass.global.usedTypes.keys().asScala.toList, + cpg, + nodesWithGenericTypes = astCreationPass.global.nodesWithGenericTypes.asScala.toMap + ).createAndApply() new TypeInferencePass(cpg).createAndApply() } } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala index 49fa7a4e4196..dd952067bda5 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala @@ -1,12 +1,17 @@ package io.joern.javasrc2cpg.passes import com.github.javaparser.symbolsolver.JavaSymbolSolver +import io.joern.javasrc2cpg.{Config, JpAstWithMeta} +import io.joern.x2cpg.datastructures.{CodeTree, Global, TreeNode} import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewType, NewTypeDecl, NewTypeParameter} import io.shiftleft.passes.ConcurrentWriterCpgPass -import io.joern.javasrc2cpg.{Config, JpAstWithMeta} -import io.joern.x2cpg.datastructures.Global import org.slf4j.LoggerFactory +import scala.collection.mutable +import scala.jdk.CollectionConverters.MapHasAsScala + class AstCreationPass(asts: List[JpAstWithMeta], config: Config, cpg: Cpg, symbolSolver: JavaSymbolSolver) extends ConcurrentWriterCpgPass[JpAstWithMeta](cpg) { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala index e9f5ee67f90d..58b73fd21e35 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreator.scala @@ -148,12 +148,10 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ NewType, NewTypeArgument, NewTypeDecl, - NewTypeRef, - NewUnknown + NewTypeRef } import io.joern.x2cpg.{Ast, AstCreatorBase, Defines} -import io.joern.x2cpg.datastructures.Global -import io.joern.x2cpg.passes.frontend.TypeNodePass +import io.joern.x2cpg.datastructures.{Global, JavaTree, TreeNode} import io.joern.x2cpg.utils.AstPropertiesUtil._ import io.joern.x2cpg.utils.NodeBuilders import io.joern.x2cpg.AstNodeBuilder @@ -701,7 +699,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa val modifiers = List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) - methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, Ast(returnNode), modifiers) + methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers) } private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = { @@ -823,7 +821,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa constructorNode, thisAst :: parameterAsts, bodyAst, - Ast(methodReturn), + methodReturn, modifiers, annotationAsts ) @@ -1010,11 +1008,6 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa scopeStack.lookupVariableType(methodDeclaration.getTypeAsString.takeWhile(_ != '<'), wildcardFallback = true) ) .orElse(Option(s"${Defines.UnresolvedNamespace}.${methodDeclaration.getTypeAsString}")) - val typeNode = methodDeclaration.getType match { - case x: ClassOrInterfaceType if x.getTypeArguments.isPresent => - astForGenericType(x) - case _ => Ast() // This will be created by some TypePass - } scopeStack.pushNewScope(MethodScope(ExpectedType(returnTypeFullName, expectedReturnType))) @@ -1047,9 +1040,10 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa line(methodDeclaration.getType), column(methodDeclaration.getType) ) - val methodReturnAst = typeNode.root match { - case Some(t) => Ast(methodReturn).withEvalTypeEdge(methodReturn, t) - case None => Ast(methodReturn) + methodDeclaration.getType match { + case x: ClassOrInterfaceType if x.getTypeArguments.isPresent => + global.nodesWithGenericTypes.put(methodReturn, astForGenericType(x)) + case _ => } val annotationAsts = methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq @@ -1058,7 +1052,7 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa scopeStack.popScope() - methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturnAst, modifiers, annotationAsts) + methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturn, modifiers, annotationAsts) } private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration): NewMethodReturn = { @@ -2095,43 +2089,39 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa } } - private def typeToTypeArgument(x: Type): Ast = { + private def typeToTypeArgument(x: Type): TreeNode = { val typeWithoutGeneric = x.asString().takeWhile(_ != '<') val typeFullName = typeInfoCalc .fullName(x) .orElse(scopeStack.lookupVariableType(typeWithoutGeneric)) .orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true)) - .getOrElse(typeWithoutGeneric) + .getOrElse(s"${Defines.UnresolvedNamespace}.$typeWithoutGeneric") x match { case t: ClassOrInterfaceType if t.getTypeArguments.isPresent => - Ast(NewTypeArgument().code(typeFullName).lineNumber(line(x)).columnNumber(column(x))) + TreeNode(typeFullName) .withChildren(astForTypeArgument(t.getTypeArguments.get().asScala.toList)) case _ => - Ast(NewTypeArgument().code(typeFullName).lineNumber(line(x)).columnNumber(column(x))) + TreeNode(typeFullName) } } - private def astForTypeArgument(xs: List[Type]): Seq[Ast] = xs match { + private def astForTypeArgument(xs: List[Type]): List[TreeNode] = xs match { case head :: next => typeToTypeArgument(head) +: astForTypeArgument(next) - case Nil => Seq.empty + case Nil => List.empty } - private def astForGenericType(x: ClassOrInterfaceType): Ast = { + private def astForGenericType(x: ClassOrInterfaceType): JavaTree = { val typeArguments = if (x.getTypeArguments.isPresent) astForTypeArgument(x.getTypeArguments.get().asScala.toList) - else Seq.empty + else List.empty val typeWithoutGeneric = x.asString().takeWhile(_ != '<') val typeFullName = typeInfoCalc .fullName(x) .orElse(scopeStack.lookupVariableType(typeWithoutGeneric)) .orElse(scopeStack.lookupVariableType(typeWithoutGeneric, wildcardFallback = true)) - .getOrElse(typeWithoutGeneric) - val typeName = typeFullName match { - case t if t.contains(".") && !t.endsWith(".") => t.substring(t.lastIndexOf('.') + 1) - case t => t - } - Ast(NewType().name(typeName).fullName(typeFullName)).withChildren(typeArguments) + .getOrElse(s"${Defines.UnresolvedNamespace}.$typeWithoutGeneric") + new JavaTree(io.joern.x2cpg.datastructures.TreeNode(typeFullName).withChildren(typeArguments)) } private def assignmentsForVarDecl( @@ -2162,24 +2152,19 @@ class AstCreator(filename: String, javaParserAst: CompilationUnit, global: Globa val callNode = newOperatorCallNode(Operators.assignment, code, typeFullName, lineNumber, columnNumber) - val typeNode = variable.getType match { - case x: ClassOrInterfaceType if x.getTypeArguments.isPresent => - astForGenericType(x) - case _ => Ast() // This will be created by the TypeUsagePass - } - val targetAst = scopeStack.lookupVariable(name) match { case Some(nodeTypeInfo) if nodeTypeInfo.isField && !nodeTypeInfo.isStatic => val thisType = scopeStack.getEnclosingTypeDecl.map(_.fullName) fieldAccessAst(NameConstants.This, thisType, name, typeFullName, line(variable), column(variable)) case maybeCorrespNode => - val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any)) - val identifierAst = Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) - typeNode.root match { - case Some(t) => identifierAst.withEvalTypeEdge(identifier, t) - case None => identifierAst + val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any)) + variable.getType match { + case x: ClassOrInterfaceType if x.getTypeArguments.isPresent => + global.nodesWithGenericTypes.put(identifier, astForGenericType(x)) + case _ => } + Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) } // Since all partial constructors will be dealt with here, don't pass them up. diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala index 8e7d24934c0b..a37edfc4dd3c 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/GenericsTests.scala @@ -73,9 +73,11 @@ class GenericsTests extends JavaSrcCode2CpgFixture { |public class Test extends Box {} |""".stripMargin) - "it should create the correct generic typeDecl name" in { + "it should create the correct generic typeDecls, each with a simple name and one with the arguments" in { cpg.typeDecl.nameExact("Box").l match { - case decl :: Nil => decl.fullName shouldBe "Box" + case decl1 :: decl2 :: Nil => + decl1.fullName shouldBe "Box" + decl2.fullName shouldBe "Box" case res => fail(s"Expected typeDecl Box but got $res") } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala index a7ed0a3702b0..4a54a7ff364b 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala @@ -2,6 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.joern.x2cpg.Defines +import io.joern.x2cpg.datastructures.TreeNode import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Literal, TypeArgument} import io.shiftleft.semanticcpg.language._ @@ -268,9 +269,7 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover val Some(totalStudents) = cpg.identifier.nameExact("totalStudents").headOption val List(list) = totalStudents.evalTypeOut.l list.name shouldBe "List" - list.fullName shouldBe "java.util.List" - val List(long) = list.astOut.l - long.code shouldBe "java.lang.Long" + list.fullName shouldBe "java.util.List" } "present (nested) type arguments to method returns" in { @@ -280,12 +279,7 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover val Some(fooReturn) = cpg.method("foo").methodReturn.headOption val List(list) = fooReturn.evalTypeOut.l list.name shouldBe "List" - list.fullName shouldBe "java.util.List" - val List(map) = list.astOut.collectAll[TypeArgument].l - map.code shouldBe "java.util.Map" - val List(string, integer) = map._astOut.collectAll[TypeArgument].l - string.code shouldBe "java.lang.String" - integer.code shouldBe "java.lang.Integer" + list.fullName shouldBe "java.util.List>" } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala index 549a64e8e576..e189657a7c69 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala @@ -236,7 +236,7 @@ case class Ast( conditionEdges = newConditionEdges, refEdges = newRefEdges, bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges, + receiverEdges = newReceiverEdges ) .withChildren(newChildren) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/datastructures/Global.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/datastructures/Global.scala index a526aa91169e..d9a12576e782 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/datastructures/Global.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/datastructures/Global.scala @@ -1,9 +1,49 @@ package io.joern.x2cpg.datastructures +import io.shiftleft.codepropertygraph.generated.nodes.NewNode + import java.util.concurrent.ConcurrentHashMap class Global { val usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap() + val nodesWithGenericTypes: ConcurrentHashMap[NewNode, CodeTree] = new ConcurrentHashMap() + +} + +case class TreeNode(value: String, children: List[TreeNode] = List.empty) { + + def withChildren(children: List[TreeNode]): TreeNode = this.copy(children = this.children ++ children) + + override def toString: String = value +} + +abstract class CodeTree(val root: TreeNode) { + + protected val separator: String + protected val lbracket: String + protected val rbracket: String + + // Lazy load the code tree string + private lazy val treeString = _toString(List(root)) + + override def toString: String = treeString + + private def _toString(xs: List[TreeNode]): String = xs match { + case head :: Nil if head.children.nonEmpty => + head.toString + lbracket + _toString(head.children) + rbracket + case head :: next if head.children.nonEmpty => + head.toString + lbracket + _toString(head.children) + rbracket + separator + _toString(next) + case head :: Nil => head.toString + case head :: next => head.toString + separator + _toString(next) + case Nil => "" + } + +} + +final class JavaTree(root: TreeNode) extends CodeTree(root) { + override protected val separator: String = ", " + override protected val lbracket: String = "<" + override protected val rbracket: String = ">" } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala index cdd183567afb..39609df17871 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/TypeNodePass.scala @@ -1,14 +1,22 @@ package io.joern.x2cpg.passes.frontend +import io.joern.x2cpg.datastructures.{CodeTree, TreeNode} import io.joern.x2cpg.passes.frontend.TypeNodePass.fullToShortName import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.NewType -import io.shiftleft.passes.{KeyPool, CpgPass} +import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewType, NewTypeDecl, NewTypeParameter} +import io.shiftleft.passes.{CpgPass, KeyPool} + +import scala.collection.mutable /** Creates a `TYPE` node for each type in `usedTypes` */ -class TypeNodePass(usedTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None) - extends CpgPass(cpg, "types", keyPool) { +class TypeNodePass( + usedTypes: List[String], + cpg: Cpg, + keyPool: Option[KeyPool] = None, + nodesWithGenericTypes: Map[NewNode, CodeTree] = Map.empty +) extends CpgPass(cpg, "types", keyPool) { override def run(diffGraph: DiffGraphBuilder): Unit = { @@ -27,7 +35,54 @@ class TypeNodePass(usedTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = .typeDeclFullName(typeName) diffGraph.addNode(node) } + + generateGenericTypes(diffGraph, nodesWithGenericTypes) + } + + private def generateGenericTypes(diffGraph: DiffGraphBuilder, nodesWithGenericTypes: Map[NewNode, CodeTree]): Unit = { + + def treeNodeToTypeParameter(x: TreeNode): NewTypeParameter = { + val typeParameter = NewTypeParameter().name(x.value).code(x.value) + diffGraph.addNode(typeParameter) + generateTypeParametersFromChildren(x.children).foreach(tp => diffGraph.addEdge(typeParameter, tp, EdgeTypes.AST)) + typeParameter + } + + def generateTypeParametersFromChildren(xs: List[TreeNode]): List[NewTypeParameter] = xs match { + case head :: Nil => List(treeNodeToTypeParameter(head)) + case head :: next => treeNodeToTypeParameter(head) +: generateTypeParametersFromChildren(next) + case Nil => List.empty + } + + def generateTypeNodeFromTree(tree: CodeTree): NewType = { + val fullName = tree.toString + val shortType = tree.root.value match { + case t if t.contains('.') && !t.endsWith(".") => t.substring(t.lastIndexOf('.') + 1) + case t => t + } + val typeNode = NewType().name(shortType).fullName(fullName).typeDeclFullName(fullName) + val typeDecl = NewTypeDecl() + .name(shortType) + .fullName(fullName) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName("ANY") + diffGraph.addNode(typeNode).addNode(typeDecl).addEdge(typeNode, typeDecl, EdgeTypes.REF) + // TODO: How to do TYPE->TYPE_ARGUMENT or TYPE_DECL->TYPE_PARAMETER? + // + // generateTypeParametersFromChildren(tree.root.children).foreach(ta => + // diffGraph.addEdge(typeDecl, ta, EdgeTypes.AST) + // ) + typeNode + } + + val typeToNode = mutable.HashMap.empty[String, NewType] + + nodesWithGenericTypes.foreach { case (node, tree) => + val associatedTypeNode = typeToNode.getOrElseUpdate(tree.toString, generateTypeNodeFromTree(tree)) + diffGraph.addEdge(node, associatedTypeNode, EdgeTypes.EVAL_TYPE) + } } + } object TypeNodePass {