diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py index cb4f0109f..5367320ca 100644 --- a/slither/core/children/child_function.py +++ b/slither/core/children/child_function.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: class ChildFunction: - def __init__(self): + def __init__(self) -> None: super().__init__() self._function = None - def set_function(self, function: "Function"): + def set_function(self, function: "Function") -> None: self._function = function @property diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index ab40472a4..0b10c5615 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: class Identifier(ExpressionTyped): - def __init__(self, value): + def __init__(self, value) -> None: super().__init__() self._value: "Variable" = value @@ -15,5 +15,5 @@ class Identifier(ExpressionTyped): def value(self) -> "Variable": return self._value - def __str__(self): + def __str__(self) -> str: return str(self._value) diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py index 5eb641fb4..7b7b4f8bc 100644 --- a/slither/core/variables/local_variable.py +++ b/slither/core/variables/local_variable.py @@ -11,11 +11,11 @@ from slither.core.declarations.structure import Structure class LocalVariable(ChildFunction, Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._location: Optional[str] = None - def set_location(self, loc: str): + def set_location(self, loc: str) -> None: self._location = loc @property diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 95f0ce9b1..6b8aca51e 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -343,7 +343,6 @@ class FunctionSolc(CallerContextExpression): node, [self._function.name, f"asm_{len(self._node_to_yulobject)}"], scope, - parent_func=self._function, ) self._node_to_yulobject[node] = yul_object return yul_object diff --git a/slither/solc_parsing/yul/evm_functions.py b/slither/solc_parsing/yul/evm_functions.py index 0276d4bf7..41c150765 100644 --- a/slither/solc_parsing/yul/evm_functions.py +++ b/slither/solc_parsing/yul/evm_functions.py @@ -264,9 +264,9 @@ binary_ops = { class YulBuiltin: # pylint: disable=too-few-public-methods - def __init__(self, name): + def __init__(self, name: str) -> None: self._name = name @property - def name(self): + def name(self) -> str: return self._name diff --git a/slither/solc_parsing/yul/parse_yul.py b/slither/solc_parsing/yul/parse_yul.py index 8c9ee427e..f7c9938fc 100644 --- a/slither/solc_parsing/yul/parse_yul.py +++ b/slither/solc_parsing/yul/parse_yul.py @@ -24,6 +24,7 @@ from slither.core.expressions import ( UnaryOperation, ) from slither.core.expressions.expression import Expression +from slither.core.scope.scope import FileScope from slither.core.solidity_types import ElementaryType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.local_variable import LocalVariable @@ -51,30 +52,35 @@ class YulNode: def underlying_node(self) -> Node: return self._node - def add_unparsed_expression(self, expression: Dict): + def add_unparsed_expression(self, expression: Dict) -> None: assert self._unparsed_expression is None self._unparsed_expression = expression - def analyze_expressions(self): + def analyze_expressions(self) -> None: if self._node.type == NodeType.VARIABLE and not self._node.expression: - self._node.add_expression(self._node.variable_declaration.expression) + expression = self._node.variable_declaration.expression + if expression: + self._node.add_expression(expression) if self._unparsed_expression: expression = parse_yul(self._scope, self, self._unparsed_expression) - self._node.add_expression(expression) + if expression: + self._node.add_expression(expression) if self._node.expression: if self._node.type == NodeType.VARIABLE: # Update the expression to be an assignement to the variable - _expression = AssignmentOperation( - Identifier(self._node.variable_declaration), - self._node.expression, - AssignmentOperationType.ASSIGN, - self._node.variable_declaration.type, - ) - _expression.set_offset( - self._node.expression.source_mapping, self._node.compilation_unit - ) - self._node.add_expression(_expression, bypass_verif_empty=True) + variable_declaration = self._node.variable_declaration + if variable_declaration: + _expression = AssignmentOperation( + Identifier(self._node.variable_declaration), + self._node.expression, + AssignmentOperationType.ASSIGN, + variable_declaration.type, + ) + _expression.set_offset( + self._node.expression.source_mapping, self._node.compilation_unit + ) + self._node.add_expression(_expression, bypass_verif_empty=True) expression = self._node.expression read_var = ReadVar(expression) @@ -122,13 +128,13 @@ class YulScope(metaclass=abc.ABCMeta): ] def __init__( - self, contract: Optional[Contract], yul_id: List[str], parent_func: Function = None - ): + self, contract: Optional[Contract], yul_id: List[str], parent_func: Function + ) -> None: self._contract = contract self._id: List[str] = yul_id self._yul_local_variables: List[YulLocalVariable] = [] self._yul_local_functions: List[YulFunction] = [] - self._parent_func = parent_func + self._parent_func: Function = parent_func @property def id(self) -> List[str]: @@ -155,10 +161,14 @@ class YulScope(metaclass=abc.ABCMeta): def new_node(self, node_type: NodeType, src: Union[str, Dict]) -> YulNode: pass - def add_yul_local_variable(self, var): + @property + def file_scope(self) -> FileScope: + return self._parent_func.file_scope + + def add_yul_local_variable(self, var: "YulLocalVariable") -> None: self._yul_local_variables.append(var) - def get_yul_local_variable_from_name(self, variable_name): + def get_yul_local_variable_from_name(self, variable_name: str) -> Optional["YulLocalVariable"]: return next( ( v @@ -168,10 +178,10 @@ class YulScope(metaclass=abc.ABCMeta): None, ) - def add_yul_local_function(self, func): + def add_yul_local_function(self, func: "YulFunction") -> None: self._yul_local_functions.append(func) - def get_yul_local_function_from_name(self, func_name): + def get_yul_local_function_from_name(self, func_name: str) -> Optional["YulLocalVariable"]: return next( (v for v in self._yul_local_functions if v.underlying.name == func_name), None, @@ -242,7 +252,7 @@ class YulFunction(YulScope): def function(self) -> Function: return self._function - def convert_body(self): + def convert_body(self) -> None: node = self.new_node(NodeType.ENTRYPOINT, self._ast["src"]) link_underlying_nodes(self._entrypoint, node) @@ -258,7 +268,7 @@ class YulFunction(YulScope): convert_yul(self, node, self._ast["body"], self.node_scope) - def parse_body(self): + def parse_body(self) -> None: for node in self._nodes: node.analyze_expressions() @@ -289,9 +299,8 @@ class YulBlock(YulScope): entrypoint: Node, yul_id: List[str], node_scope: Union[Scope, Function], - **kwargs, ): - super().__init__(contract, yul_id, **kwargs) + super().__init__(contract, yul_id, entrypoint.function) self._entrypoint: YulNode = YulNode(entrypoint, self) self._nodes: List[YulNode] = [] @@ -318,7 +327,7 @@ class YulBlock(YulScope): def convert(self, ast: Dict) -> YulNode: return convert_yul(self, self._entrypoint, ast, self.node_scope) - def analyze_expressions(self): + def analyze_expressions(self) -> None: for node in self._nodes: node.analyze_expressions() @@ -361,18 +370,22 @@ def convert_yul_function_definition( while not isinstance(top_node_scope, Function): top_node_scope = top_node_scope.father + func: Union[FunctionTopLevel, FunctionContract] if isinstance(top_node_scope, FunctionTopLevel): - scope = root.contract.file_scope + scope = root.file_scope func = FunctionTopLevel(root.compilation_unit, scope) # Note: we do not add the function in the scope # While its a top level function, it is not accessible outside of the function definition # In practice we should probably have a specific function type for function defined within a function else: func = FunctionContract(root.compilation_unit) + func.function_language = FunctionLanguage.Yul yul_function = YulFunction(func, root, ast, node_scope) - root.contract.add_function(func) + if root.contract: + root.contract.add_function(func) + root.compilation_unit.add_function(func) root.add_yul_local_function(yul_function) @@ -774,14 +787,15 @@ def parse_yul_identifier(root: YulScope, _node: YulNode, ast: Dict) -> Optional[ # check function-scoped variables parent_func = root.parent_func if parent_func: - variable = parent_func.get_local_variable_from_name(name) - if variable: - return Identifier(variable) + local_variable = parent_func.get_local_variable_from_name(name) + if local_variable: + return Identifier(local_variable) if isinstance(parent_func, FunctionContract): - variable = parent_func.contract.get_state_variable_from_name(name) - if variable: - return Identifier(variable) + assert parent_func.contract + state_variable = parent_func.contract.get_state_variable_from_name(name) + if state_variable: + return Identifier(state_variable) # check yul-scoped variable variable = root.get_yul_local_variable_from_name(name) @@ -798,7 +812,7 @@ def parse_yul_identifier(root: YulScope, _node: YulNode, ast: Dict) -> Optional[ if magic_suffix: return magic_suffix - ret, _ = find_top_level(name, root.contract.file_scope) + ret, _ = find_top_level(name, root.file_scope) if ret: return Identifier(ret) @@ -840,7 +854,7 @@ def parse_yul_unsupported(_root: YulScope, _node: YulNode, ast: Dict) -> Optiona def parse_yul(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: - op = parsers.get(ast["nodeType"], parse_yul_unsupported)(root, node, ast) + op: Expression = parsers.get(ast["nodeType"], parse_yul_unsupported)(root, node, ast) if op: op.set_offset(ast["src"], root.compilation_unit) return op diff --git a/tests/ast-parsing/compile/yul-top-level-0.8.0.sol-0.8.0-compact.zip b/tests/ast-parsing/compile/yul-top-level-0.8.0.sol-0.8.0-compact.zip new file mode 100644 index 000000000..ce81a5c37 Binary files /dev/null and b/tests/ast-parsing/compile/yul-top-level-0.8.0.sol-0.8.0-compact.zip differ diff --git a/tests/ast-parsing/yul-top-level-0.8.0.sol b/tests/ast-parsing/yul-top-level-0.8.0.sol new file mode 100644 index 000000000..214db9cb4 --- /dev/null +++ b/tests/ast-parsing/yul-top-level-0.8.0.sol @@ -0,0 +1,16 @@ +function top_level_yul(int256 c) pure returns (uint result) { + assembly { + function internal_yul(a) -> b { + b := a + } + + result := internal_yul(c) + } +} + + +contract Test { + function test() public{ + top_level_yul(10); + } +} \ No newline at end of file diff --git a/tests/test_ast_parsing.py b/tests/test_ast_parsing.py index 4e452a0fc..2f0e9b12c 100644 --- a/tests/test_ast_parsing.py +++ b/tests/test_ast_parsing.py @@ -438,6 +438,7 @@ ALL_TESTS = [ Test("using-for-global-0.8.0.sol", ["0.8.15"]), Test("library_event-0.8.16.sol", ["0.8.16"]), Test("top-level-struct-0.8.0.sol", ["0.8.0"]), + Test("yul-top-level-0.8.0.sol", ["0.8.0"]), Test("complex_imports/import_aliases_issue_1319/test.sol", ["0.5.12"]), ] # create the output folder if needed