Minimal YUL handling

pull/502/head
Josselin 5 years ago committed by samczsun
parent 64ecdbec9a
commit aae3a0e642
  1. 5
      slither/core/declarations/function.py
  2. 9
      slither/core/declarations/solidity_variables.py
  3. 26
      slither/core/expressions/binary_operation.py
  4. 2
      slither/detectors/statements/assembly.py
  5. 2
      slither/printers/summary/slithir.py
  6. 7
      slither/slithir/convert.py
  7. 26
      slither/slithir/operations/binary.py
  8. 24
      slither/slithir/operations/codesize.py
  9. 5
      slither/slithir/utils/ssa.py
  10. 48
      slither/solc_parsing/cfg/node.py
  11. 30
      slither/solc_parsing/declarations/function.py
  12. 0
      slither/solc_parsing/yul/__init__.py
  13. 268
      slither/solc_parsing/yul/evm_functions.py
  14. 525
      slither/solc_parsing/yul/parse_yul.py
  15. 80
      slither/visitors/slithir/expression_to_slithir.py

@ -89,6 +89,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
def __init__(self):
super(Function, self).__init__()
self._scope: List[str] = []
self._name: Optional[str] = None
self._view: bool = False
self._pure: bool = False
@ -207,7 +208,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Return the function signature without the return values
"""
name, parameters, _ = self.signature
return name + "(" + ",".join(parameters) + ")"
return ".".join(self._scope + [name]) + "(" + ",".join(parameters) + ")"
@property
def canonical_name(self) -> str:
@ -216,7 +217,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Return the function signature without the return values
"""
name, parameters, _ = self.signature
return self.contract_declarer.name + "." + name + "(" + ",".join(parameters) + ")"
return ".".join([self.contract_declarer.name] + self._scope + [name]) + "(" + ",".join(parameters) + ")"
@property
def contains_assembly(self) -> bool:

@ -90,7 +90,14 @@ class SolidityVariable(Context):
# dev function, will be removed once the code is stable
def _check_name(self, name: str):
assert name in SOLIDITY_VARIABLES
assert name in SOLIDITY_VARIABLES or name.endswith("_slot") or name.endswith("_offset")
@property
def state_variable(self):
if self._name.endswith("_slot"):
return self._name[:-5]
if self._name.endswith("_offset"):
return self._name[:-7]
@property
def name(self) -> str:

@ -31,6 +31,12 @@ class BinaryOperationType(Enum):
ANDAND = 17 # &&
OROR = 18 # ||
DIVISION_SIGNED = 19
MODULO_SIGNED = 20
LESS_SIGNED = 21
GREATER_SIGNED = 22
RIGHT_SHIFT_ARITHMETIC = 23
@staticmethod
def get_type(operation_type: "BinaryOperation"):
if operation_type == "**":
@ -71,6 +77,16 @@ class BinaryOperationType(Enum):
return BinaryOperationType.ANDAND
if operation_type == "||":
return BinaryOperationType.OROR
if operation_type == "/'":
return BinaryOperationType.DIVISION_SIGNED
if operation_type == "%'":
return BinaryOperationType.MODULO_SIGNED
if operation_type == "<'":
return BinaryOperationType.LESS_SIGNED
if operation_type == ">'":
return BinaryOperationType.GREATER_SIGNED
if operation_type == ">>'":
return BinaryOperationType.RIGHT_SHIFT_ARITHMETIC
raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type))
@ -113,6 +129,16 @@ class BinaryOperationType(Enum):
return "&&"
if self == BinaryOperationType.OROR:
return "||"
if self == BinaryOperationType.DIVISION_SIGNED:
return "/'"
if self == BinaryOperationType.MODULO_SIGNED:
return "%'"
if self == BinaryOperationType.LESS_SIGNED:
return "<'"
if self == BinaryOperationType.GREATER_SIGNED:
return ">'"
if self == BinaryOperationType.RIGHT_SHIFT_ARITHMETIC:
return ">>'"
raise SlitherCoreError("str: Unknown operation type {})".format(self))

