diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 9387f511a..ead9b5394 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Tuple, Dict from slither.core.declarations.top_level import TopLevel from slither.core.declarations.contract_level import ContractLevel from slither.core.solidity_types import Type, ElementaryType if TYPE_CHECKING: + from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations import Contract from slither.core.scope.scope import FileScope @@ -43,6 +44,8 @@ class TypeAliasTopLevel(TypeAlias, TopLevel): def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None: super().__init__(underlying_type, name) self.file_scope: "FileScope" = scope + # operators redefined + self.operators: Dict[str, "FunctionTopLevel"] = {} def __str__(self) -> str: return self.name diff --git a/slither/solc_parsing/declarations/using_for_top_level.py b/slither/solc_parsing/declarations/using_for_top_level.py index 707ad83ac..fe72e5780 100644 --- a/slither/solc_parsing/declarations/using_for_top_level.py +++ b/slither/solc_parsing/declarations/using_for_top_level.py @@ -55,22 +55,29 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- self._propagate_global(type_name) else: for f in self._functions: - full_name_split = f["function"]["name"].split(".") - if len(full_name_split) == 1: + # User defined operator + if "operator" in f: # Top level function - function_name: str = full_name_split[0] - self._analyze_top_level_function(function_name, type_name) - elif len(full_name_split) == 2: - # It can be a top level function behind an aliased import - # or a library function - first_part = full_name_split[0] - function_name = full_name_split[1] - self._check_aliased_import(first_part, function_name, type_name) + function_name: str = f["definition"]["name"] + operator: str = f["operator"] + self._analyze_operator(operator, function_name, type_name) else: - # MyImport.MyLib.a we don't care of the alias - library_name_str = full_name_split[1] - function_name = full_name_split[2] - self._analyze_library_function(library_name_str, function_name, type_name) + full_name_split = f["function"]["name"].split(".") + if len(full_name_split) == 1: + # Top level function + function_name: str = full_name_split[0] + self._analyze_top_level_function(function_name, type_name) + elif len(full_name_split) == 2: + # It can be a top level function behind an aliased import + # or a library function + first_part = full_name_split[0] + function_name = full_name_split[1] + self._check_aliased_import(first_part, function_name, type_name) + else: + # MyImport.MyLib.a we don't care of the alias + library_name_str = full_name_split[1] + function_name = full_name_split[2] + self._analyze_library_function(library_name_str, function_name, type_name) def _check_aliased_import( self, @@ -101,6 +108,19 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- self._propagate_global(type_name) break + def _analyze_operator( + self, operator: str, function_name: str, type_name: TypeAliasTopLevel + ) -> None: + for tl_function in self._using_for.file_scope.functions: + # The library function is bound to the first parameter's type + if ( + tl_function.name == function_name + and tl_function.parameters + and type_name == tl_function.parameters[0].type + ): + type_name.operators[operator] = tl_function + break + def _analyze_library_function( self, library_name: str, diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 4d2cfc00f..a0bce044c 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -1,6 +1,6 @@ import logging import re -from typing import Union, Dict, TYPE_CHECKING +from typing import Union, Dict, TYPE_CHECKING, List, Any import slither.core.expressions.type_conversion from slither.core.declarations.solidity_variables import ( @@ -236,6 +236,24 @@ if TYPE_CHECKING: pass +def _user_defined_op_call( + caller_context: CallerContextExpression, src, function_id: int, args: List[Any], type_call: str +) -> CallExpression: + var, was_created = find_variable(None, caller_context, function_id) + + if was_created: + var.set_offset(src, caller_context.compilation_unit) + + identifier = Identifier(var) + identifier.set_offset(src, caller_context.compilation_unit) + + var.references.append(identifier.source_mapping) + + call = CallExpression(identifier, args, type_call) + call.set_offset(src, caller_context.compilation_unit) + return call + + def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression": # pylint: disable=too-many-nested-blocks,too-many-statements """ @@ -274,16 +292,24 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) if name == "UnaryOperation": if is_compact_ast: attributes = expression - else: - attributes = expression["attributes"] - assert "prefix" in attributes - operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"]) - - if is_compact_ast: expression = parse_expression(expression["subExpression"], caller_context) else: + attributes = expression["attributes"] assert len(expression["children"]) == 1 expression = parse_expression(expression["children"][0], caller_context) + assert "prefix" in attributes + + # Use of user defined operation + if "function" in attributes: + return _user_defined_op_call( + caller_context, + src, + attributes["function"], + [expression], + attributes["typeDescriptions"]["typeString"], + ) + + operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"]) unary_op = UnaryOperation(expression, operation_type) unary_op.set_offset(src, caller_context.compilation_unit) return unary_op @@ -291,17 +317,25 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) if name == "BinaryOperation": if is_compact_ast: attributes = expression - else: - attributes = expression["attributes"] - operation_type = BinaryOperationType.get_type(attributes["operator"]) - - if is_compact_ast: left_expression = parse_expression(expression["leftExpression"], caller_context) right_expression = parse_expression(expression["rightExpression"], caller_context) else: assert len(expression["children"]) == 2 + attributes = expression["attributes"] left_expression = parse_expression(expression["children"][0], caller_context) right_expression = parse_expression(expression["children"][1], caller_context) + + # Use of user defined operation + if "function" in attributes: + return _user_defined_op_call( + caller_context, + src, + attributes["function"], + [left_expression, right_expression], + attributes["typeDescriptions"]["typeString"], + ) + + operation_type = BinaryOperationType.get_type(attributes["operator"]) binary_op = BinaryOperation(left_expression, right_expression, operation_type) binary_op.set_offset(src, caller_context.compilation_unit) return binary_op diff --git a/tests/e2e/solc_parsing/test_ast_parsing.py b/tests/e2e/solc_parsing/test_ast_parsing.py index b694d1044..307e6736f 100644 --- a/tests/e2e/solc_parsing/test_ast_parsing.py +++ b/tests/e2e/solc_parsing/test_ast_parsing.py @@ -458,6 +458,7 @@ ALL_TESTS = [ "assembly-functions.sol", ["0.6.9", "0.7.6", "0.8.16"], ), + Test("user_defined_operators-0.8.19.sol", ["0.8.19"]), ] # create the output folder if needed try: diff --git a/tests/e2e/solc_parsing/test_data/compile/user_defined_operators-0.8.19.sol-0.8.19-compact.zip b/tests/e2e/solc_parsing/test_data/compile/user_defined_operators-0.8.19.sol-0.8.19-compact.zip new file mode 100644 index 000000000..7159a1486 Binary files /dev/null and b/tests/e2e/solc_parsing/test_data/compile/user_defined_operators-0.8.19.sol-0.8.19-compact.zip differ diff --git a/tests/e2e/solc_parsing/test_data/expected/user_defined_operators-0.8.19.sol-0.8.19-compact.json b/tests/e2e/solc_parsing/test_data/expected/user_defined_operators-0.8.19.sol-0.8.19-compact.json new file mode 100644 index 000000000..bee7819a6 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/expected/user_defined_operators-0.8.19.sol-0.8.19-compact.json @@ -0,0 +1,13 @@ +{ + "Lib": { + "f(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n" + }, + "T": { + "add_function_call(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n1->2;\n2[label=\"Node Type: RETURN 2\n\"];\n}\n", + "add_op(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n", + "lib_call(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n", + "neg_usertype(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n1->2;\n2[label=\"Node Type: RETURN 2\n\"];\n}\n", + "neg_int(int256)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n", + "eq_op(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n" + } +} \ No newline at end of file diff --git a/tests/e2e/solc_parsing/test_data/user_defined_operators-0.8.19.sol b/tests/e2e/solc_parsing/test_data/user_defined_operators-0.8.19.sol new file mode 100644 index 000000000..e4df845fb --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/user_defined_operators-0.8.19.sol @@ -0,0 +1,48 @@ +pragma solidity ^0.8.19; + +type Int is int; +using {add as +, eq as ==, add, neg as -, Lib.f} for Int global; + +function add(Int a, Int b) pure returns (Int) { + return Int.wrap(Int.unwrap(a) + Int.unwrap(b)); +} + +function eq(Int a, Int b) pure returns (bool) { + return true; +} + +function neg(Int a) pure returns (Int) { + return a; +} + +library Lib { + function f(Int r) internal {} +} + +contract T { + function add_function_call(Int b, Int c) public returns(Int) { + Int res = add(b,c); + return res; + } + + function add_op(Int b, Int c) public returns(Int) { + return b + c; + } + + function lib_call(Int b) public { + return b.f(); + } + + function neg_usertype(Int b) public returns(Int) { + Int res = -b; + return res; + } + + function neg_int(int b) public returns(int) { + return -b; + } + + function eq_op(Int b, Int c) public returns(bool) { + return b == c; + } +}