diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index 234873eaca..6e2cf142ff 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -17,6 +17,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: self._solidity_signature: Optional[str] = None self._full_name: Optional[str] = None + self._pattern = "error" @property def name(self) -> str: diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index be215bd1f2..6e8968dfb2 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -94,14 +94,11 @@ class FunctionType(Enum): def _filter_state_variables_written(expressions: List["Expression"]): ret = [] + for expression in expressions: - if isinstance(expression, Identifier): - ret.append(expression) - if isinstance(expression, UnaryOperation): - ret.append(expression.expression) - if isinstance(expression, MemberAccess): + if isinstance(expression, (Identifier, UnaryOperation, MemberAccess)): ret.append(expression.expression) - if isinstance(expression, IndexAccess): + elif isinstance(expression, IndexAccess): ret.append(expression.expression_left) return ret diff --git a/slither/core/declarations/import_directive.py b/slither/core/declarations/import_directive.py index 19ea2cff96..440b09e9cd 100644 --- a/slither/core/declarations/import_directive.py +++ b/slither/core/declarations/import_directive.py @@ -16,6 +16,8 @@ def __init__(self, filename: Path, scope: "FileScope") -> None: # Map local name -> original name self.renaming: Dict[str, str] = {} + self._pattern = "import" + @property def filename(self) -> str: """ diff --git a/slither/core/declarations/pragma_directive.py b/slither/core/declarations/pragma_directive.py index cd790d5a47..90c1da2ddf 100644 --- a/slither/core/declarations/pragma_directive.py +++ b/slither/core/declarations/pragma_directive.py @@ -11,6 +11,7 @@ def __init__(self, directive: List[str], scope: "FileScope") -> None: super().__init__() self._directive = directive self.scope: "FileScope" = scope + self._pattern = "pragma" @property def directive(self) -> List[str]: diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index 5cd29a9f5d..493620ab18 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -78,3 +78,6 @@ def value( def __str__(self) -> str: return str(self._value) + + def expression(self): + return self diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index a9f45c8d81..61729b06a2 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -225,4 +225,4 @@ def __eq__(self, other: Any) -> bool: return self.type == other.type def __hash__(self) -> int: - return hash(str(self)) + return hash(self._type) diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index ead9b5394f..c22cd257ef 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -15,6 +15,7 @@ def __init__(self, underlying_type: ElementaryType, name: str) -> None: super().__init__() self.name = name self.underlying_type = underlying_type + self._pattern = "type" @property def type(self) -> ElementaryType: diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index fceab78559..8dda25a241 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -1,5 +1,4 @@ import re -from abc import ABCMeta from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any from Crypto.Hash import SHA1 @@ -99,21 +98,29 @@ def __str__(self) -> str: return f"{filename_short}{lines}" def __hash__(self) -> int: - return hash(str(self)) + return hash( + ( + self.start, + self.length, + self.filename.relative, + self.end, + ) + ) def __eq__(self, other: Any) -> bool: - if not isinstance(other, type(self)): + try: + return ( + self.start == other.start + and self.length == other.length + and self.filename == other.filename + and self.is_dependency == other.is_dependency + and self.lines == other.lines + and self.starting_column == other.starting_column + and self.ending_column == other.ending_column + and self.end == other.end + ) + except AttributeError: return NotImplemented - return ( - self.start == other.start - and self.length == other.length - and self.filename == other.filename - and self.is_dependency == other.is_dependency - and self.lines == other.lines - and self.starting_column == other.starting_column - and self.ending_column == other.ending_column - and self.end == other.end - ) def _compute_line( @@ -183,12 +190,14 @@ def _convert_source_mapping( return new_source -class SourceMapping(Context, metaclass=ABCMeta): +class SourceMapping(Context): def __init__(self) -> None: super().__init__() self.source_mapping: Optional[Source] = None self.references: List[Source] = [] + self._pattern: Union[str, None] = None + def set_offset( self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit" ) -> None: @@ -204,3 +213,11 @@ def add_reference_from_raw_source( ) -> None: s = _convert_source_mapping(offset, compilation_unit) self.references.append(s) + + @property + def pattern(self) -> str: + if self._pattern is None: + # Add " " to look after the first solidity keyword + return f" {self.name}" # pylint: disable=no-member + + return self._pattern diff --git a/slither/utils/source_mapping.py b/slither/utils/source_mapping.py index 9bf772894e..180c842f72 100644 --- a/slither/utils/source_mapping.py +++ b/slither/utils/source_mapping.py @@ -2,37 +2,17 @@ from crytic_compile import CryticCompile from slither.core.declarations import ( Contract, - Function, - Enum, - Event, - Import, - Pragma, - Structure, - CustomError, FunctionContract, ) -from slither.core.solidity_types import Type, TypeAlias from slither.core.source_mapping.source_mapping import Source, SourceMapping -from slither.core.variables.variable import Variable from slither.exceptions import SlitherError def get_definition(target: SourceMapping, crytic_compile: CryticCompile) -> Source: - if isinstance(target, (Contract, Function, Enum, Event, Structure, Variable)): - # Add " " to look after the first solidity keyword - pattern = " " + target.name - elif isinstance(target, Import): - pattern = "import" - elif isinstance(target, Pragma): - pattern = "pragma" # todo maybe return with the while pragma statement - elif isinstance(target, CustomError): - pattern = "error" - elif isinstance(target, TypeAlias): - pattern = "type" - elif isinstance(target, Type): - raise SlitherError("get_definition_generic not implemented for types") - else: - raise SlitherError(f"get_definition_generic not implemented for {type(target)}") + try: + pattern = target.pattern + except AttributeError as exc: + raise SlitherError(f"get_definition_generic not implemented for {type(target)}") from exc file_content = crytic_compile.src_content_for_file(target.source_mapping.filename.absolute) txt = file_content[ diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 41886a1023..83dd1be51a 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,4 +1,5 @@ import logging +from functools import lru_cache from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -16,11 +17,39 @@ from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.type_conversion import TypeConversion from slither.core.expressions.unary_operation import UnaryOperation +from slither.core.expressions.super_call_expression import SuperCallExpression +from slither.core.expressions.super_identifier import SuperIdentifier +from slither.core.expressions.self_identifier import SelfIdentifier from slither.exceptions import SlitherError logger = logging.getLogger("ExpressionVisitor") +@lru_cache() +def get_visitor_mapping(): + """Returns a visitor mapping from expression type to visiting functions.""" + return { + AssignmentOperation: "_visit_assignement_operation", + BinaryOperation: "_visit_binary_operation", + CallExpression: "_visit_call_expression", + ConditionalExpression: "_visit_conditional_expression", + ElementaryTypeNameExpression: "_visit_elementary_type_name_expression", + Identifier: "_visit_identifier", + IndexAccess: "_visit_index_access", + Literal: "_visit_literal", + MemberAccess: "_visit_member_access", + NewArray: "_visit_new_array", + NewContract: "_visit_new_contract", + NewElementaryType: "_visit_new_elementary_type", + TupleExpression: "_visit_tuple_expression", + TypeConversion: "_visit_type_conversion", + UnaryOperation: "_visit_unary_operation", + SelfIdentifier: "_visit_identifier", + SuperIdentifier: "_visit_identifier", + SuperCallExpression: "_visit_call_expression", + } + + # pylint: disable=too-few-public-methods class ExpressionVisitor: def __init__(self, expression: Expression) -> None: @@ -35,60 +64,16 @@ def expression(self) -> Expression: # visit an expression # call pre_visit, visit_expression_name, post_visit - # pylint: disable=too-many-branches def _visit_expression(self, expression: Expression) -> None: self._pre_visit(expression) - if isinstance(expression, AssignmentOperation): - self._visit_assignement_operation(expression) - - elif isinstance(expression, BinaryOperation): - self._visit_binary_operation(expression) - - elif isinstance(expression, CallExpression): - self._visit_call_expression(expression) - - elif isinstance(expression, ConditionalExpression): - self._visit_conditional_expression(expression) - - elif isinstance(expression, ElementaryTypeNameExpression): - self._visit_elementary_type_name_expression(expression) - - elif isinstance(expression, Identifier): - self._visit_identifier(expression) - - elif isinstance(expression, IndexAccess): - self._visit_index_access(expression) - - elif isinstance(expression, Literal): - self._visit_literal(expression) - - elif isinstance(expression, MemberAccess): - self._visit_member_access(expression) - - elif isinstance(expression, NewArray): - self._visit_new_array(expression) - - elif isinstance(expression, NewContract): - self._visit_new_contract(expression) - - elif isinstance(expression, NewElementaryType): - self._visit_new_elementary_type(expression) - - elif isinstance(expression, TupleExpression): - self._visit_tuple_expression(expression) + if expression is not None: + visitor_method = get_visitor_mapping().get(expression.__class__) + if not visitor_method: + raise SlitherError(f"Expression not handled: {expression}") - elif isinstance(expression, TypeConversion): - self._visit_type_conversion(expression) - - elif isinstance(expression, UnaryOperation): - self._visit_unary_operation(expression) - - elif expression is None: - pass - - else: - raise SlitherError(f"Expression not handled: {expression}") + visitor = getattr(self, visitor_method) + visitor(expression) self._post_visit(expression)