@ -30,7 +30,7 @@ class Assembly(AbstractDetector):
Returns:
(bool)
"""
return node.type == NodeType.ASSEMBLY
return node.type == NodeType.ASSEMBLY and len(node.yul_path) == 2
def detect_assembly(self, contract):
ret = []

@ -20,7 +20,7 @@ class PrinterSlithIR(AbstractPrinter):
txt = ""
for contract in self.contracts:
txt += 'Contract {}'.format(contract.name)
txt += 'Contract {}\n'.format(contract.name)
for function in contract.functions:
txt += f'\tFunction {function.canonical_name} {"" if function.is_shadowed else "(*)"}\n'
for node in function.nodes:

@ -10,6 +10,7 @@ from slither.core.solidity_types import (ArrayType, ElementaryType,
from slither.core.solidity_types.elementary_type import Int as ElementaryTypeInt
from slither.core.variables.variable import Variable
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations.codesize import CodeSize
from slither.slithir.variables import TupleVariable
from slither.slithir.operations import (Assignment, Balance, Binary,
BinaryType, Call, Condition, Delete,
@ -512,6 +513,12 @@ def propagate_types(ir, node):
b.set_expression(ir.expression)
b.set_node(ir.node)
return b
if ir.variable_right == 'codesize' and not isinstance(ir.variable_left, Contract) and isinstance(
ir.variable_left.type, ElementaryType):
b = CodeSize(ir.variable_left, ir.lvalue)
b.set_expression(ir.expression)
b.set_node(ir.node)
return b
if ir.variable_right == 'selector' and isinstance(ir.variable_left.type, Function):
assignment = Assignment(ir.lvalue,
Constant(str(get_function_id(ir.variable_left.type.full_name))),

@ -31,6 +31,12 @@ class BinaryType(Enum):
ANDAND = 17 # &&
OROR = 18 # ||
DIVISION_SIGNED = 19
MODULO_SIGNED = 20
LESS_SIGNED = 21
GREATER_SIGNED = 22
RIGHT_SHIFT_ARITHMETIC = 23
@staticmethod
def return_bool(operation_type):
return operation_type in [BinaryType.OROR,
@ -82,6 +88,16 @@ class BinaryType(Enum):
return BinaryType.ANDAND
if operation_type == '||':
return BinaryType.OROR
if operation_type == "/'":
return BinaryType.DIVISION_SIGNED
if operation_type == "%'":
return BinaryType.MODULO_SIGNED
if operation_type == "<'":
return BinaryType.LESS_SIGNED
if operation_type == ">'":
return BinaryType.GREATER_SIGNED
if operation_type == ">>'":
return BinaryType.RIGHT_SHIFT_ARITHMETIC
raise SlithIRError('get_type: Unknown operation type {})'.format(operation_type))
@ -124,6 +140,16 @@ class BinaryType(Enum):
return "&&"
if self == BinaryType.OROR:
return "||"
if self == BinaryType.DIVISION_SIGNED:
return "/'"
if self == BinaryType.MODULO_SIGNED:
return "%'"
if self == BinaryType.LESS_SIGNED:
return "<'"
if self == BinaryType.GREATER_SIGNED:
return ">'"
if self == BinaryType.RIGHT_SHIFT_ARITHMETIC:
return ">>'"
raise SlithIRError("str: Unknown operation type {} {})".format(self, type(self)))

@ -0,0 +1,24 @@
from slither.core.solidity_types import ElementaryType
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue
class CodeSize(OperationWithLValue):
def __init__(self, value, lvalue):
assert is_valid_rvalue(value)
assert is_valid_lvalue(lvalue)
self._value = value
self._lvalue = lvalue
lvalue.set_type(ElementaryType('uint256'))
@property
def read(self):
return [self._value]
@property
def value(self):
return self._value
def __str__(self):
return "{} -> CODESIZE {}".format(self.lvalue, self.value)

@ -17,6 +17,7 @@ from slither.slithir.operations import (Assignment, Balance, Binary, Condition,
Push, Return, Send, SolidityCall,
Transfer, TypeConversion, Unary,
Unpack, Nop)
from slither.slithir.operations.codesize import CodeSize
from slither.slithir.variables import (Constant, LocalIRVariable,
ReferenceVariable, ReferenceVariableSSA,
StateIRVariable, TemporaryVariable,
@ -527,6 +528,10 @@ def copy_ir(ir, *instances):
variable_right = get_variable(ir, lambda x: x.variable_right, *instances)
operation_type = ir.type
return Binary(lvalue, variable_left, variable_right, operation_type)
elif isinstance(ir, CodeSize):
lvalue = get_variable(ir, lambda x: x.lvalue, *instances)
value = get_variable(ir, lambda x: x.value, *instances)
return CodeSize(value, lvalue)
elif isinstance(ir, Condition):
val = get_variable(ir, lambda x: x.value, *instances)
return Condition(val)

@ -9,6 +9,7 @@ from slither.core.expressions.assignment_operation import (
from slither.core.expressions.identifier import Identifier
from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.visitors.expression.find_calls import FindCalls
from slither.solc_parsing.yul.parse_yul import parse_yul
from slither.visitors.expression.read_var import ReadVar
from slither.visitors.expression.write_var import WriteVar
@ -17,15 +18,57 @@ class NodeSolc:
def __init__(self, node: Node):
self._unparsed_expression: Optional[Dict] = None
self._node = node
self._unparsed_yul_expression = None
"""
todo this should really go somewhere else, but until
that happens I'm setting it to None for performance
"""
self._yul_local_variables = None
self._yul_local_functions = None
self._yul_path = None
@property
def underlying_node(self) -> Node:
return self._node
def set_yul_root(self, func):
self._yul_path = [func.name, f"asm_{func._counter_asm_nodes}"]
def set_yul_child(self, parent, cur):
self._yul_path = parent.yul_path + [cur]
@property
def yul_path(self):
return self._yul_path
def format_canonical_yul_name(self, name, off=None):
return ".".join(self._yul_path[:off] + [name])
def add_yul_local_variable(self, var):
if not self._yul_local_variables:
self._yul_local_variables = []
self._yul_local_variables.append(var)
def get_yul_local_variable_from_name(self, variable_name):
return next((v for v in self._yul_local_variables if v.name == variable_name), None)
def add_yul_local_function(self, func):
if not self._yul_local_functions:
self._yul_local_functions = []
self._yul_local_functions.append(func)
def get_yul_local_function_from_name(self, func_name):
return next((v for v in self._yul_local_functions if v.name == func_name), None)
def add_unparsed_expression(self, expression: Dict):
assert self._unparsed_expression is None
self._unparsed_expression = expression
def add_unparsed_yul_expression(self, root, expression):
assert self._unparsed_expression is None
self._unparsed_yul_expression = (root, expression)
def analyze_expressions(self, caller_context):
if self._node.type == NodeType.VARIABLE and not self._node.expression:
self._node.add_expression(self._node.variable_declaration.expression)
@ -34,6 +77,11 @@ class NodeSolc:
self._node.add_expression(expression)
# self._unparsed_expression = None
if self._unparsed_yul_expression:
expression = parse_yul(self._unparsed_yul_expression[0], self, self._unparsed_yul_expression[1])
self._expression = expression
self._unparsed_yul_expression = None
if self._node.expression:
if self._node.type == NodeType.VARIABLE:

@ -18,6 +18,7 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import (
LocalVariableInitFromTupleSolc,
)
from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration
from slither.solc_parsing.yul.parse_yul import convert_yul
from slither.utils.expression_manipulations import SplitTernaryExpression
from slither.visitors.expression.export_values import ExportValues
from slither.visitors.expression.has_conditional import HasConditional
@ -62,6 +63,8 @@ class FunctionSolc:
self._functionNotParsed = function_data
self._params_was_analyzed = False
self._content_was_analyzed = False
self._counter_nodes = 0
self._counter_asm_nodes = 0
self._counter_scope_local_variables = 0
# variable renamed will map the solc id
@ -312,6 +315,9 @@ class FunctionSolc:
self._node_to_nodesolc[node] = node_parser
return node_parser
def node_solc(self):
return NodeSolc
# endregion
###################################################################################
###################################################################################
@ -797,13 +803,23 @@ class FunctionSolc:
elif name == "Block":
node = self._parse_block(statement, node)
elif name == "InlineAssembly":
asm_node = self._new_node(NodeType.ASSEMBLY, statement["src"])
self._function.contains_assembly = True
# Added with solc 0.4.12
if "operations" in statement:
asm_node.underlying_node.add_inline_asm(statement["operations"])
link_underlying_nodes(node, asm_node)
node = asm_node
# Added with solc 0.6 - the yul code is an AST
if 'AST' in statement:
self._contains_assembly = True
yul_root = self._new_node(NodeType.ASSEMBLY, statement['src'])
yul_root.set_yul_root(self)
link_underlying_nodes(node, yul_root)
self._counter_asm_nodes += 1
node = convert_yul(yul_root, yul_root, statement['AST'])
else:
asm_node = self._new_node(NodeType.ASSEMBLY, statement['src'])
self._function._contains_assembly = True
# Added with solc 0.4.12
if 'operations' in statement:
asm_node.underlying_node.add_inline_asm(statement['operations'])
link_underlying_nodes(node, asm_node)
node = asm_node
elif name == "DoWhileStatement":
node = self._parse_dowhile(statement, node)
# For Continue / Break / Return / Throw

@ -0,0 +1,268 @@
from slither.core.declarations.solidity_variables import SOLIDITY_FUNCTIONS
from slither.core.expressions import BinaryOperationType, UnaryOperationType
# taken from https://github.com/ethereum/solidity/blob/356cc91084114f840da66804b2a9fc1ac2846cff/libevmasm/Instruction.cpp#L180
evm_opcodes = [
"STOP",
"ADD",
"SUB",
"MUL",
"DIV",
"SDIV",
"MOD",
"SMOD",
"EXP",
"NOT",
"LT",
"GT",
"SLT",
"SGT",
"EQ",
"ISZERO",
"AND",
"OR",
"XOR",
"BYTE",
"SHL",
"SHR",
"SAR",
"ADDMOD",
"MULMOD",
"SIGNEXTEND",
"KECCAK256",
"ADDRESS",
"BALANCE",
"ORIGIN",
"CALLER",
"CALLVALUE",
"CALLDATALOAD",
"CALLDATASIZE",
"CALLDATACOPY",
"CODESIZE",
"CODECOPY",
"GASPRICE",
"EXTCODESIZE",
"EXTCODECOPY",
"RETURNDATASIZE",
"RETURNDATACOPY",
"EXTCODEHASH",
"BLOCKHASH",
"COINBASE",
"TIMESTAMP",
"NUMBER",
"DIFFICULTY",
"GASLIMIT",
"CHAINID",
"SELFBALANCE",
"POP",
"MLOAD",
"MSTORE",
"MSTORE8",
"SLOAD",
"SSTORE",
"JUMP",
"JUMPI",
"PC",
"MSIZE",
"GAS",
"JUMPDEST",
"PUSH1",
"PUSH2",
"PUSH3",
"PUSH4",
"PUSH5",
"PUSH6",
"PUSH7",
"PUSH8",
"PUSH9",
"PUSH10",
"PUSH11",
"PUSH12",
"PUSH13",
"PUSH14",
"PUSH15",
"PUSH16",
"PUSH17",
"PUSH18",
"PUSH19",
"PUSH20",
"PUSH21",
"PUSH22",
"PUSH23",
"PUSH24",
"PUSH25",
"PUSH26",
"PUSH27",
"PUSH28",
"PUSH29",
"PUSH30",
"PUSH31",
"PUSH32",
"DUP1",
"DUP2",
"DUP3",
"DUP4",
"DUP5",
"DUP6",
"DUP7",
"DUP8",
"DUP9",
"DUP10",
"DUP11",
"DUP12",
"DUP13",
"DUP14",
"DUP15",
"DUP16",
"SWAP1",
"SWAP2",
"SWAP3",
"SWAP4",
"SWAP5",
"SWAP6",
"SWAP7",
"SWAP8",
"SWAP9",
"SWAP10",
"SWAP11",
"SWAP12",
"SWAP13",
"SWAP14",
"SWAP15",
"SWAP16",
"LOG0",
"LOG1",
"LOG2",
"LOG3",
"LOG4",
"CREATE",
"CALL",
"CALLCODE",
"STATICCALL",
"RETURN",
"DELEGATECALL",
"CREATE2",
"REVERT",
"INVALID",
"SELFDESTRUCT",
]
yul_funcs = [
"datasize",
"dataoffset",
"datacopy",
"setimmutable",
"loadimmutable",
]
builtins = [x.lower() for x in evm_opcodes if not (
x.startswith("PUSH") or
x.startswith("SWAP") or
x.startswith("DUP") or
x == "JUMP" or
x == "JUMPI" or
x == "JUMPDEST"
)] + yul_funcs
function_args = {
'byte': [2, 1],
'addmod': [3, 1],
'mulmod': [3, 1],
'signextend': [2, 1],
'keccak256': [2, 1],
'pc': [0, 1],
'pop': [1, 0],
'mload': [1, 1],
'mstore': [2, 0],
'mstore8': [2, 0],
'sload': [1, 1],
'sstore': [2, 0],
'msize': [1, 1],
'gas': [0, 1],
'address': [0, 1],
'balance': [1, 1],
'selfbalance': [0, 1],
'caller': [0, 1],
'callvalue': [0, 1],
'calldataload': [1, 1],
'calldatasize': [0, 1],
'calldatacopy': [3, 0],
'codesize': [0, 1],
'codecopy': [3, 0],
'extcodesize': [1, 1],
'extcodecopy': [4, 0],
'returndatasize': [0, 1],
'returndatacopy': [3, 0],
'extcodehash': [1, 1],
'create': [3, 1],
'create2': [4, 1],
'call': [7, 1],
'callcode': [7, 1],
'delegatecall': [6, 1],
'staticcall': [6, 1],
'return': [2, 0],
'revert': [2, 0],
'selfdestruct': [1, 0],
'invalid': [0, 0],
'log0': [2, 0],
'log1': [3, 0],
'log2': [4, 0],
'log3': [5, 0],
'log4': [6, 0],
'chainid': [0, 1],
'origin': [0, 1],
'gasprice': [0, 1],
'blockhash': [1, 1],
'coinbase': [0, 1],
'timestamp': [0, 1],
'number': [0, 1],
'difficulty': [0, 1],
'gaslimit': [0, 1],
}
def format_function_descriptor(name):
if name not in function_args:
return name + "()"
return name + "(" + ",".join(["uint256"] * function_args[name][0]) + ")"
for k, v in function_args.items():
SOLIDITY_FUNCTIONS[format_function_descriptor(k)] = ['uint256'] * v[1]
unary_ops = {
'not': UnaryOperationType.TILD,
'iszero': UnaryOperationType.BANG,
}
binary_ops = {
'add': BinaryOperationType.ADDITION,
'sub': BinaryOperationType.SUBTRACTION,
'mul': BinaryOperationType.MULTIPLICATION,
'div': BinaryOperationType.DIVISION,
'sdiv': BinaryOperationType.DIVISION_SIGNED,
'mod': BinaryOperationType.MODULO,
'smod': BinaryOperationType.MODULO_SIGNED,
'exp': BinaryOperationType.POWER,
'lt': BinaryOperationType.LESS,
'gt': BinaryOperationType.GREATER,
'slt': BinaryOperationType.LESS_SIGNED,
'sgt': BinaryOperationType.GREATER_SIGNED,
'eq': BinaryOperationType.EQUAL,
'and': BinaryOperationType.AND,
'or': BinaryOperationType.OR,
'xor': BinaryOperationType.CARET,
'shl': BinaryOperationType.LEFT_SHIFT,
'shr': BinaryOperationType.RIGHT_SHIFT,
'sar': BinaryOperationType.RIGHT_SHIFT_ARITHMETIC,
}
class YulBuiltin:
def __init__(self, name):
self._name = name
@property
def name(self):
return self._name

@ -0,0 +1,525 @@
import json
from slither.core.cfg.node import NodeType, link_nodes
from slither.core.declarations import Function, SolidityFunction, SolidityVariable
from slither.core.expressions import (
Literal,
AssignmentOperation,
AssignmentOperationType,
Identifier, CallExpression, TupleExpression, BinaryOperation, UnaryOperation,
)
from slither.core.solidity_types import ElementaryType
from slither.core.variables.local_variable import LocalVariable
from slither.exceptions import SlitherException
from slither.solc_parsing.yul.evm_functions import *
class YulLocalVariable(LocalVariable):
def __init__(self, ast):
super(LocalVariable, self).__init__()
assert (ast['nodeType'] == 'YulTypedName')
self._name = ast['name']
self._type = ElementaryType('uint256')
self._location = 'memory'
class YulFunction(Function):
def __init__(self, ast, root):
super(YulFunction, self).__init__()
assert (ast['nodeType'] == 'YulFunctionDefinition')
self._contract = root.function.contract
self._contract_declarer = root.function.contract_declarer
self._name = ast['name']
self._scope = root.yul_path
self._counter_nodes = 0
self._is_implemented = True
self._contains_assembly = True
self._node_solc = root.function.node_solc()
self._entry_point = self.new_node(NodeType.ASSEMBLY, ast['src'])
self._entry_point.set_yul_child(root, ast['name'])
self._ast = ast
self.set_offset(ast['src'], root.function.slither)
def convert_body(self):
node = self.new_node(NodeType.ENTRYPOINT, self._ast['src'])
link_nodes(self.entry_point, node)
for param in self._ast.get('parameters', []):
node = convert_yul(self.entry_point, node, param)
self._parameters.append(self.entry_point.get_yul_local_variable_from_name(param['name']))
for ret in self._ast.get('returnVariables', []):
node = convert_yul(self.entry_point, node, ret)
self._returns.append(self.entry_point.get_yul_local_variable_from_name(ret['name']))
convert_yul(self.entry_point, node, self._ast['body'])
def parse_body(self):
for node in self.nodes:
node.analyze_expressions(self)
def node_solc(self):
return self._node_solc
def new_node(self, node_type, src):
node = self._node_solc(node_type, self._counter_nodes)
node.set_offset(src, self.slither)
node.set_function(self)
self._counter_nodes += 1
self._nodes.append(node)
return node
###################################################################################
###################################################################################
# region Block conversion
###################################################################################
###################################################################################
"""
The functions in this region, at a high level, will extract the control flow
structures and metadata from the input AST. These include things like function
definitions and local variables.
Each function takes three parameters:
1) root is a NodeSolc of NodeType.ASSEMBLY, and stores information at the
local scope. In Yul, variables are scoped to the function they're
declared in (except for variables outside the assembly block)
2) parent is the last node in the CFG. new nodes should be linked against
this node
3) ast is a dictionary and is the current node in the Yul ast being converted
Each function must return a single parameter:
1) A NodeSolc representing the new end of the CFG
The entrypoint is the function at the end of this region, `convert_yul`, which
dispatches to a specialized function based on a lookup dictionary.
"""
def convert_yul_block(root, parent, ast):
for statement in ast["statements"]:
parent = convert_yul(root, parent, statement)
return parent
def convert_yul_function_definition(root, parent, ast):
f = YulFunction(ast, root)
root.function.contract._functions[root.format_canonical_yul_name(f.name)] = f
f.convert_body()
f.parse_body()
return parent
def convert_yul_variable_declaration(root, parent, ast):
for variable_ast in ast['variables']:
parent = convert_yul(root, parent, variable_ast)
node = parent.function.new_node(NodeType.EXPRESSION, ast["src"])
node.add_unparsed_yul_expression(root, ast)
link_nodes(parent, node)
return node
def convert_yul_assignment(root, parent, ast):
node = parent.function.new_node(NodeType.EXPRESSION, ast["src"])
node.add_unparsed_yul_expression(root, ast)
link_nodes(parent, node)
return node
def convert_yul_expression_statement(root, parent, ast):
src = ast['src']
expression_ast = ast['expression']
expression = parent.function.new_node(NodeType.EXPRESSION, src)
expression.add_unparsed_yul_expression(root, expression_ast)
link_nodes(parent, expression)
return expression
def convert_yul_if(root, parent, ast):
# we're cheating and pretending that yul supports if/else so we can convert switch cleaner
src = ast['src']
condition_ast = ast['condition']
true_body_ast = ast['body']
false_body_ast = ast['false_body'] if 'false_body' in ast else None
condition = parent.function.new_node(NodeType.IF, src)
end = parent.function.new_node(NodeType.ENDIF, src)
condition.add_unparsed_yul_expression(root, condition_ast)
true_body = convert_yul(root, condition, true_body_ast)
if false_body_ast:
false_body = convert_yul(root, condition, false_body_ast)
link_nodes(false_body, end)
else:
link_nodes(condition, end)
link_nodes(parent, condition)
link_nodes(true_body, end)
return end
def convert_yul_switch(root, parent, ast):
"""
This is unfortunate. We don't really want a switch in our IR so we're going to
translate it into a series of if/else statements.
"""
cases_ast = ast['cases']
expression_ast = ast['expression']
# this variable stores the result of the expression so we don't accidentally compute it more than once
switch_expr_var = 'switch_expr_{}'.format(ast['src'].replace(':', '_'))
rewritten_switch = {
'nodeType': 'YulBlock',
'src': ast['src'],
'statements': [
{
'nodeType': 'YulVariableDeclaration',
'src': expression_ast['src'],
'variables': [
{
'nodeType': 'YulTypedName',
'src': expression_ast['src'],
'name': switch_expr_var,
'type': '',
},
],
'value': expression_ast,
},
],
}
last_if = None
default_ast = None
for case_ast in cases_ast:
body_ast = case_ast['body']
value_ast = case_ast['value']
if value_ast == 'default':
default_ast = case_ast
continue
current_if = {
'nodeType': 'YulIf',
'src': case_ast['src'],
'condition': {
'nodeType': 'YulFunctionCall',
'src': case_ast['src'],
'functionName': {
'nodeType': 'YulIdentifier',
'src': case_ast['src'],
'name': 'eq',
},
'arguments': [
{
'nodeType': 'YulIdentifier',
'src': case_ast['src'],
'name': switch_expr_var,
},
value_ast,
]
},
'body': body_ast,
}
if last_if:
last_if['false_body'] = current_if
else:
rewritten_switch['statements'].append(current_if)
last_if = current_if
if default_ast:
body_ast = default_ast['body']
if last_if:
last_if['false_body'] = body_ast
else:
rewritten_switch['statements'].append(body_ast)
return convert_yul(root, parent, rewritten_switch)
def convert_yul_for_loop(root, parent, ast):
pre_ast = ast['pre']
condition_ast = ast['condition']
post_ast = ast['post']
body_ast = ast['body']
start_loop = parent.function.new_node(NodeType.STARTLOOP, ast['src'])
end_loop = parent.function.new_node(NodeType.ENDLOOP, ast['src'])
link_nodes(parent, start_loop)
pre = convert_yul(root, start_loop, pre_ast)
condition = parent.function.new_node(NodeType.IFLOOP, condition_ast['src'])
condition.add_unparsed_yul_expression(root, condition_ast)
link_nodes(pre, condition)
link_nodes(condition, end_loop)
body = convert_yul(root, condition, body_ast)
post = convert_yul(root, body, post_ast)
link_nodes(post, condition)
return end_loop
def convert_yul_break(root, parent, ast):
break_ = parent.function.new_node(NodeType.BREAK, ast['src'])
link_nodes(parent, break_)
return break_
def convert_yul_continue(root, parent, ast):
continue_ = parent.function.new_node(NodeType.CONTINUE, ast['src'])
link_nodes(parent, continue_)
return continue_
def convert_yul_leave(root, parent, ast):
leave = parent.function.new_node(NodeType.RETURN, ast['src'])
link_nodes(parent, leave)
return leave
def convert_yul_typed_name(root, parent, ast):
var = YulLocalVariable(ast)
var.set_function(root.function)
var.set_offset(ast['src'], root.slither)
root.add_yul_local_variable(var)
node = parent.function.new_node(NodeType.VARIABLE, ast['src'])
node.add_variable_declaration(var)
link_nodes(parent, node)
return node
def convert_yul_unsupported(root, parent, ast):
raise SlitherException(f"no converter available for {ast['nodeType']} {json.dumps(ast, indent=2)}")
def convert_yul(root, parent, ast):
return converters.get(ast['nodeType'], convert_yul_unsupported)(root, parent, ast)
converters = {
'YulBlock': convert_yul_block,
'YulFunctionDefinition': convert_yul_function_definition,
'YulVariableDeclaration': convert_yul_variable_declaration,
'YulAssignment': convert_yul_assignment,
'YulExpressionStatement': convert_yul_expression_statement,
'YulIf': convert_yul_if,
'YulSwitch': convert_yul_switch,
'YulForLoop': convert_yul_for_loop,
'YulBreak': convert_yul_break,
'YulContinue': convert_yul_continue,
'YulLeave': convert_yul_leave,
'YulTypedName': convert_yul_typed_name,
}
# endregion
###################################################################################
###################################################################################
###################################################################################
###################################################################################
# region Expression parsing
###################################################################################
###################################################################################
"""
The functions in this region parse the AST into expressions.
Each function takes three parameters:
1) root is the same root as above
2) node is the CFG node which stores this expression
3) ast is the same ast as above
Each function must return a single parameter:
1) The operation that was parsed, or None
The entrypoint is the function at the end of this region, `parse_yul`, which
dispatches to a specialized function based on a lookup dictionary.
"""
def _parse_yul_assignment_common(root, node, ast, key):
lhs = [parse_yul(root, node, arg) for arg in ast[key]]
rhs = parse_yul(root, node, ast['value'])
return AssignmentOperation(vars_to_val(lhs), rhs, AssignmentOperationType.ASSIGN, vars_to_typestr(lhs))
def parse_yul_variable_declaration(root, node, ast):
"""
We already created variables in the conversion phase, so just do
the assignment
"""
if not ast['value']:
return None
return _parse_yul_assignment_common(root, node, ast, 'variables')
def parse_yul_assignment(root, node, ast):
return _parse_yul_assignment_common(root, node, ast, 'variableNames')
def parse_yul_function_call(root, node, ast):
args = [parse_yul(root, node, arg) for arg in ast['arguments']]
ident = parse_yul(root, node, ast['functionName'])
if isinstance(ident.value, YulBuiltin):
name = ident.value.name
if name in binary_ops:
if name in ['shl', 'shr', 'sar']:
# lmao ok
return BinaryOperation(args[1], args[0], binary_ops[name])
return BinaryOperation(args[0], args[1], binary_ops[name])
if name in unary_ops:
return UnaryOperation(args[0], unary_ops[name])
ident = Identifier(SolidityFunction(format_function_descriptor(ident.value.name)))
if isinstance(ident.value, Function):
return CallExpression(ident, args, vars_to_typestr(ident.value.returns))
elif isinstance(ident.value, SolidityFunction):
return CallExpression(ident, args, vars_to_typestr(ident.value.return_type))
else:
raise SlitherException(f"unexpected function call target type {str(type(ident.value))}")
def parse_yul_identifier(root, node, ast):
name = ast['name']
if name in builtins:
return Identifier(YulBuiltin(name))
# check function-scoped variables
variable = root.function.get_local_variable_from_name(name)
if variable:
return Identifier(variable)
# check yul-scoped variable
variable = root.get_yul_local_variable_from_name(name)
if variable:
return Identifier(variable)
# check yul-scoped function
# note that a function can recurse into itself, so we have two canonical names
# to check (but only one of them can be valid)
functions = root.function.contract_declarer._functions
canonical_name = root.format_canonical_yul_name(name)
if canonical_name in functions:
return Identifier(functions[canonical_name])
canonical_name = root.format_canonical_yul_name(name, -1)
if canonical_name in functions:
return Identifier(functions[canonical_name])
# check for magic suffixes
if name.endswith("_slot"):
potential_name = name[:-5]
var = root.function.contract.get_state_variable_from_name(potential_name)
if var:
return Identifier(SolidityVariable(name))
if name.endswith("_offset"):
potential_name = name[:-7]
var = root.function.contract.get_state_variable_from_name(potential_name)
if var:
return Identifier(SolidityVariable(name))
raise SlitherException(f"unresolved reference to identifier {name}")
def parse_yul_literal(root, node, ast):
type_ = ast['type']
value = ast['value']
if not type_:
type_ = 'bool' if value in ['true', 'false'] else 'uint256'
return Literal(value, ElementaryType(type_))
def parse_yul_typed_name(root, node, ast):
var = root.get_yul_local_variable_from_name(ast['name'])
i = Identifier(var)
i._type = var.type
return i
def parse_yul_unsupported(root, node, ast):
raise SlitherException(f"no parser available for {ast['nodeType']} {json.dumps(ast, indent=2)}")
def parse_yul(root, node, ast):
op = parsers.get(ast['nodeType'], parse_yul_unsupported)(root, node, ast)
if op:
op.set_offset(ast["src"], root.slither)
return op
parsers = {
'YulVariableDeclaration': parse_yul_variable_declaration,
'YulAssignment': parse_yul_assignment,
'YulFunctionCall': parse_yul_function_call,
'YulIdentifier': parse_yul_identifier,
'YulTypedName': parse_yul_typed_name,
'YulLiteral': parse_yul_literal,
}
# endregion
###################################################################################
###################################################################################
def vars_to_typestr(rets):
if len(rets) == 0:
return ""
if len(rets) == 1:
return str(rets[0].type)
return "tuple({})".format(",".join(str(ret.type) for ret in rets))
def vars_to_val(vars):
if len(vars) == 1:
return vars[0]
return TupleExpression(vars)

@ -1,9 +1,9 @@
import logging
from slither.core.declarations import Function
from slither.core.declarations import Function, SolidityVariable, SolidityVariableComposed
from slither.core.expressions import (AssignmentOperationType,
UnaryOperationType, BinaryOperationType)
from slither.core.solidity_types import ArrayType
from slither.core.solidity_types import ArrayType, ElementaryType
from slither.core.solidity_types.type import Type
from slither.slithir.operations import (Assignment, Binary, BinaryType, Delete,
Index, InitArray, InternalCall, Member,
@ -167,27 +167,63 @@ class ExpressionToSlithIR(ExpressionVisitor):
self._result.append(internal_call)
set_val(expression, val)
else:
# If tuple
if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()':
val = TupleVariable(self._node)
else:
# yul things
if called.name == 'caller()':
val = TemporaryVariable(self._node)
message_call = TmpCall(called, len(args), val, expression.type_call)
message_call.set_expression(expression)
# Gas/value are only accessible here if the syntax {gas: , value: }
# Is used over .gas().value()
if expression.call_gas:
call_gas = get(expression.call_gas)
message_call.call_gas = call_gas
if expression.call_value:
call_value = get(expression.call_value)
message_call.call_value = call_value
if expression.call_salt:
call_salt = get(expression.call_salt)
message_call.call_salt = call_salt
self._result.append(message_call)
set_val(expression, val)
var = Assignment(val, SolidityVariableComposed('msg.sender'), 'uint256')
self._result.append(var)
set_val(expression, val)
elif called.name == 'origin()':
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed('tx.origin'), 'uint256')
self._result.append(var)
set_val(expression, val)
elif called.name == 'extcodesize(uint256)':
val = ReferenceVariable(self._node)
var = Member(args[0], Constant('codesize'), val)
self._result.append(var)
set_val(expression, val)
elif called.name == 'selfbalance()':
val = TemporaryVariable(self._node)
var = TypeConversion(val, SolidityVariable('this'), ElementaryType('address'))
self._result.append(var)
val1 = ReferenceVariable(self._node)
var1 = Member(val, Constant('balance'), val1)
self._result.append(var1)
set_val(expression, val1)
elif called.name == 'address()':
val = TemporaryVariable(self._node)
var = TypeConversion(val, SolidityVariable('this'), ElementaryType('address'))
self._result.append(var)
set_val(expression, val)
elif called.name == 'callvalue()':
val = TemporaryVariable(self._node)
var = Assignment(val, SolidityVariableComposed('msg.value'), 'uint256')
self._result.append(var)
set_val(expression, val)
else:
# If tuple
if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()':
val = TupleVariable(self._node)
else:
val = TemporaryVariable(self._node)
message_call = TmpCall(called, len(args), val, expression.type_call)
message_call.set_expression(expression)
# Gas/value are only accessible here if the syntax {gas: , value: }
# Is used over .gas().value()
if expression.call_gas:
call_gas = get(expression.call_gas)
message_call.call_gas = call_gas
if expression.call_value:
call_value = get(expression.call_value)
message_call.call_value = call_value
if expression.call_salt:
call_salt = get(expression.call_salt)
message_call.call_salt = call_salt
self._result.append(message_call)
set_val(expression, val)
def _post_conditional_expression(self, expression):
raise Exception('Ternary operator are not convertible to SlithIR {}'.format(expression))

Loading…
Cancel
Save