Add support for user defined types (#1135)

* Add support for user defined types
- Create a new core object TypeAlias (top level or contract)
- Add support for wrap/unwrap
- Add tests
pull/746/merge
Feist Josselin 3 years ago committed by GitHub
parent 496c8e1910
commit 5863c30747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      slither/core/scope/scope.py
  2. 1
      slither/core/solidity_types/__init__.py
  3. 41
      slither/core/solidity_types/type_alias.py
  4. 26
      slither/solc_parsing/declarations/contract.py
  5. 5
      slither/solc_parsing/expressions/find_variable.py
  6. 18
      slither/solc_parsing/slither_compilation_unit_solc.py
  7. 14
      slither/solc_parsing/solidity_types/type_parsing.py
  8. 165
      slither/visitors/slithir/expression_to_slithir.py
  9. BIN
      tests/ast-parsing/compile/user_defined_types.sol-0.8.10-compact.zip
  10. BIN
      tests/ast-parsing/compile/user_defined_types.sol-0.8.11-compact.zip
  11. BIN
      tests/ast-parsing/compile/user_defined_types.sol-0.8.12-compact.zip
  12. BIN
      tests/ast-parsing/compile/user_defined_types.sol-0.8.8-compact.zip
  13. 10
      tests/ast-parsing/expected/user_defined_types.sol-0.8.10-compact.json
  14. 10
      tests/ast-parsing/expected/user_defined_types.sol-0.8.11-compact.json
  15. 10
      tests/ast-parsing/expected/user_defined_types.sol-0.8.12-compact.json
  16. 10
      tests/ast-parsing/expected/user_defined_types.sol-0.8.8-compact.json
  17. 30
      tests/ast-parsing/user_defined_types.sol
  18. 2
      tests/test_ast_parsing.py

@ -6,6 +6,7 @@ from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.solidity_types import TypeAlias
from slither.core.variables.top_level_variable import TopLevelVariable
from slither.slithir.variables import Constant
@ -44,6 +45,10 @@ class FileScope:
# local name -> original name (A -> B)
self.renaming: Dict[str, str] = {}
# User defined types
# Name -> type alias
self.user_defined_types: Dict[str, TypeAlias] = {}
def add_accesible_scopes(self) -> bool:
"""
Add information from accessible scopes. Return true if new information was obtained
@ -82,6 +87,9 @@ class FileScope:
if not _dict_contain(new_scope.renaming, self.renaming):
self.renaming.update(new_scope.renaming)
learn_something = True
if not _dict_contain(new_scope.user_defined_types, self.user_defined_types):
self.user_defined_types.update(new_scope.user_defined_types)
learn_something = True
return learn_something

@ -5,3 +5,4 @@ from .mapping_type import MappingType
from .user_defined_type import UserDefinedType
from .type import Type
from .type_information import TypeInformation
from .type_alias import TypeAlias, TypeAliasTopLevel, TypeAliasContract

@ -0,0 +1,41 @@
from typing import TYPE_CHECKING, Tuple
from slither.core.children.child_contract import ChildContract
from slither.core.declarations.top_level import TopLevel
from slither.core.solidity_types import Type
if TYPE_CHECKING:
from slither.core.declarations import Contract
from slither.core.scope.scope import FileScope
class TypeAlias(Type):
def __init__(self, underlying_type: Type, name: str):
super().__init__()
self.name = name
self.underlying_type = underlying_type
@property
def storage_size(self) -> Tuple[int, bool]:
return self.underlying_type.storage_size
def __hash__(self):
return hash(str(self))
class TypeAliasTopLevel(TypeAlias, TopLevel):
def __init__(self, underlying_type: Type, name: str, scope: "FileScope"):
super().__init__(underlying_type, name)
self.file_scope: "FileScope" = scope
def __str__(self):
return self.name
class TypeAliasContract(TypeAlias, ChildContract):
def __init__(self, underlying_type: Type, name: str, contract: "Contract"):
super().__init__(underlying_type, name)
self._contract: "Contract" = contract
def __str__(self):
return self.contract.name + "." + self.name

@ -5,6 +5,7 @@ from slither.core.declarations import Modifier, Event, EnumContract, StructureCo
from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract
from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
@ -230,6 +231,7 @@ class ContractSolc(CallerContextExpression):
self.baseConstructorContractsCalled.append(referencedDeclaration)
def _parse_contract_items(self):
# pylint: disable=too-many-branches
if not self.get_children() in self._data: # empty contract
return
for item in self._data[self.get_children()]:
@ -253,10 +255,34 @@ class ContractSolc(CallerContextExpression):
self._usingForNotParsed.append(item)
elif item[self.get_key()] == "ErrorDefinition":
self._customErrorParsed.append(item)
elif item[self.get_key()] == "UserDefinedValueTypeDefinition":
self._parse_type_alias(item)
else:
raise ParsingError("Unknown contract item: " + item[self.get_key()])
return
def _parse_type_alias(self, item: Dict) -> None:
assert "name" in item
assert "underlyingType" in item
underlying_type = item["underlyingType"]
assert "nodeType" in underlying_type and underlying_type["nodeType"] == "ElementaryTypeName"
assert "name" in underlying_type
original_type = ElementaryType(underlying_type["name"])
# For user defined types defined at the contract level the lookup can be done
# Using the name or the canonical name
# For example during the type parsing the canonical name
# Note that Solidity allows shadowing of user defined types
# Between top level and contract definitions
alias = item["name"]
alias_canonical = self._contract.name + "." + item["name"]
user_defined_type = TypeAliasContract(original_type, alias, self.underlying_contract)
user_defined_type.set_offset(item["src"], self.compilation_unit)
self._contract.file_scope.user_defined_types[alias] = user_defined_type
self._contract.file_scope.user_defined_types[alias_canonical] = user_defined_type
def _parse_struct(self, struct: Dict):
st = StructureContract(self._contract.compilation_unit)

@ -18,6 +18,7 @@ from slither.core.solidity_types import (
ArrayType,
FunctionType,
MappingType,
TypeAlias,
)
from slither.core.variables.top_level_variable import TopLevelVariable
from slither.core.variables.variable import Variable
@ -292,6 +293,7 @@ def find_variable(
Enum,
Structure,
CustomError,
TypeAlias,
],
bool,
]:
@ -337,6 +339,9 @@ def find_variable(
if var_name in current_scope.renaming:
var_name = current_scope.renaming[var_name]
if var_name in current_scope.user_defined_types:
return current_scope.user_defined_types[var_name], False
# Use ret0/ret1 to help mypy
ret0 = _find_variable_from_ref_declaration(
referenced_declaration, direct_contracts, direct_functions

@ -15,6 +15,7 @@ from slither.core.declarations.import_directive import Import
from slither.core.declarations.pragma_directive import Pragma
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import ElementaryType, TypeAliasTopLevel
from slither.core.variables.top_level_variable import TopLevelVariable
from slither.exceptions import SlitherException
from slither.solc_parsing.declarations.contract import ContractSolc
@ -298,6 +299,23 @@ class SlitherCompilationUnitSolc:
self._compilation_unit.custom_errors.append(custom_error)
self._custom_error_parser.append(custom_error_parser)
elif top_level_data[self.get_key()] == "UserDefinedValueTypeDefinition":
assert "name" in top_level_data
alias = top_level_data["name"]
assert "underlyingType" in top_level_data
underlying_type = top_level_data["underlyingType"]
assert (
"nodeType" in underlying_type
and underlying_type["nodeType"] == "ElementaryTypeName"
)
assert "name" in underlying_type
original_type = ElementaryType(underlying_type["name"])
user_defined_type = TypeAliasTopLevel(original_type, alias, scope)
user_defined_type.set_offset(top_level_data["src"], self._compilation_unit)
scope.user_defined_types[alias] = user_defined_type
else:
raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported")

@ -6,6 +6,7 @@ from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.function_contract import FunctionContract
from slither.core.expressions.literal import Literal
from slither.core.solidity_types import TypeAlias
from slither.core.solidity_types.array_type import ArrayType
from slither.core.solidity_types.elementary_type import (
ElementaryType,
@ -224,6 +225,7 @@ def parse_type(
sl: "SlitherCompilationUnit"
renaming: Dict[str, str]
user_defined_types: Dict[str, TypeAlias]
# Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None
if isinstance(caller_context, SlitherCompilationUnitSolc) or (
@ -234,11 +236,13 @@ def parse_type(
sl = caller_context.compilation_unit
next_context = caller_context
renaming = {}
user_defined_types = {}
else:
assert isinstance(caller_context, FunctionSolc)
sl = caller_context.underlying_function.compilation_unit
next_context = caller_context.slither_parser
renaming = caller_context.underlying_function.file_scope.renaming
user_defined_types = caller_context.underlying_function.file_scope.user_defined_types
structures_direct_access = sl.structures_top_level
all_structuress = [c.structures for c in sl.contracts]
all_structures = [item for sublist in all_structuress for item in sublist]
@ -274,6 +278,7 @@ def parse_type(
functions = list(scope.functions)
renaming = scope.renaming
user_defined_types = scope.user_defined_types
elif isinstance(caller_context, (ContractSolc, FunctionSolc)):
if isinstance(caller_context, FunctionSolc):
underlying_func = caller_context.underlying_function
@ -302,6 +307,7 @@ def parse_type(
functions = contract.functions + contract.modifiers
renaming = scope.renaming
user_defined_types = scope.user_defined_types
else:
raise ParsingError(f"Incorrect caller context: {type(caller_context)}")
@ -315,6 +321,8 @@ def parse_type(
name = t.name
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
return _find_from_type_name(
name,
functions,
@ -335,6 +343,8 @@ def parse_type(
name = t["typeDescriptions"]["typeString"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
return _find_from_type_name(
name,
functions,
@ -351,6 +361,8 @@ def parse_type(
name = t["attributes"][type_name_key]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
return _find_from_type_name(
name,
functions,
@ -367,6 +379,8 @@ def parse_type(
name = t["name"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
return _find_from_type_name(
name,
functions,

@ -1,10 +1,13 @@
import logging
from typing import List
from slither.core.declarations import (
Function,
SolidityVariable,
SolidityVariableComposed,
SolidityFunction,
Contract,
)
from slither.core.expressions import (
AssignmentOperationType,
@ -13,8 +16,9 @@ from slither.core.expressions import (
ElementaryTypeNameExpression,
CallExpression,
Identifier,
MemberAccess,
)
from slither.core.solidity_types import ArrayType, ElementaryType
from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias
from slither.core.solidity_types.type import Type
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
from slither.core.variables.variable import Variable
@ -33,6 +37,7 @@ from slither.slithir.operations import (
Unpack,
Return,
SolidityCall,
Operation,
)
from slither.slithir.tmp_operations.argument import Argument
from slither.slithir.tmp_operations.tmp_call import TmpCall
@ -59,6 +64,10 @@ def get(expression):
return val
def get_without_removing(expression):
return expression.context[key]
def set_val(expression, val):
expression.context[key] = val
@ -127,7 +136,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
self._expression = expression
self._node = node
self._result = []
self._result: List[Operation] = []
self._visit_expression(self.expression)
if node.type == NodeType.RETURN:
r = Return(get(self.expression))
@ -240,8 +249,13 @@ class ExpressionToSlithIR(ExpressionVisitor):
def _post_call_expression(
self, expression
): # pylint: disable=too-many-branches,too-many-statements
called = get(expression.called)
): # pylint: disable=too-many-branches,too-many-statements,too-many-locals
assert isinstance(expression, CallExpression)
expression_called = expression.called
called = get(expression_called)
args = [get(a) for a in expression.arguments if a]
for arg in args:
arg_ = Argument(arg)
@ -259,66 +273,81 @@ class ExpressionToSlithIR(ExpressionVisitor):
internal_call.set_expression(expression)
self._result.append(internal_call)
set_val(expression, val)
# User defined types
elif (
isinstance(called, TypeAlias)
and isinstance(expression_called, MemberAccess)
and expression_called.member_name in ["wrap", "unwrap"]
and len(args) == 1
):
val = TemporaryVariable(self._node)
var = TypeConversion(val, args[0], called)
var.set_expression(expression)
val.set_type(called)
self._result.append(var)
set_val(expression, val)
# yul things
elif called.name == "caller()":
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed("msg.sender"), "uint256")
self._result.append(var)
set_val(expression, val)
elif called.name == "origin()":
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed("tx.origin"), "uint256")
self._result.append(var)
set_val(expression, val)
elif called.name == "extcodesize(uint256)":
val = ReferenceVariable(self._node)
var = Member(args[0], Constant("codesize"), val)
self._result.append(var)
set_val(expression, val)
elif called.name == "selfbalance()":
val = TemporaryVariable(self._node)
var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address"))
val.set_type(ElementaryType("address"))
self._result.append(var)
val1 = ReferenceVariable(self._node)
var1 = Member(val, Constant("balance"), val1)
self._result.append(var1)
set_val(expression, val1)
elif called.name == "address()":
val = TemporaryVariable(self._node)
var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address"))
val.set_type(ElementaryType("address"))
self._result.append(var)
set_val(expression, val)
elif called.name == "callvalue()":
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed("msg.value"), "uint256")
self._result.append(var)
set_val(expression, val)
else:
# yul things
if called.name == "caller()":
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed("msg.sender"), "uint256")
self._result.append(var)
set_val(expression, val)
elif called.name == "origin()":
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed("tx.origin"), "uint256")
self._result.append(var)
set_val(expression, val)
elif called.name == "extcodesize(uint256)":
val = ReferenceVariable(self._node)
var = Member(args[0], Constant("codesize"), val)
self._result.append(var)
set_val(expression, val)
elif called.name == "selfbalance()":
val = TemporaryVariable(self._node)
var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address"))
val.set_type(ElementaryType("address"))
self._result.append(var)
val1 = ReferenceVariable(self._node)
var1 = Member(val, Constant("balance"), val1)
self._result.append(var1)
set_val(expression, val1)
elif called.name == "address()":
val = TemporaryVariable(self._node)
var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address"))
val.set_type(ElementaryType("address"))
self._result.append(var)
set_val(expression, val)
elif called.name == "callvalue()":
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed("msg.value"), "uint256")
self._result.append(var)
set_val(expression, val)
# If tuple
if expression.type_call.startswith("tuple(") and expression.type_call != "tuple()":
val = TupleVariable(self._node)
else:
# If tuple
if expression.type_call.startswith("tuple(") and expression.type_call != "tuple()":
val = TupleVariable(self._node)
else:
val = TemporaryVariable(self._node)
val = TemporaryVariable(self._node)
message_call = TmpCall(called, len(args), val, expression.type_call)
message_call.set_expression(expression)
# Gas/value are only accessible here if the syntax {gas: , value: }
# Is used over .gas().value()
if expression.call_gas:
call_gas = get(expression.call_gas)
message_call.call_gas = call_gas
if expression.call_value:
call_value = get(expression.call_value)
message_call.call_value = call_value
if expression.call_salt:
call_salt = get(expression.call_salt)
message_call.call_salt = call_salt
self._result.append(message_call)
set_val(expression, val)
message_call = TmpCall(called, len(args), val, expression.type_call)
message_call.set_expression(expression)
# Gas/value are only accessible here if the syntax {gas: , value: }
# Is used over .gas().value()
if expression.call_gas:
call_gas = get(expression.call_gas)
message_call.call_gas = call_gas
if expression.call_value:
call_value = get(expression.call_value)
message_call.call_value = call_value
if expression.call_salt:
call_salt = get(expression.call_salt)
message_call.call_salt = call_salt
self._result.append(message_call)
set_val(expression, val)
def _post_conditional_expression(self, expression):
raise Exception(f"Ternary operator are not convertible to SlithIR {expression}")
@ -413,6 +442,20 @@ class ExpressionToSlithIR(ExpressionVisitor):
set_val(expression, val)
return
if isinstance(expr, TypeAlias) and expression.member_name in ["wrap", "unwrap"]:
# The logic is be handled by _post_call_expression
set_val(expression, expr)
return
# Early lookup to detect user defined types from other contracts definitions
# contract A { type MyInt is int}
# contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
# The logic is handled by _post_call_expression
if isinstance(expr, Contract):
if expression.member_name in expr.file_scope.user_defined_types:
set_val(expression, expr.file_scope.user_defined_types[expression.member_name])
return
val = ReferenceVariable(self._node)
member = Member(expr, Constant(expression.member_name), val)
member.set_expression(expression)

@ -0,0 +1,10 @@
{
"B": {
"u()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n",
"f()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"D": {},
"C": {
"f(Left[])": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}

@ -0,0 +1,10 @@
{
"B": {
"u()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n",
"f()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"D": {},
"C": {
"f(Left[])": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}

@ -0,0 +1,10 @@
{
"B": {
"u()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n",
"f()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"D": {},
"C": {
"f(Left[])": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}

@ -0,0 +1,10 @@
{
"B": {
"u()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n}\n",
"f()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"D": {},
"C": {
"f(Left[])": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}

@ -0,0 +1,30 @@
type MyInt is uint;
contract B {
type MyInt is int;
function u() internal returns(int) {}
function f() public{
MyInt mi = MyInt.wrap(u());
}
}
function f(MyInt a) pure returns (MyInt b) {
b = MyInt(a);
}
contract D
{
B.MyInt x = B.MyInt.wrap(int(1));
}
contract C {
function f(Left[] memory a) internal returns(Left){
return a[0];
}
}
type Left is bytes2;
MyInt constant x = MyInt.wrap(20);

@ -401,6 +401,8 @@ ALL_TESTS = [
),
Test("custom_error_with_state_variable.sol", make_version(8, 4, 12)),
Test("complex_imports/import_aliases/test.sol", VERSIONS_08),
# 0.8.9 crashes on our testcase
Test("user_defined_types.sol", ["0.8.8"] + make_version(8, 10, 12)),
]
# create the output folder if needed
try:

Loading…
Cancel
Save