Merge pull request #1684 from crytic/dev-ud-operators

Support user defined operators
pull/1991/head
Feist Josselin 1 year ago committed by GitHub
commit 79579f9285
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      slither/core/solidity_types/type_alias.py
  2. 48
      slither/solc_parsing/declarations/using_for_top_level.py
  3. 58
      slither/solc_parsing/expressions/expression_parsing.py
  4. 1
      tests/e2e/solc_parsing/test_ast_parsing.py
  5. BIN
      tests/e2e/solc_parsing/test_data/compile/user_defined_operators-0.8.19.sol-0.8.19-compact.zip
  6. 13
      tests/e2e/solc_parsing/test_data/expected/user_defined_operators-0.8.19.sol-0.8.19-compact.json
  7. 48
      tests/e2e/solc_parsing/test_data/user_defined_operators-0.8.19.sol

@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple, Dict
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations.contract_level import ContractLevel
from slither.core.solidity_types import Type, ElementaryType from slither.core.solidity_types import Type, ElementaryType
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.core.scope.scope import FileScope from slither.core.scope.scope import FileScope
@ -43,6 +44,8 @@ class TypeAliasTopLevel(TypeAlias, TopLevel):
def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None: def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None:
super().__init__(underlying_type, name) super().__init__(underlying_type, name)
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope
# operators redefined
self.operators: Dict[str, "FunctionTopLevel"] = {}
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name

