Merge pull request #1559 from crytic/dev-fix-yul-parsing

Improve yul parsing
pull/1573/head
Feist Josselin 2 years ago committed by GitHub
commit aee2a786e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      slither/core/children/child_function.py
  2. 4
      slither/core/expressions/identifier.py
  3. 4
      slither/core/variables/local_variable.py
  4. 1
      slither/solc_parsing/declarations/function.py
  5. 4
      slither/solc_parsing/yul/evm_functions.py
  6. 86
      slither/solc_parsing/yul/parse_yul.py
  7. BIN
      tests/ast-parsing/compile/yul-top-level-0.8.0.sol-0.8.0-compact.zip
  8. 16
      tests/ast-parsing/yul-top-level-0.8.0.sol
  9. 1
      tests/test_ast_parsing.py

@ -5,11 +5,11 @@ if TYPE_CHECKING:
class ChildFunction: class ChildFunction:
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._function = None self._function = None
def set_function(self, function: "Function"): def set_function(self, function: "Function") -> None:
self._function = function self._function = function
@property @property

@ -7,7 +7,7 @@ if TYPE_CHECKING:
class Identifier(ExpressionTyped): class Identifier(ExpressionTyped):
def __init__(self, value): def __init__(self, value) -> None:
super().__init__() super().__init__()
self._value: "Variable" = value self._value: "Variable" = value
@ -15,5 +15,5 @@ class Identifier(ExpressionTyped):
def value(self) -> "Variable": def value(self) -> "Variable":
return self._value return self._value
def __str__(self): def __str__(self) -> str:
return str(self._value) return str(self._value)

@ -11,11 +11,11 @@ from slither.core.declarations.structure import Structure
class LocalVariable(ChildFunction, Variable): class LocalVariable(ChildFunction, Variable):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._location: Optional[str] = None self._location: Optional[str] = None
def set_location(self, loc: str): def set_location(self, loc: str) -> None:
self._location = loc self._location = loc
@property @property

@ -343,7 +343,6 @@ class FunctionSolc(CallerContextExpression):
node, node,
[self._function.name, f"asm_{len(self._node_to_yulobject)}"], [self._function.name, f"asm_{len(self._node_to_yulobject)}"],
scope, scope,
parent_func=self._function,
) )
self._node_to_yulobject[node] = yul_object self._node_to_yulobject[node] = yul_object
return yul_object return yul_object

