From aae3a0e64212d2dc7b22a777e52d9c36f778bd30 Mon Sep 17 00:00:00 2001 From: Josselin Date: Sun, 7 Jun 2020 16:15:21 +0200 Subject: [PATCH] Minimal YUL handling --- slither/core/declarations/function.py | 5 +- .../core/declarations/solidity_variables.py | 9 +- slither/core/expressions/binary_operation.py | 26 + slither/detectors/statements/assembly.py | 2 +- slither/printers/summary/slithir.py | 2 +- slither/slithir/convert.py | 7 + slither/slithir/operations/binary.py | 26 + slither/slithir/operations/codesize.py | 24 + slither/slithir/utils/ssa.py | 5 + slither/solc_parsing/cfg/node.py | 48 ++ slither/solc_parsing/declarations/function.py | 30 +- slither/solc_parsing/yul/__init__.py | 0 slither/solc_parsing/yul/evm_functions.py | 268 +++++++++ slither/solc_parsing/yul/parse_yul.py | 525 ++++++++++++++++++ .../visitors/slithir/expression_to_slithir.py | 80 ++- 15 files changed, 1023 insertions(+), 34 deletions(-) create mode 100644 slither/slithir/operations/codesize.py create mode 100644 slither/solc_parsing/yul/__init__.py create mode 100644 slither/solc_parsing/yul/evm_functions.py create mode 100644 slither/solc_parsing/yul/parse_yul.py diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index d28dabb2c..caf06b74d 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -89,6 +89,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): def __init__(self): super(Function, self).__init__() + self._scope: List[str] = [] self._name: Optional[str] = None self._view: bool = False self._pure: bool = False @@ -207,7 +208,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Return the function signature without the return values """ name, parameters, _ = self.signature - return name + "(" + ",".join(parameters) + ")" + return ".".join(self._scope + [name]) + "(" + ",".join(parameters) + ")" @property def canonical_name(self) -> str: @@ -216,7 +217,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Return the function signature without the return values """ name, parameters, _ = self.signature - return self.contract_declarer.name + "." + name + "(" + ",".join(parameters) + ")" + return ".".join([self.contract_declarer.name] + self._scope + [name]) + "(" + ",".join(parameters) + ")" @property def contains_assembly(self) -> bool: diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index ff26dd1c7..3f493147d 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -90,7 +90,14 @@ class SolidityVariable(Context): # dev function, will be removed once the code is stable def _check_name(self, name: str): - assert name in SOLIDITY_VARIABLES + assert name in SOLIDITY_VARIABLES or name.endswith("_slot") or name.endswith("_offset") + + @property + def state_variable(self): + if self._name.endswith("_slot"): + return self._name[:-5] + if self._name.endswith("_offset"): + return self._name[:-7] @property def name(self) -> str: diff --git a/slither/core/expressions/binary_operation.py b/slither/core/expressions/binary_operation.py index 0a1c4b7a2..6f8c0f15f 100644 --- a/slither/core/expressions/binary_operation.py +++ b/slither/core/expressions/binary_operation.py @@ -31,6 +31,12 @@ class BinaryOperationType(Enum): ANDAND = 17 # && OROR = 18 # || + DIVISION_SIGNED = 19 + MODULO_SIGNED = 20 + LESS_SIGNED = 21 + GREATER_SIGNED = 22 + RIGHT_SHIFT_ARITHMETIC = 23 + @staticmethod def get_type(operation_type: "BinaryOperation"): if operation_type == "**": @@ -71,6 +77,16 @@ class BinaryOperationType(Enum): return BinaryOperationType.ANDAND if operation_type == "||": return BinaryOperationType.OROR + if operation_type == "/'": + return BinaryOperationType.DIVISION_SIGNED + if operation_type == "%'": + return BinaryOperationType.MODULO_SIGNED + if operation_type == "<'": + return BinaryOperationType.LESS_SIGNED + if operation_type == ">'": + return BinaryOperationType.GREATER_SIGNED + if operation_type == ">>'": + return BinaryOperationType.RIGHT_SHIFT_ARITHMETIC raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) @@ -113,6 +129,16 @@ class BinaryOperationType(Enum): return "&&" if self == BinaryOperationType.OROR: return "||" + if self == BinaryOperationType.DIVISION_SIGNED: + return "/'" + if self == BinaryOperationType.MODULO_SIGNED: + return "%'" + if self == BinaryOperationType.LESS_SIGNED: + return "<'" + if self == BinaryOperationType.GREATER_SIGNED: + return ">'" + if self == BinaryOperationType.RIGHT_SHIFT_ARITHMETIC: + return ">>'" raise SlitherCoreError("str: Unknown operation type {})".format(self)) diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py index bf84dab7a..ceed61336 100644 --- a/slither/detectors/statements/assembly.py +++ b/slither/detectors/statements/assembly.py @@ -30,7 +30,7 @@ class Assembly(AbstractDetector): Returns: (bool) """ - return node.type == NodeType.ASSEMBLY + return node.type == NodeType.ASSEMBLY and len(node.yul_path) == 2 def detect_assembly(self, contract): ret = [] diff --git a/slither/printers/summary/slithir.py b/slither/printers/summary/slithir.py index a6ba49a17..7326fa0be 100644 --- a/slither/printers/summary/slithir.py +++ b/slither/printers/summary/slithir.py @@ -20,7 +20,7 @@ class PrinterSlithIR(AbstractPrinter): txt = "" for contract in self.contracts: - txt += 'Contract {}'.format(contract.name) + txt += 'Contract {}\n'.format(contract.name) for function in contract.functions: txt += f'\tFunction {function.canonical_name} {"" if function.is_shadowed else "(*)"}\n' for node in function.nodes: diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index c47739e95..ef8c02432 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -10,6 +10,7 @@ from slither.core.solidity_types import (ArrayType, ElementaryType, from slither.core.solidity_types.elementary_type import Int as ElementaryTypeInt from slither.core.variables.variable import Variable from slither.core.variables.state_variable import StateVariable +from slither.slithir.operations.codesize import CodeSize from slither.slithir.variables import TupleVariable from slither.slithir.operations import (Assignment, Balance, Binary, BinaryType, Call, Condition, Delete, @@ -512,6 +513,12 @@ def propagate_types(ir, node): b.set_expression(ir.expression) b.set_node(ir.node) return b + if ir.variable_right == 'codesize' and not isinstance(ir.variable_left, Contract) and isinstance( + ir.variable_left.type, ElementaryType): + b = CodeSize(ir.variable_left, ir.lvalue) + b.set_expression(ir.expression) + b.set_node(ir.node) + return b if ir.variable_right == 'selector' and isinstance(ir.variable_left.type, Function): assignment = Assignment(ir.lvalue, Constant(str(get_function_id(ir.variable_left.type.full_name))), diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index cd886197f..fce83aa6d 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -31,6 +31,12 @@ class BinaryType(Enum): ANDAND = 17 # && OROR = 18 # || + DIVISION_SIGNED = 19 + MODULO_SIGNED = 20 + LESS_SIGNED = 21 + GREATER_SIGNED = 22 + RIGHT_SHIFT_ARITHMETIC = 23 + @staticmethod def return_bool(operation_type): return operation_type in [BinaryType.OROR, @@ -82,6 +88,16 @@ class BinaryType(Enum): return BinaryType.ANDAND if operation_type == '||': return BinaryType.OROR + if operation_type == "/'": + return BinaryType.DIVISION_SIGNED + if operation_type == "%'": + return BinaryType.MODULO_SIGNED + if operation_type == "<'": + return BinaryType.LESS_SIGNED + if operation_type == ">'": + return BinaryType.GREATER_SIGNED + if operation_type == ">>'": + return BinaryType.RIGHT_SHIFT_ARITHMETIC raise SlithIRError('get_type: Unknown operation type {})'.format(operation_type)) @@ -124,6 +140,16 @@ class BinaryType(Enum): return "&&" if self == BinaryType.OROR: return "||" + if self == BinaryType.DIVISION_SIGNED: + return "/'" + if self == BinaryType.MODULO_SIGNED: + return "%'" + if self == BinaryType.LESS_SIGNED: + return "<'" + if self == BinaryType.GREATER_SIGNED: + return ">'" + if self == BinaryType.RIGHT_SHIFT_ARITHMETIC: + return ">>'" raise SlithIRError("str: Unknown operation type {} {})".format(self, type(self))) diff --git a/slither/slithir/operations/codesize.py b/slither/slithir/operations/codesize.py new file mode 100644 index 000000000..85289dc38 --- /dev/null +++ b/slither/slithir/operations/codesize.py @@ -0,0 +1,24 @@ +from slither.core.solidity_types import ElementaryType +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + + +class CodeSize(OperationWithLValue): + + def __init__(self, value, lvalue): + assert is_valid_rvalue(value) + assert is_valid_lvalue(lvalue) + self._value = value + self._lvalue = lvalue + lvalue.set_type(ElementaryType('uint256')) + + @property + def read(self): + return [self._value] + + @property + def value(self): + return self._value + + def __str__(self): + return "{} -> CODESIZE {}".format(self.lvalue, self.value) diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index e05069a59..277fb4ea7 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -17,6 +17,7 @@ from slither.slithir.operations import (Assignment, Balance, Binary, Condition, Push, Return, Send, SolidityCall, Transfer, TypeConversion, Unary, Unpack, Nop) +from slither.slithir.operations.codesize import CodeSize from slither.slithir.variables import (Constant, LocalIRVariable, ReferenceVariable, ReferenceVariableSSA, StateIRVariable, TemporaryVariable, @@ -527,6 +528,10 @@ def copy_ir(ir, *instances): variable_right = get_variable(ir, lambda x: x.variable_right, *instances) operation_type = ir.type return Binary(lvalue, variable_left, variable_right, operation_type) + elif isinstance(ir, CodeSize): + lvalue = get_variable(ir, lambda x: x.lvalue, *instances) + value = get_variable(ir, lambda x: x.value, *instances) + return CodeSize(value, lvalue) elif isinstance(ir, Condition): val = get_variable(ir, lambda x: x.value, *instances) return Condition(val) diff --git a/slither/solc_parsing/cfg/node.py b/slither/solc_parsing/cfg/node.py index 2d2bec0f3..9ed4c950b 100644 --- a/slither/solc_parsing/cfg/node.py +++ b/slither/solc_parsing/cfg/node.py @@ -9,6 +9,7 @@ from slither.core.expressions.assignment_operation import ( from slither.core.expressions.identifier import Identifier from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.visitors.expression.find_calls import FindCalls +from slither.solc_parsing.yul.parse_yul import parse_yul from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.write_var import WriteVar @@ -17,15 +18,57 @@ class NodeSolc: def __init__(self, node: Node): self._unparsed_expression: Optional[Dict] = None self._node = node + self._unparsed_yul_expression = None + + """ + todo this should really go somewhere else, but until + that happens I'm setting it to None for performance + """ + self._yul_local_variables = None + self._yul_local_functions = None + self._yul_path = None @property def underlying_node(self) -> Node: return self._node + def set_yul_root(self, func): + self._yul_path = [func.name, f"asm_{func._counter_asm_nodes}"] + + def set_yul_child(self, parent, cur): + self._yul_path = parent.yul_path + [cur] + + @property + def yul_path(self): + return self._yul_path + + def format_canonical_yul_name(self, name, off=None): + return ".".join(self._yul_path[:off] + [name]) + + def add_yul_local_variable(self, var): + if not self._yul_local_variables: + self._yul_local_variables = [] + self._yul_local_variables.append(var) + + def get_yul_local_variable_from_name(self, variable_name): + return next((v for v in self._yul_local_variables if v.name == variable_name), None) + + def add_yul_local_function(self, func): + if not self._yul_local_functions: + self._yul_local_functions = [] + self._yul_local_functions.append(func) + + def get_yul_local_function_from_name(self, func_name): + return next((v for v in self._yul_local_functions if v.name == func_name), None) + def add_unparsed_expression(self, expression: Dict): assert self._unparsed_expression is None self._unparsed_expression = expression + def add_unparsed_yul_expression(self, root, expression): + assert self._unparsed_expression is None + self._unparsed_yul_expression = (root, expression) + def analyze_expressions(self, caller_context): if self._node.type == NodeType.VARIABLE and not self._node.expression: self._node.add_expression(self._node.variable_declaration.expression) @@ -34,6 +77,11 @@ class NodeSolc: self._node.add_expression(expression) # self._unparsed_expression = None + if self._unparsed_yul_expression: + expression = parse_yul(self._unparsed_yul_expression[0], self, self._unparsed_yul_expression[1]) + self._expression = expression + self._unparsed_yul_expression = None + if self._node.expression: if self._node.type == NodeType.VARIABLE: diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 4572c6766..0c0f4d66b 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -18,6 +18,7 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import ( LocalVariableInitFromTupleSolc, ) from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration +from slither.solc_parsing.yul.parse_yul import convert_yul from slither.utils.expression_manipulations import SplitTernaryExpression from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional @@ -62,6 +63,8 @@ class FunctionSolc: self._functionNotParsed = function_data self._params_was_analyzed = False self._content_was_analyzed = False + self._counter_nodes = 0 + self._counter_asm_nodes = 0 self._counter_scope_local_variables = 0 # variable renamed will map the solc id @@ -312,6 +315,9 @@ class FunctionSolc: self._node_to_nodesolc[node] = node_parser return node_parser + def node_solc(self): + return NodeSolc + # endregion ################################################################################### ################################################################################### @@ -797,13 +803,23 @@ class FunctionSolc: elif name == "Block": node = self._parse_block(statement, node) elif name == "InlineAssembly": - asm_node = self._new_node(NodeType.ASSEMBLY, statement["src"]) - self._function.contains_assembly = True - # Added with solc 0.4.12 - if "operations" in statement: - asm_node.underlying_node.add_inline_asm(statement["operations"]) - link_underlying_nodes(node, asm_node) - node = asm_node + # Added with solc 0.6 - the yul code is an AST + if 'AST' in statement: + self._contains_assembly = True + yul_root = self._new_node(NodeType.ASSEMBLY, statement['src']) + yul_root.set_yul_root(self) + link_underlying_nodes(node, yul_root) + self._counter_asm_nodes += 1 + + node = convert_yul(yul_root, yul_root, statement['AST']) + else: + asm_node = self._new_node(NodeType.ASSEMBLY, statement['src']) + self._function._contains_assembly = True + # Added with solc 0.4.12 + if 'operations' in statement: + asm_node.underlying_node.add_inline_asm(statement['operations']) + link_underlying_nodes(node, asm_node) + node = asm_node elif name == "DoWhileStatement": node = self._parse_dowhile(statement, node) # For Continue / Break / Return / Throw diff --git a/slither/solc_parsing/yul/__init__.py b/slither/solc_parsing/yul/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solc_parsing/yul/evm_functions.py b/slither/solc_parsing/yul/evm_functions.py new file mode 100644 index 000000000..aa237428b --- /dev/null +++ b/slither/solc_parsing/yul/evm_functions.py @@ -0,0 +1,268 @@ +from slither.core.declarations.solidity_variables import SOLIDITY_FUNCTIONS +from slither.core.expressions import BinaryOperationType, UnaryOperationType + +# taken from https://github.com/ethereum/solidity/blob/356cc91084114f840da66804b2a9fc1ac2846cff/libevmasm/Instruction.cpp#L180 +evm_opcodes = [ + "STOP", + "ADD", + "SUB", + "MUL", + "DIV", + "SDIV", + "MOD", + "SMOD", + "EXP", + "NOT", + "LT", + "GT", + "SLT", + "SGT", + "EQ", + "ISZERO", + "AND", + "OR", + "XOR", + "BYTE", + "SHL", + "SHR", + "SAR", + "ADDMOD", + "MULMOD", + "SIGNEXTEND", + "KECCAK256", + "ADDRESS", + "BALANCE", + "ORIGIN", + "CALLER", + "CALLVALUE", + "CALLDATALOAD", + "CALLDATASIZE", + "CALLDATACOPY", + "CODESIZE", + "CODECOPY", + "GASPRICE", + "EXTCODESIZE", + "EXTCODECOPY", + "RETURNDATASIZE", + "RETURNDATACOPY", + "EXTCODEHASH", + "BLOCKHASH", + "COINBASE", + "TIMESTAMP", + "NUMBER", + "DIFFICULTY", + "GASLIMIT", + "CHAINID", + "SELFBALANCE", + "POP", + "MLOAD", + "MSTORE", + "MSTORE8", + "SLOAD", + "SSTORE", + "JUMP", + "JUMPI", + "PC", + "MSIZE", + "GAS", + "JUMPDEST", + "PUSH1", + "PUSH2", + "PUSH3", + "PUSH4", + "PUSH5", + "PUSH6", + "PUSH7", + "PUSH8", + "PUSH9", + "PUSH10", + "PUSH11", + "PUSH12", + "PUSH13", + "PUSH14", + "PUSH15", + "PUSH16", + "PUSH17", + "PUSH18", + "PUSH19", + "PUSH20", + "PUSH21", + "PUSH22", + "PUSH23", + "PUSH24", + "PUSH25", + "PUSH26", + "PUSH27", + "PUSH28", + "PUSH29", + "PUSH30", + "PUSH31", + "PUSH32", + "DUP1", + "DUP2", + "DUP3", + "DUP4", + "DUP5", + "DUP6", + "DUP7", + "DUP8", + "DUP9", + "DUP10", + "DUP11", + "DUP12", + "DUP13", + "DUP14", + "DUP15", + "DUP16", + "SWAP1", + "SWAP2", + "SWAP3", + "SWAP4", + "SWAP5", + "SWAP6", + "SWAP7", + "SWAP8", + "SWAP9", + "SWAP10", + "SWAP11", + "SWAP12", + "SWAP13", + "SWAP14", + "SWAP15", + "SWAP16", + "LOG0", + "LOG1", + "LOG2", + "LOG3", + "LOG4", + "CREATE", + "CALL", + "CALLCODE", + "STATICCALL", + "RETURN", + "DELEGATECALL", + "CREATE2", + "REVERT", + "INVALID", + "SELFDESTRUCT", +] + +yul_funcs = [ + "datasize", + "dataoffset", + "datacopy", + "setimmutable", + "loadimmutable", +] + +builtins = [x.lower() for x in evm_opcodes if not ( + x.startswith("PUSH") or + x.startswith("SWAP") or + x.startswith("DUP") or + x == "JUMP" or + x == "JUMPI" or + x == "JUMPDEST" +)] + yul_funcs + +function_args = { + 'byte': [2, 1], + 'addmod': [3, 1], + 'mulmod': [3, 1], + 'signextend': [2, 1], + 'keccak256': [2, 1], + 'pc': [0, 1], + 'pop': [1, 0], + 'mload': [1, 1], + 'mstore': [2, 0], + 'mstore8': [2, 0], + 'sload': [1, 1], + 'sstore': [2, 0], + 'msize': [1, 1], + 'gas': [0, 1], + 'address': [0, 1], + 'balance': [1, 1], + 'selfbalance': [0, 1], + 'caller': [0, 1], + 'callvalue': [0, 1], + 'calldataload': [1, 1], + 'calldatasize': [0, 1], + 'calldatacopy': [3, 0], + 'codesize': [0, 1], + 'codecopy': [3, 0], + 'extcodesize': [1, 1], + 'extcodecopy': [4, 0], + 'returndatasize': [0, 1], + 'returndatacopy': [3, 0], + 'extcodehash': [1, 1], + 'create': [3, 1], + 'create2': [4, 1], + 'call': [7, 1], + 'callcode': [7, 1], + 'delegatecall': [6, 1], + 'staticcall': [6, 1], + 'return': [2, 0], + 'revert': [2, 0], + 'selfdestruct': [1, 0], + 'invalid': [0, 0], + 'log0': [2, 0], + 'log1': [3, 0], + 'log2': [4, 0], + 'log3': [5, 0], + 'log4': [6, 0], + 'chainid': [0, 1], + 'origin': [0, 1], + 'gasprice': [0, 1], + 'blockhash': [1, 1], + 'coinbase': [0, 1], + 'timestamp': [0, 1], + 'number': [0, 1], + 'difficulty': [0, 1], + 'gaslimit': [0, 1], +} + + +def format_function_descriptor(name): + if name not in function_args: + return name + "()" + + return name + "(" + ",".join(["uint256"] * function_args[name][0]) + ")" + + +for k, v in function_args.items(): + SOLIDITY_FUNCTIONS[format_function_descriptor(k)] = ['uint256'] * v[1] + +unary_ops = { + 'not': UnaryOperationType.TILD, + 'iszero': UnaryOperationType.BANG, +} + +binary_ops = { + 'add': BinaryOperationType.ADDITION, + 'sub': BinaryOperationType.SUBTRACTION, + 'mul': BinaryOperationType.MULTIPLICATION, + 'div': BinaryOperationType.DIVISION, + 'sdiv': BinaryOperationType.DIVISION_SIGNED, + 'mod': BinaryOperationType.MODULO, + 'smod': BinaryOperationType.MODULO_SIGNED, + 'exp': BinaryOperationType.POWER, + 'lt': BinaryOperationType.LESS, + 'gt': BinaryOperationType.GREATER, + 'slt': BinaryOperationType.LESS_SIGNED, + 'sgt': BinaryOperationType.GREATER_SIGNED, + 'eq': BinaryOperationType.EQUAL, + 'and': BinaryOperationType.AND, + 'or': BinaryOperationType.OR, + 'xor': BinaryOperationType.CARET, + 'shl': BinaryOperationType.LEFT_SHIFT, + 'shr': BinaryOperationType.RIGHT_SHIFT, + 'sar': BinaryOperationType.RIGHT_SHIFT_ARITHMETIC, +} + + +class YulBuiltin: + def __init__(self, name): + self._name = name + + @property + def name(self): + return self._name diff --git a/slither/solc_parsing/yul/parse_yul.py b/slither/solc_parsing/yul/parse_yul.py new file mode 100644 index 000000000..4552a57c8 --- /dev/null +++ b/slither/solc_parsing/yul/parse_yul.py @@ -0,0 +1,525 @@ +import json + +from slither.core.cfg.node import NodeType, link_nodes +from slither.core.declarations import Function, SolidityFunction, SolidityVariable +from slither.core.expressions import ( + Literal, + AssignmentOperation, + AssignmentOperationType, + Identifier, CallExpression, TupleExpression, BinaryOperation, UnaryOperation, +) +from slither.core.solidity_types import ElementaryType +from slither.core.variables.local_variable import LocalVariable +from slither.exceptions import SlitherException +from slither.solc_parsing.yul.evm_functions import * + + +class YulLocalVariable(LocalVariable): + + def __init__(self, ast): + super(LocalVariable, self).__init__() + + assert (ast['nodeType'] == 'YulTypedName') + self._name = ast['name'] + self._type = ElementaryType('uint256') + + self._location = 'memory' + + +class YulFunction(Function): + + def __init__(self, ast, root): + super(YulFunction, self).__init__() + + assert (ast['nodeType'] == 'YulFunctionDefinition') + + self._contract = root.function.contract + self._contract_declarer = root.function.contract_declarer + + self._name = ast['name'] + self._scope = root.yul_path + self._counter_nodes = 0 + + self._is_implemented = True + self._contains_assembly = True + + self._node_solc = root.function.node_solc() + + self._entry_point = self.new_node(NodeType.ASSEMBLY, ast['src']) + self._entry_point.set_yul_child(root, ast['name']) + + self._ast = ast + self.set_offset(ast['src'], root.function.slither) + + def convert_body(self): + node = self.new_node(NodeType.ENTRYPOINT, self._ast['src']) + link_nodes(self.entry_point, node) + + for param in self._ast.get('parameters', []): + node = convert_yul(self.entry_point, node, param) + self._parameters.append(self.entry_point.get_yul_local_variable_from_name(param['name'])) + + for ret in self._ast.get('returnVariables', []): + node = convert_yul(self.entry_point, node, ret) + self._returns.append(self.entry_point.get_yul_local_variable_from_name(ret['name'])) + + convert_yul(self.entry_point, node, self._ast['body']) + + def parse_body(self): + for node in self.nodes: + node.analyze_expressions(self) + + def node_solc(self): + return self._node_solc + + def new_node(self, node_type, src): + node = self._node_solc(node_type, self._counter_nodes) + node.set_offset(src, self.slither) + node.set_function(self) + self._counter_nodes += 1 + self._nodes.append(node) + return node + + +################################################################################### +################################################################################### +# region Block conversion +################################################################################### +################################################################################### + +""" +The functions in this region, at a high level, will extract the control flow +structures and metadata from the input AST. These include things like function +definitions and local variables. + +Each function takes three parameters: + 1) root is a NodeSolc of NodeType.ASSEMBLY, and stores information at the + local scope. In Yul, variables are scoped to the function they're + declared in (except for variables outside the assembly block) + 2) parent is the last node in the CFG. new nodes should be linked against + this node + 3) ast is a dictionary and is the current node in the Yul ast being converted + +Each function must return a single parameter: + 1) A NodeSolc representing the new end of the CFG + +The entrypoint is the function at the end of this region, `convert_yul`, which +dispatches to a specialized function based on a lookup dictionary. +""" + + +def convert_yul_block(root, parent, ast): + for statement in ast["statements"]: + parent = convert_yul(root, parent, statement) + return parent + + +def convert_yul_function_definition(root, parent, ast): + f = YulFunction(ast, root) + + root.function.contract._functions[root.format_canonical_yul_name(f.name)] = f + + f.convert_body() + f.parse_body() + + return parent + + +def convert_yul_variable_declaration(root, parent, ast): + for variable_ast in ast['variables']: + parent = convert_yul(root, parent, variable_ast) + + node = parent.function.new_node(NodeType.EXPRESSION, ast["src"]) + node.add_unparsed_yul_expression(root, ast) + link_nodes(parent, node) + + return node + + +def convert_yul_assignment(root, parent, ast): + node = parent.function.new_node(NodeType.EXPRESSION, ast["src"]) + node.add_unparsed_yul_expression(root, ast) + link_nodes(parent, node) + return node + + +def convert_yul_expression_statement(root, parent, ast): + src = ast['src'] + expression_ast = ast['expression'] + + expression = parent.function.new_node(NodeType.EXPRESSION, src) + expression.add_unparsed_yul_expression(root, expression_ast) + link_nodes(parent, expression) + + return expression + + +def convert_yul_if(root, parent, ast): + # we're cheating and pretending that yul supports if/else so we can convert switch cleaner + + src = ast['src'] + condition_ast = ast['condition'] + true_body_ast = ast['body'] + false_body_ast = ast['false_body'] if 'false_body' in ast else None + + condition = parent.function.new_node(NodeType.IF, src) + end = parent.function.new_node(NodeType.ENDIF, src) + + condition.add_unparsed_yul_expression(root, condition_ast) + + true_body = convert_yul(root, condition, true_body_ast) + + if false_body_ast: + false_body = convert_yul(root, condition, false_body_ast) + link_nodes(false_body, end) + else: + link_nodes(condition, end) + + link_nodes(parent, condition) + link_nodes(true_body, end) + + return end + + +def convert_yul_switch(root, parent, ast): + """ + This is unfortunate. We don't really want a switch in our IR so we're going to + translate it into a series of if/else statements. + """ + cases_ast = ast['cases'] + expression_ast = ast['expression'] + + # this variable stores the result of the expression so we don't accidentally compute it more than once + switch_expr_var = 'switch_expr_{}'.format(ast['src'].replace(':', '_')) + + rewritten_switch = { + 'nodeType': 'YulBlock', + 'src': ast['src'], + 'statements': [ + { + 'nodeType': 'YulVariableDeclaration', + 'src': expression_ast['src'], + 'variables': [ + { + 'nodeType': 'YulTypedName', + 'src': expression_ast['src'], + 'name': switch_expr_var, + 'type': '', + }, + ], + 'value': expression_ast, + }, + ], + } + + last_if = None + + default_ast = None + + for case_ast in cases_ast: + body_ast = case_ast['body'] + value_ast = case_ast['value'] + + if value_ast == 'default': + default_ast = case_ast + continue + + current_if = { + 'nodeType': 'YulIf', + 'src': case_ast['src'], + 'condition': { + 'nodeType': 'YulFunctionCall', + 'src': case_ast['src'], + 'functionName': { + 'nodeType': 'YulIdentifier', + 'src': case_ast['src'], + 'name': 'eq', + }, + 'arguments': [ + { + 'nodeType': 'YulIdentifier', + 'src': case_ast['src'], + 'name': switch_expr_var, + }, + value_ast, + ] + }, + 'body': body_ast, + } + + if last_if: + last_if['false_body'] = current_if + else: + rewritten_switch['statements'].append(current_if) + + last_if = current_if + + if default_ast: + body_ast = default_ast['body'] + + if last_if: + last_if['false_body'] = body_ast + else: + rewritten_switch['statements'].append(body_ast) + + return convert_yul(root, parent, rewritten_switch) + + +def convert_yul_for_loop(root, parent, ast): + pre_ast = ast['pre'] + condition_ast = ast['condition'] + post_ast = ast['post'] + body_ast = ast['body'] + + start_loop = parent.function.new_node(NodeType.STARTLOOP, ast['src']) + end_loop = parent.function.new_node(NodeType.ENDLOOP, ast['src']) + + link_nodes(parent, start_loop) + + pre = convert_yul(root, start_loop, pre_ast) + + condition = parent.function.new_node(NodeType.IFLOOP, condition_ast['src']) + condition.add_unparsed_yul_expression(root, condition_ast) + link_nodes(pre, condition) + + link_nodes(condition, end_loop) + + body = convert_yul(root, condition, body_ast) + + post = convert_yul(root, body, post_ast) + + link_nodes(post, condition) + + return end_loop + + +def convert_yul_break(root, parent, ast): + break_ = parent.function.new_node(NodeType.BREAK, ast['src']) + link_nodes(parent, break_) + return break_ + + +def convert_yul_continue(root, parent, ast): + continue_ = parent.function.new_node(NodeType.CONTINUE, ast['src']) + link_nodes(parent, continue_) + return continue_ + + +def convert_yul_leave(root, parent, ast): + leave = parent.function.new_node(NodeType.RETURN, ast['src']) + link_nodes(parent, leave) + return leave + + +def convert_yul_typed_name(root, parent, ast): + var = YulLocalVariable(ast) + var.set_function(root.function) + var.set_offset(ast['src'], root.slither) + + root.add_yul_local_variable(var) + + node = parent.function.new_node(NodeType.VARIABLE, ast['src']) + node.add_variable_declaration(var) + link_nodes(parent, node) + + return node + + +def convert_yul_unsupported(root, parent, ast): + raise SlitherException(f"no converter available for {ast['nodeType']} {json.dumps(ast, indent=2)}") + + +def convert_yul(root, parent, ast): + return converters.get(ast['nodeType'], convert_yul_unsupported)(root, parent, ast) + + +converters = { + 'YulBlock': convert_yul_block, + 'YulFunctionDefinition': convert_yul_function_definition, + 'YulVariableDeclaration': convert_yul_variable_declaration, + 'YulAssignment': convert_yul_assignment, + 'YulExpressionStatement': convert_yul_expression_statement, + 'YulIf': convert_yul_if, + 'YulSwitch': convert_yul_switch, + 'YulForLoop': convert_yul_for_loop, + 'YulBreak': convert_yul_break, + 'YulContinue': convert_yul_continue, + 'YulLeave': convert_yul_leave, + 'YulTypedName': convert_yul_typed_name, +} + +# endregion +################################################################################### +################################################################################### + +################################################################################### +################################################################################### +# region Expression parsing +################################################################################### +################################################################################### + +""" +The functions in this region parse the AST into expressions. + +Each function takes three parameters: + 1) root is the same root as above + 2) node is the CFG node which stores this expression + 3) ast is the same ast as above + +Each function must return a single parameter: + 1) The operation that was parsed, or None + +The entrypoint is the function at the end of this region, `parse_yul`, which +dispatches to a specialized function based on a lookup dictionary. +""" + + +def _parse_yul_assignment_common(root, node, ast, key): + lhs = [parse_yul(root, node, arg) for arg in ast[key]] + rhs = parse_yul(root, node, ast['value']) + + return AssignmentOperation(vars_to_val(lhs), rhs, AssignmentOperationType.ASSIGN, vars_to_typestr(lhs)) + + +def parse_yul_variable_declaration(root, node, ast): + """ + We already created variables in the conversion phase, so just do + the assignment + """ + + if not ast['value']: + return None + + return _parse_yul_assignment_common(root, node, ast, 'variables') + + +def parse_yul_assignment(root, node, ast): + return _parse_yul_assignment_common(root, node, ast, 'variableNames') + + +def parse_yul_function_call(root, node, ast): + args = [parse_yul(root, node, arg) for arg in ast['arguments']] + ident = parse_yul(root, node, ast['functionName']) + + if isinstance(ident.value, YulBuiltin): + name = ident.value.name + if name in binary_ops: + if name in ['shl', 'shr', 'sar']: + # lmao ok + return BinaryOperation(args[1], args[0], binary_ops[name]) + + return BinaryOperation(args[0], args[1], binary_ops[name]) + + if name in unary_ops: + return UnaryOperation(args[0], unary_ops[name]) + + ident = Identifier(SolidityFunction(format_function_descriptor(ident.value.name))) + + if isinstance(ident.value, Function): + return CallExpression(ident, args, vars_to_typestr(ident.value.returns)) + elif isinstance(ident.value, SolidityFunction): + return CallExpression(ident, args, vars_to_typestr(ident.value.return_type)) + else: + raise SlitherException(f"unexpected function call target type {str(type(ident.value))}") + + +def parse_yul_identifier(root, node, ast): + name = ast['name'] + + if name in builtins: + return Identifier(YulBuiltin(name)) + + # check function-scoped variables + variable = root.function.get_local_variable_from_name(name) + if variable: + return Identifier(variable) + + # check yul-scoped variable + variable = root.get_yul_local_variable_from_name(name) + if variable: + return Identifier(variable) + + # check yul-scoped function + # note that a function can recurse into itself, so we have two canonical names + # to check (but only one of them can be valid) + + functions = root.function.contract_declarer._functions + + canonical_name = root.format_canonical_yul_name(name) + if canonical_name in functions: + return Identifier(functions[canonical_name]) + + canonical_name = root.format_canonical_yul_name(name, -1) + if canonical_name in functions: + return Identifier(functions[canonical_name]) + + # check for magic suffixes + if name.endswith("_slot"): + potential_name = name[:-5] + var = root.function.contract.get_state_variable_from_name(potential_name) + if var: + return Identifier(SolidityVariable(name)) + if name.endswith("_offset"): + potential_name = name[:-7] + var = root.function.contract.get_state_variable_from_name(potential_name) + if var: + return Identifier(SolidityVariable(name)) + + raise SlitherException(f"unresolved reference to identifier {name}") + + +def parse_yul_literal(root, node, ast): + type_ = ast['type'] + value = ast['value'] + + if not type_: + type_ = 'bool' if value in ['true', 'false'] else 'uint256' + + return Literal(value, ElementaryType(type_)) + + +def parse_yul_typed_name(root, node, ast): + var = root.get_yul_local_variable_from_name(ast['name']) + + i = Identifier(var) + i._type = var.type + return i + + +def parse_yul_unsupported(root, node, ast): + raise SlitherException(f"no parser available for {ast['nodeType']} {json.dumps(ast, indent=2)}") + + +def parse_yul(root, node, ast): + op = parsers.get(ast['nodeType'], parse_yul_unsupported)(root, node, ast) + if op: + op.set_offset(ast["src"], root.slither) + return op + + +parsers = { + 'YulVariableDeclaration': parse_yul_variable_declaration, + 'YulAssignment': parse_yul_assignment, + 'YulFunctionCall': parse_yul_function_call, + 'YulIdentifier': parse_yul_identifier, + 'YulTypedName': parse_yul_typed_name, + 'YulLiteral': parse_yul_literal, +} + + +# endregion +################################################################################### +################################################################################### + +def vars_to_typestr(rets): + if len(rets) == 0: + return "" + if len(rets) == 1: + return str(rets[0].type) + return "tuple({})".format(",".join(str(ret.type) for ret in rets)) + + +def vars_to_val(vars): + if len(vars) == 1: + return vars[0] + return TupleExpression(vars) diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 69fb2a980..c215084b7 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -1,9 +1,9 @@ import logging -from slither.core.declarations import Function +from slither.core.declarations import Function, SolidityVariable, SolidityVariableComposed from slither.core.expressions import (AssignmentOperationType, UnaryOperationType, BinaryOperationType) -from slither.core.solidity_types import ArrayType +from slither.core.solidity_types import ArrayType, ElementaryType from slither.core.solidity_types.type import Type from slither.slithir.operations import (Assignment, Binary, BinaryType, Delete, Index, InitArray, InternalCall, Member, @@ -167,27 +167,63 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(internal_call) set_val(expression, val) else: - # If tuple - if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': - val = TupleVariable(self._node) - else: + # yul things + if called.name == 'caller()': val = TemporaryVariable(self._node) - - message_call = TmpCall(called, len(args), val, expression.type_call) - message_call.set_expression(expression) - # Gas/value are only accessible here if the syntax {gas: , value: } - # Is used over .gas().value() - if expression.call_gas: - call_gas = get(expression.call_gas) - message_call.call_gas = call_gas - if expression.call_value: - call_value = get(expression.call_value) - message_call.call_value = call_value - if expression.call_salt: - call_salt = get(expression.call_salt) - message_call.call_salt = call_salt - self._result.append(message_call) - set_val(expression, val) + var = Assignment(val, SolidityVariableComposed('msg.sender'), 'uint256') + self._result.append(var) + set_val(expression, val) + elif called.name == 'origin()': + val = TemporaryVariable(self._node) + var = Assignment(val, SolidityVariableComposed('tx.origin'), 'uint256') + self._result.append(var) + set_val(expression, val) + elif called.name == 'extcodesize(uint256)': + val = ReferenceVariable(self._node) + var = Member(args[0], Constant('codesize'), val) + self._result.append(var) + set_val(expression, val) + elif called.name == 'selfbalance()': + val = TemporaryVariable(self._node) + var = TypeConversion(val, SolidityVariable('this'), ElementaryType('address')) + self._result.append(var) + + val1 = ReferenceVariable(self._node) + var1 = Member(val, Constant('balance'), val1) + self._result.append(var1) + set_val(expression, val1) + elif called.name == 'address()': + val = TemporaryVariable(self._node) + var = TypeConversion(val, SolidityVariable('this'), ElementaryType('address')) + self._result.append(var) + set_val(expression, val) + elif called.name == 'callvalue()': + val = TemporaryVariable(self._node) + var = Assignment(val, SolidityVariableComposed('msg.value'), 'uint256') + self._result.append(var) + set_val(expression, val) + else: + # If tuple + if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': + val = TupleVariable(self._node) + else: + val = TemporaryVariable(self._node) + + message_call = TmpCall(called, len(args), val, expression.type_call) + message_call.set_expression(expression) + # Gas/value are only accessible here if the syntax {gas: , value: } + # Is used over .gas().value() + if expression.call_gas: + call_gas = get(expression.call_gas) + message_call.call_gas = call_gas + if expression.call_value: + call_value = get(expression.call_value) + message_call.call_value = call_value + if expression.call_salt: + call_salt = get(expression.call_salt) + message_call.call_salt = call_salt + self._result.append(message_call) + set_val(expression, val) def _post_conditional_expression(self, expression): raise Exception('Ternary operator are not convertible to SlithIR {}'.format(expression))