@ -55,22 +55,29 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
self._propagate_global(type_name) self._propagate_global(type_name)
else: else:
for f in self._functions: for f in self._functions:
full_name_split = f["function"]["name"].split(".") # User defined operator
if len(full_name_split) == 1: if "operator" in f:
# Top level function # Top level function
function_name: str = full_name_split[0] function_name: str = f["definition"]["name"]
self._analyze_top_level_function(function_name, type_name) operator: str = f["operator"]
elif len(full_name_split) == 2: self._analyze_operator(operator, function_name, type_name)
# It can be a top level function behind an aliased import
# or a library function
first_part = full_name_split[0]
function_name = full_name_split[1]
self._check_aliased_import(first_part, function_name, type_name)
else: else:
# MyImport.MyLib.a we don't care of the alias full_name_split = f["function"]["name"].split(".")
library_name_str = full_name_split[1] if len(full_name_split) == 1:
function_name = full_name_split[2] # Top level function
self._analyze_library_function(library_name_str, function_name, type_name) function_name: str = full_name_split[0]
self._analyze_top_level_function(function_name, type_name)
elif len(full_name_split) == 2:
# It can be a top level function behind an aliased import
# or a library function
first_part = full_name_split[0]
function_name = full_name_split[1]
self._check_aliased_import(first_part, function_name, type_name)
else:
# MyImport.MyLib.a we don't care of the alias
library_name_str = full_name_split[1]
function_name = full_name_split[2]
self._analyze_library_function(library_name_str, function_name, type_name)
def _check_aliased_import( def _check_aliased_import(
self, self,
@ -101,6 +108,19 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
self._propagate_global(type_name) self._propagate_global(type_name)
break break
def _analyze_operator(
self, operator: str, function_name: str, type_name: TypeAliasTopLevel
) -> None:
for tl_function in self._using_for.file_scope.functions:
# The library function is bound to the first parameter's type
if (
tl_function.name == function_name
and tl_function.parameters
and type_name == tl_function.parameters[0].type
):
type_name.operators[operator] = tl_function
break
def _analyze_library_function( def _analyze_library_function(
self, self,
library_name: str, library_name: str,

@ -1,6 +1,6 @@
import logging import logging
import re import re
from typing import Union, Dict, TYPE_CHECKING from typing import Union, Dict, TYPE_CHECKING, List, Any
import slither.core.expressions.type_conversion import slither.core.expressions.type_conversion
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
@ -236,6 +236,24 @@ if TYPE_CHECKING:
pass pass
def _user_defined_op_call(
caller_context: CallerContextExpression, src, function_id: int, args: List[Any], type_call: str
) -> CallExpression:
var, was_created = find_variable(None, caller_context, function_id)
if was_created:
var.set_offset(src, caller_context.compilation_unit)
identifier = Identifier(var)
identifier.set_offset(src, caller_context.compilation_unit)
var.references.append(identifier.source_mapping)
call = CallExpression(identifier, args, type_call)
call.set_offset(src, caller_context.compilation_unit)
return call
def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression": def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression":
# pylint: disable=too-many-nested-blocks,too-many-statements # pylint: disable=too-many-nested-blocks,too-many-statements
""" """
@ -274,16 +292,24 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression)
if name == "UnaryOperation": if name == "UnaryOperation":
if is_compact_ast: if is_compact_ast:
attributes = expression attributes = expression
else:
attributes = expression["attributes"]
assert "prefix" in attributes
operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"])
if is_compact_ast:
expression = parse_expression(expression["subExpression"], caller_context) expression = parse_expression(expression["subExpression"], caller_context)
else: else:
attributes = expression["attributes"]
assert len(expression["children"]) == 1 assert len(expression["children"]) == 1
expression = parse_expression(expression["children"][0], caller_context) expression = parse_expression(expression["children"][0], caller_context)
assert "prefix" in attributes
# Use of user defined operation
if "function" in attributes:
return _user_defined_op_call(
caller_context,
src,
attributes["function"],
[expression],
attributes["typeDescriptions"]["typeString"],
)
operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"])
unary_op = UnaryOperation(expression, operation_type) unary_op = UnaryOperation(expression, operation_type)
unary_op.set_offset(src, caller_context.compilation_unit) unary_op.set_offset(src, caller_context.compilation_unit)
return unary_op return unary_op
@ -291,17 +317,25 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression)
if name == "BinaryOperation": if name == "BinaryOperation":
if is_compact_ast: if is_compact_ast:
attributes = expression attributes = expression
else:
attributes = expression["attributes"]
operation_type = BinaryOperationType.get_type(attributes["operator"])
if is_compact_ast:
left_expression = parse_expression(expression["leftExpression"], caller_context) left_expression = parse_expression(expression["leftExpression"], caller_context)
right_expression = parse_expression(expression["rightExpression"], caller_context) right_expression = parse_expression(expression["rightExpression"], caller_context)
else: else:
assert len(expression["children"]) == 2 assert len(expression["children"]) == 2
attributes = expression["attributes"]
left_expression = parse_expression(expression["children"][0], caller_context) left_expression = parse_expression(expression["children"][0], caller_context)
right_expression = parse_expression(expression["children"][1], caller_context) right_expression = parse_expression(expression["children"][1], caller_context)
# Use of user defined operation
if "function" in attributes:
return _user_defined_op_call(
caller_context,
src,
attributes["function"],
[left_expression, right_expression],
attributes["typeDescriptions"]["typeString"],
)
operation_type = BinaryOperationType.get_type(attributes["operator"])
binary_op = BinaryOperation(left_expression, right_expression, operation_type) binary_op = BinaryOperation(left_expression, right_expression, operation_type)
binary_op.set_offset(src, caller_context.compilation_unit) binary_op.set_offset(src, caller_context.compilation_unit)
return binary_op return binary_op

@ -458,6 +458,7 @@ ALL_TESTS = [
"assembly-functions.sol", "assembly-functions.sol",
["0.6.9", "0.7.6", "0.8.16"], ["0.6.9", "0.7.6", "0.8.16"],
), ),
Test("user_defined_operators-0.8.19.sol", ["0.8.19"]),
] ]
# create the output folder if needed # create the output folder if needed
try: try:

@ -0,0 +1,13 @@
{
"Lib": {
"f(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n"
},
"T": {
"add_function_call(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n1->2;\n2[label=\"Node Type: RETURN 2\n\"];\n}\n",
"add_op(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"lib_call(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"neg_usertype(Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n1->2;\n2[label=\"Node Type: RETURN 2\n\"];\n}\n",
"neg_int(int256)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"eq_op(Int,Int)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}

@ -0,0 +1,48 @@
pragma solidity ^0.8.19;
type Int is int;
using {add as +, eq as ==, add, neg as -, Lib.f} for Int global;
function add(Int a, Int b) pure returns (Int) {
return Int.wrap(Int.unwrap(a) + Int.unwrap(b));
}
function eq(Int a, Int b) pure returns (bool) {
return true;
}
function neg(Int a) pure returns (Int) {
return a;
}
library Lib {
function f(Int r) internal {}
}
contract T {
function add_function_call(Int b, Int c) public returns(Int) {
Int res = add(b,c);
return res;
}
function add_op(Int b, Int c) public returns(Int) {
return b + c;
}
function lib_call(Int b) public {
return b.f();
}
function neg_usertype(Int b) public returns(Int) {
Int res = -b;
return res;
}
function neg_int(int b) public returns(int) {
return -b;
}
function eq_op(Int b, Int c) public returns(bool) {
return b == c;
}
}
Loading…
Cancel
Save