@ -264,9 +264,9 @@ binary_ops = {
class YulBuiltin: # pylint: disable=too-few-public-methods class YulBuiltin: # pylint: disable=too-few-public-methods
def __init__(self, name): def __init__(self, name: str) -> None:
self._name = name self._name = name
@property @property
def name(self): def name(self) -> str:
return self._name return self._name

@ -24,6 +24,7 @@ from slither.core.expressions import (
UnaryOperation, UnaryOperation,
) )
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ElementaryType
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
@ -51,30 +52,35 @@ class YulNode:
def underlying_node(self) -> Node: def underlying_node(self) -> Node:
return self._node return self._node
def add_unparsed_expression(self, expression: Dict): def add_unparsed_expression(self, expression: Dict) -> None:
assert self._unparsed_expression is None assert self._unparsed_expression is None
self._unparsed_expression = expression self._unparsed_expression = expression
def analyze_expressions(self): def analyze_expressions(self) -> None:
if self._node.type == NodeType.VARIABLE and not self._node.expression: if self._node.type == NodeType.VARIABLE and not self._node.expression:
self._node.add_expression(self._node.variable_declaration.expression) expression = self._node.variable_declaration.expression
if expression:
self._node.add_expression(expression)
if self._unparsed_expression: if self._unparsed_expression:
expression = parse_yul(self._scope, self, self._unparsed_expression) expression = parse_yul(self._scope, self, self._unparsed_expression)
self._node.add_expression(expression) if expression:
self._node.add_expression(expression)
if self._node.expression: if self._node.expression:
if self._node.type == NodeType.VARIABLE: if self._node.type == NodeType.VARIABLE:
# Update the expression to be an assignement to the variable # Update the expression to be an assignement to the variable
_expression = AssignmentOperation( variable_declaration = self._node.variable_declaration
Identifier(self._node.variable_declaration), if variable_declaration:
self._node.expression, _expression = AssignmentOperation(
AssignmentOperationType.ASSIGN, Identifier(self._node.variable_declaration),
self._node.variable_declaration.type, self._node.expression,
) AssignmentOperationType.ASSIGN,
_expression.set_offset( variable_declaration.type,
self._node.expression.source_mapping, self._node.compilation_unit )
) _expression.set_offset(
self._node.add_expression(_expression, bypass_verif_empty=True) self._node.expression.source_mapping, self._node.compilation_unit
)
self._node.add_expression(_expression, bypass_verif_empty=True)
expression = self._node.expression expression = self._node.expression
read_var = ReadVar(expression) read_var = ReadVar(expression)
@ -122,13 +128,13 @@ class YulScope(metaclass=abc.ABCMeta):
] ]
def __init__( def __init__(
self, contract: Optional[Contract], yul_id: List[str], parent_func: Function = None self, contract: Optional[Contract], yul_id: List[str], parent_func: Function
): ) -> None:
self._contract = contract self._contract = contract
self._id: List[str] = yul_id self._id: List[str] = yul_id
self._yul_local_variables: List[YulLocalVariable] = [] self._yul_local_variables: List[YulLocalVariable] = []
self._yul_local_functions: List[YulFunction] = [] self._yul_local_functions: List[YulFunction] = []
self._parent_func = parent_func self._parent_func: Function = parent_func
@property @property
def id(self) -> List[str]: def id(self) -> List[str]:
@ -155,10 +161,14 @@ class YulScope(metaclass=abc.ABCMeta):
def new_node(self, node_type: NodeType, src: Union[str, Dict]) -> YulNode: def new_node(self, node_type: NodeType, src: Union[str, Dict]) -> YulNode:
pass pass
def add_yul_local_variable(self, var): @property
def file_scope(self) -> FileScope:
return self._parent_func.file_scope
def add_yul_local_variable(self, var: "YulLocalVariable") -> None:
self._yul_local_variables.append(var) self._yul_local_variables.append(var)
def get_yul_local_variable_from_name(self, variable_name): def get_yul_local_variable_from_name(self, variable_name: str) -> Optional["YulLocalVariable"]:
return next( return next(
( (
v v
@ -168,10 +178,10 @@ class YulScope(metaclass=abc.ABCMeta):
None, None,
) )
def add_yul_local_function(self, func): def add_yul_local_function(self, func: "YulFunction") -> None:
self._yul_local_functions.append(func) self._yul_local_functions.append(func)
def get_yul_local_function_from_name(self, func_name): def get_yul_local_function_from_name(self, func_name: str) -> Optional["YulLocalVariable"]:
return next( return next(
(v for v in self._yul_local_functions if v.underlying.name == func_name), (v for v in self._yul_local_functions if v.underlying.name == func_name),
None, None,
@ -242,7 +252,7 @@ class YulFunction(YulScope):
def function(self) -> Function: def function(self) -> Function:
return self._function return self._function
def convert_body(self): def convert_body(self) -> None:
node = self.new_node(NodeType.ENTRYPOINT, self._ast["src"]) node = self.new_node(NodeType.ENTRYPOINT, self._ast["src"])
link_underlying_nodes(self._entrypoint, node) link_underlying_nodes(self._entrypoint, node)
@ -258,7 +268,7 @@ class YulFunction(YulScope):
convert_yul(self, node, self._ast["body"], self.node_scope) convert_yul(self, node, self._ast["body"], self.node_scope)
def parse_body(self): def parse_body(self) -> None:
for node in self._nodes: for node in self._nodes:
node.analyze_expressions() node.analyze_expressions()
@ -289,9 +299,8 @@ class YulBlock(YulScope):
entrypoint: Node, entrypoint: Node,
yul_id: List[str], yul_id: List[str],
node_scope: Union[Scope, Function], node_scope: Union[Scope, Function],
**kwargs,
): ):
super().__init__(contract, yul_id, **kwargs) super().__init__(contract, yul_id, entrypoint.function)
self._entrypoint: YulNode = YulNode(entrypoint, self) self._entrypoint: YulNode = YulNode(entrypoint, self)
self._nodes: List[YulNode] = [] self._nodes: List[YulNode] = []
@ -318,7 +327,7 @@ class YulBlock(YulScope):
def convert(self, ast: Dict) -> YulNode: def convert(self, ast: Dict) -> YulNode:
return convert_yul(self, self._entrypoint, ast, self.node_scope) return convert_yul(self, self._entrypoint, ast, self.node_scope)
def analyze_expressions(self): def analyze_expressions(self) -> None:
for node in self._nodes: for node in self._nodes:
node.analyze_expressions() node.analyze_expressions()
@ -361,18 +370,22 @@ def convert_yul_function_definition(
while not isinstance(top_node_scope, Function): while not isinstance(top_node_scope, Function):
top_node_scope = top_node_scope.father top_node_scope = top_node_scope.father
func: Union[FunctionTopLevel, FunctionContract]
if isinstance(top_node_scope, FunctionTopLevel): if isinstance(top_node_scope, FunctionTopLevel):
scope = root.contract.file_scope scope = root.file_scope
func = FunctionTopLevel(root.compilation_unit, scope) func = FunctionTopLevel(root.compilation_unit, scope)
# Note: we do not add the function in the scope # Note: we do not add the function in the scope
# While its a top level function, it is not accessible outside of the function definition # While its a top level function, it is not accessible outside of the function definition
# In practice we should probably have a specific function type for function defined within a function # In practice we should probably have a specific function type for function defined within a function
else: else:
func = FunctionContract(root.compilation_unit) func = FunctionContract(root.compilation_unit)
func.function_language = FunctionLanguage.Yul func.function_language = FunctionLanguage.Yul
yul_function = YulFunction(func, root, ast, node_scope) yul_function = YulFunction(func, root, ast, node_scope)
root.contract.add_function(func) if root.contract:
root.contract.add_function(func)
root.compilation_unit.add_function(func) root.compilation_unit.add_function(func)
root.add_yul_local_function(yul_function) root.add_yul_local_function(yul_function)
@ -774,14 +787,15 @@ def parse_yul_identifier(root: YulScope, _node: YulNode, ast: Dict) -> Optional[
# check function-scoped variables # check function-scoped variables
parent_func = root.parent_func parent_func = root.parent_func
if parent_func: if parent_func:
variable = parent_func.get_local_variable_from_name(name) local_variable = parent_func.get_local_variable_from_name(name)
if variable: if local_variable:
return Identifier(variable) return Identifier(local_variable)
if isinstance(parent_func, FunctionContract): if isinstance(parent_func, FunctionContract):
variable = parent_func.contract.get_state_variable_from_name(name) assert parent_func.contract
if variable: state_variable = parent_func.contract.get_state_variable_from_name(name)
return Identifier(variable) if state_variable:
return Identifier(state_variable)
# check yul-scoped variable # check yul-scoped variable
variable = root.get_yul_local_variable_from_name(name) variable = root.get_yul_local_variable_from_name(name)
@ -798,7 +812,7 @@ def parse_yul_identifier(root: YulScope, _node: YulNode, ast: Dict) -> Optional[
if magic_suffix: if magic_suffix:
return magic_suffix return magic_suffix
ret, _ = find_top_level(name, root.contract.file_scope) ret, _ = find_top_level(name, root.file_scope)
if ret: if ret:
return Identifier(ret) return Identifier(ret)
@ -840,7 +854,7 @@ def parse_yul_unsupported(_root: YulScope, _node: YulNode, ast: Dict) -> Optiona
def parse_yul(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: def parse_yul(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]:
op = parsers.get(ast["nodeType"], parse_yul_unsupported)(root, node, ast) op: Expression = parsers.get(ast["nodeType"], parse_yul_unsupported)(root, node, ast)
if op: if op:
op.set_offset(ast["src"], root.compilation_unit) op.set_offset(ast["src"], root.compilation_unit)
return op return op

@ -0,0 +1,16 @@
function top_level_yul(int256 c) pure returns (uint result) {
assembly {
function internal_yul(a) -> b {
b := a
}
result := internal_yul(c)
}
}
contract Test {
function test() public{
top_level_yul(10);
}
}

@ -438,6 +438,7 @@ ALL_TESTS = [
Test("using-for-global-0.8.0.sol", ["0.8.15"]), Test("using-for-global-0.8.0.sol", ["0.8.15"]),
Test("library_event-0.8.16.sol", ["0.8.16"]), Test("library_event-0.8.16.sol", ["0.8.16"]),
Test("top-level-struct-0.8.0.sol", ["0.8.0"]), Test("top-level-struct-0.8.0.sol", ["0.8.0"]),
Test("yul-top-level-0.8.0.sol", ["0.8.0"]),
Test("complex_imports/import_aliases_issue_1319/test.sol", ["0.5.12"]), Test("complex_imports/import_aliases_issue_1319/test.sol", ["0.5.12"]),
] ]
# create the output folder if needed # create the output folder if needed

Loading…
Cancel
Save