diff --git a/py_cppmodel.py b/py_cppmodel.py index a4e60b4..d57746b 100644 --- a/py_cppmodel.py +++ b/py_cppmodel.py @@ -178,15 +178,22 @@ def __repr__(self) -> str: return "".format(self.name) +class ClassTemplate(Class): + def __init__(self, cursor: Cursor, namespaces: List[str]): + Class.__init__(self, cursor, namespaces) + # TODO: replace the count with the actual template args once count works. + self.template_parameter_count = cursor.get_num_template_arguments() + + class Model(object): def __init__(self, translation_unit: TranslationUnit): """Create a model from a translation unit.""" self.filename: str = translation_unit.spelling self.functions: List[Function] = [] - self.classes: List[Class] = [] + self.classes: List[Class|ClassTemplate] = [] self.unmodelled_nodes: List[Unmodelled] = [] # Keep a reference to the translation unit to prevent it from being garbage collected. - self.translation_unit: TranslationUnit = translation_unit + self.translation_unit: TranslationUnit = translation_unit def is_error_in_current_file(diagnostic: Diagnostic) -> bool: if str(diagnostic.location.file) != str(translation_unit.spelling): @@ -251,6 +258,8 @@ def _add_child_nodes(self, cursor: Any, namespaces: List[str] = []): for c in cursor.get_children(): if c.kind == CursorKind.CLASS_DECL or c.kind == CursorKind.STRUCT_DECL: self.classes.append(Class(c, namespaces)) + elif c.kind == CursorKind.CLASS_TEMPLATE: + self.classes.append(ClassTemplate(c, namespaces)) elif ( c.kind == CursorKind.FUNCTION_DECL and c.type.kind == TypeKind.FUNCTIONPROTO diff --git a/test.sh b/test.sh index 81ca10c..86c063d 100755 --- a/test.sh +++ b/test.sh @@ -4,4 +4,3 @@ set -x python -m pytype . python -m unittest discover . - diff --git a/test_py_cppmodel.py b/test_py_cppmodel.py index 715ee3d..7ac8963 100644 --- a/test_py_cppmodel.py +++ b/test_py_cppmodel.py @@ -2,6 +2,7 @@ import py_cppmodel import unittest +# This is a workaround for the current inability to portably find libclang. clang.cindex.Config.set_library_path("/Library/Developer/CommandLineTools/usr/lib/") @@ -25,7 +26,7 @@ def test_functions(self): ) def test_classes(self): - self.assertEqual(len(self.model.classes), 1) + self.assertEqual(len(self.model.classes), 2) self.assertEqual(str(self.model.classes[0]), "") self.assertEqual(len(self.model.classes[0].members), 3) @@ -47,15 +48,18 @@ def test_classes(self): str(self.model.classes[0].methods[0]), "" ) - self.assertEqual(len(self.model.unmodelled_nodes), 2) + self.assertEqual(str(self.model.classes[1]), "") + self.assertIsInstance(self.model.classes[1], py_cppmodel.ClassTemplate) + if isinstance(self.model.classes[1], py_cppmodel.ClassTemplate): + self.assertEqual(self.model.classes[1].template_parameter_count, 1) + + + def test_unmodelled_nodes(self): + self.assertEqual(len(self.model.unmodelled_nodes), 1) self.assertEqual( str(self.model.unmodelled_nodes[0]), ">", ) - self.assertEqual( - str(self.model.unmodelled_nodes[1]), - " >", - ) if __name__ == "__main__":