diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 7643b19b7..a740d41b9 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -5,7 +5,6 @@ from enum import Enum from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from slither.all_exceptions import SlitherException -from slither.core.children.child_function import ChildFunction from slither.core.declarations import Contract, Function from slither.core.declarations.solidity_variables import ( SolidityVariable, @@ -106,7 +105,7 @@ class NodeType(Enum): # I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin) # pylint: disable=no-member -class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-methods +class Node(SourceMapping): # pylint: disable=too-many-public-methods """ Node class @@ -189,6 +188,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self.scope: Union["Scope", "Function"] = scope self.file_scope: "FileScope" = file_scope + self._function: Optional["Function"] = None ################################################################################### ################################################################################### @@ -224,6 +224,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return True return False + def set_function(self, function: "Function") -> None: + self._function = function + + @property + def function(self) -> "Function": + return self._function + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py deleted file mode 100644 index df91596e3..000000000 --- a/slither/core/children/child_event.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Event - - -class ChildEvent: - def __init__(self) -> None: - super().__init__() - self._event = None - - def set_event(self, event: "Event"): - self._event = event - - @property - def event(self) -> "Event": - return self._event diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py deleted file mode 100644 index 0064658c0..000000000 --- a/slither/core/children/child_expression.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from slither.core.expressions.expression import Expression - from slither.slithir.operations import Operation - - -class ChildExpression: - def __init__(self) -> None: - super().__init__() - self._expression = None - - def set_expression(self, expression: Union["Expression", "Operation"]) -> None: - self._expression = expression - - @property - def expression(self) -> Union["Expression", "Operation"]: - return self._expression diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py deleted file mode 100644 index 5367320ca..000000000 --- a/slither/core/children/child_function.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Function - - -class ChildFunction: - def __init__(self) -> None: - super().__init__() - self._function = None - - def set_function(self, function: "Function") -> None: - self._function = function - - @property - def function(self) -> "Function": - return self._function diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py deleted file mode 100644 index 30b32f6c1..000000000 --- a/slither/core/children/child_inheritance.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Contract - - -class ChildInheritance: - def __init__(self) -> None: - super().__init__() - self._contract_declarer = None - - def set_contract_declarer(self, contract: "Contract") -> None: - self._contract_declarer = contract - - @property - def contract_declarer(self) -> "Contract": - return self._contract_declarer diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py deleted file mode 100644 index 8e6e1f0b5..000000000 --- a/slither/core/children/child_node.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.compilation_unit import SlitherCompilationUnit - from slither.core.cfg.node import Node - from slither.core.declarations import Function, Contract - - -class ChildNode: - def __init__(self) -> None: - super().__init__() - self._node = None - - def set_node(self, node: "Node") -> None: - self._node = node - - @property - def node(self) -> "Node": - return self._node - - @property - def function(self) -> "Function": - return self.node.function - - @property - def contract(self) -> "Contract": - return self.node.function.contract - - @property - def compilation_unit(self) -> "SlitherCompilationUnit": - return self.node.compilation_unit diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py deleted file mode 100644 index abcb041c2..000000000 --- a/slither/core/children/child_structure.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Structure - - -class ChildStructure: - def __init__(self) -> None: - super().__init__() - self._structure = None - - def set_structure(self, structure: "Structure") -> None: - self._structure = structure - - @property - def structure(self) -> "Structure": - return self._structure diff --git a/slither/core/children/child_contract.py b/slither/core/declarations/contract_level.py similarity index 57% rename from slither/core/children/child_contract.py rename to slither/core/declarations/contract_level.py index 86f9dea53..5893a7035 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/declarations/contract_level.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from slither.core.source_mapping.source_mapping import SourceMapping @@ -6,14 +6,21 @@ if TYPE_CHECKING: from slither.core.declarations import Contract -class ChildContract(SourceMapping): +class ContractLevel(SourceMapping): + """ + This class is used to represent objects that are at the contract level + The opposite is TopLevel + + """ + def __init__(self) -> None: super().__init__() - self._contract = None + self._contract: Optional["Contract"] = None def set_contract(self, contract: "Contract") -> None: self._contract = contract @property def contract(self) -> "Contract": + assert self._contract return self._contract diff --git a/slither/core/declarations/custom_error_contract.py b/slither/core/declarations/custom_error_contract.py index a96f12057..a3839e3f2 100644 --- a/slither/core/declarations/custom_error_contract.py +++ b/slither/core/declarations/custom_error_contract.py @@ -1,8 +1,8 @@ -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations.custom_error import CustomError -class CustomErrorContract(CustomError, ChildContract): +class CustomErrorContract(CustomError, ContractLevel): def is_declared_by(self, contract): """ Check if the element is declared by the contract diff --git a/slither/core/declarations/enum_contract.py b/slither/core/declarations/enum_contract.py index 46168d107..2e51ae511 100644 --- a/slither/core/declarations/enum_contract.py +++ b/slither/core/declarations/enum_contract.py @@ -1,13 +1,13 @@ from typing import TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Enum if TYPE_CHECKING: from slither.core.declarations import Contract -class EnumContract(Enum, ChildContract): +class EnumContract(Enum, ContractLevel): def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index d616679a2..9d42ac224 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -1,6 +1,6 @@ from typing import List, Tuple, TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.event_variable import EventVariable @@ -8,7 +8,7 @@ if TYPE_CHECKING: from slither.core.declarations import Contract -class Event(ChildContract, SourceMapping): +class Event(ContractLevel, SourceMapping): def __init__(self) -> None: super().__init__() self._name = None diff --git a/slither/core/declarations/function_contract.py b/slither/core/declarations/function_contract.py index 19456bbea..69c50a117 100644 --- a/slither/core/declarations/function_contract.py +++ b/slither/core/declarations/function_contract.py @@ -1,10 +1,9 @@ """ Function module """ -from typing import Dict, TYPE_CHECKING, List, Tuple +from typing import Dict, TYPE_CHECKING, List, Tuple, Optional -from slither.core.children.child_contract import ChildContract -from slither.core.children.child_inheritance import ChildInheritance +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Function @@ -14,9 +13,31 @@ if TYPE_CHECKING: from slither.core.declarations import Contract from slither.core.scope.scope import FileScope from slither.slithir.variables.state_variable import StateIRVariable + from slither.core.compilation_unit import SlitherCompilationUnit -class FunctionContract(Function, ChildContract, ChildInheritance): +class FunctionContract(Function, ContractLevel): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: + super().__init__(compilation_unit) + self._contract_declarer: Optional["Contract"] = None + + def set_contract_declarer(self, contract: "Contract") -> None: + self._contract_declarer = contract + + @property + def contract_declarer(self) -> "Contract": + """ + Return the contract where this function was declared. Only functions have both a contract, and contract_declarer + This is because we need to have separate representation of the function depending of the contract's context + For example a function calling super.f() will generate different IR depending on the current contract's inheritance + + Returns: + The contract where this function was declared + """ + + assert self._contract_declarer + return self._contract_declarer + @property def canonical_name(self) -> str: """ diff --git a/slither/core/declarations/structure_contract.py b/slither/core/declarations/structure_contract.py index aaf660e1e..c9d05ce4e 100644 --- a/slither/core/declarations/structure_contract.py +++ b/slither/core/declarations/structure_contract.py @@ -1,8 +1,8 @@ -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Structure -class StructureContract(Structure, ChildContract): +class StructureContract(Structure, ContractLevel): def is_declared_by(self, contract): """ Check if the element is declared by the contract diff --git a/slither/core/declarations/top_level.py b/slither/core/declarations/top_level.py index 15facf2f9..01e6f6dfd 100644 --- a/slither/core/declarations/top_level.py +++ b/slither/core/declarations/top_level.py @@ -2,4 +2,8 @@ from slither.core.source_mapping.source_mapping import SourceMapping class TopLevel(SourceMapping): - pass + """ + This class is used to represent objects that are at the top level + The opposite is ContractLevel + + """ diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e5f4e830a..e55a9cf0b 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -13,7 +13,7 @@ from typing import Optional, Dict, List, Set, Union, Tuple from crytic_compile import CryticCompile from crytic_compile.utils.naming import Filename -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.context.context import Context from slither.core.declarations import Contract, FunctionContract @@ -206,7 +206,10 @@ class SlitherCore(Context): isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract ) - or (isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract)) + or ( + isinstance(thing, ContractLevel) + and not isinstance(thing, FunctionContract) + ) ): self._offset_to_objects[definition.filename][offset].add(thing) @@ -224,7 +227,8 @@ class SlitherCore(Context): and thing.contract_declarer == thing.contract ) or ( - isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract) + isinstance(thing, ContractLevel) + and not isinstance(thing, FunctionContract) ) ): self._offset_to_objects[definition.filename][offset].add(thing) diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 5b9ea0a37..c47d2ee14 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Tuple -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations.top_level import TopLevel from slither.core.solidity_types import Type, ElementaryType @@ -48,7 +48,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel): return self.name -class TypeAliasContract(TypeAlias, ChildContract): +class TypeAliasContract(TypeAlias, ContractLevel): def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index f3ad60d0b..3b6b6c511 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -1,8 +1,7 @@ from slither.core.variables.variable import Variable -from slither.core.children.child_event import ChildEvent -class EventVariable(ChildEvent, Variable): +class EventVariable(Variable): def __init__(self) -> None: super().__init__() self._indexed = False @@ -16,5 +15,5 @@ class EventVariable(ChildEvent, Variable): return self._indexed @indexed.setter - def indexed(self, is_indexed: bool): + def indexed(self, is_indexed: bool) -> None: self._indexed = is_indexed diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py index 7b7b4f8bc..fc23eeba7 100644 --- a/slither/core/variables/local_variable.py +++ b/slither/core/variables/local_variable.py @@ -1,7 +1,6 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING from slither.core.variables.variable import Variable -from slither.core.children.child_function import ChildFunction from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.mapping_type import MappingType @@ -9,11 +8,23 @@ from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.declarations.structure import Structure +if TYPE_CHECKING: # type: ignore + from slither.core.declarations import Function -class LocalVariable(ChildFunction, Variable): + +class LocalVariable(Variable): def __init__(self) -> None: super().__init__() self._location: Optional[str] = None + self._function: Optional["Function"] = None + + def set_function(self, function: "Function") -> None: + self._function = function + + @property + def function(self) -> "Function": + assert self._function + return self._function def set_location(self, loc: str) -> None: self._location = loc diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index 47b7682a4..f2a2d6ee3 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -1,6 +1,6 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.variables.variable import Variable if TYPE_CHECKING: @@ -8,7 +8,7 @@ if TYPE_CHECKING: from slither.core.declarations import Contract -class StateVariable(ChildContract, Variable): +class StateVariable(ContractLevel, Variable): def __init__(self) -> None: super().__init__() self._node_initialization: Optional["Node"] = None diff --git a/slither/core/variables/structure_variable.py b/slither/core/variables/structure_variable.py index c6034da63..3a001b6a9 100644 --- a/slither/core/variables/structure_variable.py +++ b/slither/core/variables/structure_variable.py @@ -1,6 +1,19 @@ +from typing import TYPE_CHECKING, Optional from slither.core.variables.variable import Variable -from slither.core.children.child_structure import ChildStructure -class StructureVariable(ChildStructure, Variable): - pass +if TYPE_CHECKING: + from slither.core.declarations import Structure + + +class StructureVariable(Variable): + def __init__(self) -> None: + super().__init__() + self._structure: Optional["Structure"] = None + + def set_structure(self, structure: "Structure") -> None: + self._structure = structure + + @property + def structure(self) -> "Structure": + return self._structure diff --git a/slither/detectors/erc/incorrect_erc721_interface.py b/slither/detectors/erc/incorrect_erc721_interface.py index 8327e8b2e..f894fb517 100644 --- a/slither/detectors/erc/incorrect_erc721_interface.py +++ b/slither/detectors/erc/incorrect_erc721_interface.py @@ -89,7 +89,9 @@ contract Token{ return False @staticmethod - def detect_incorrect_erc721_interface(contract: Contract) -> List[Union[FunctionContract, Any]]: + def detect_incorrect_erc721_interface( + contract: Contract, + ) -> List[Union[FunctionContract, Any]]: """Detect incorrect ERC721 interface Returns: diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index 6e1d5fbb5..e553e78eb 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -70,7 +70,9 @@ contract C { def _detect_missing_events( self, contract: Contract - ) -> List[Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]]]: + ) -> List[ + Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]] + ]: """ Detects if critical contract parameters set by owners and used in arithmetic are missing events :param contract: The contract to check diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index 34f8173d5..e281c1d09 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -57,7 +57,9 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` ) return False - def detect_tx_origin(self, contract: Contract) -> List[Tuple[FunctionContract, List[Node]]]: + def detect_tx_origin( + self, contract: Contract + ) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in contract.functions: diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 87a6b075b..aa8dfb4ec 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -731,7 +731,7 @@ def propagate_types( return _convert_type_contract(ir) left = ir.variable_left t = None - ir_func = ir.function + ir_func = ir.node.function # Handling of this.function_name usage if ( left == SolidityVariable("this") diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py index a1ad1aa15..ca245167e 100644 --- a/slither/slithir/operations/internal_dynamic_call.py +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -24,7 +24,7 @@ class InternalDynamicCall( assert isinstance(function, Variable) assert is_valid_lvalue(lvalue) or lvalue is None super().__init__() - self._function = function + self._function: Variable = function self._function_type = function_type self._lvalue = lvalue @@ -37,7 +37,7 @@ class InternalDynamicCall( return self._unroll(self.arguments) + [self.function] @property - def function(self) -> Union[LocalVariable, LocalIRVariable]: + def function(self) -> Variable: return self._function @property diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py index 752de6a3d..f24b3bccd 100644 --- a/slither/slithir/operations/new_structure.py +++ b/slither/slithir/operations/new_structure.py @@ -14,7 +14,9 @@ from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class NewStructure(Call, OperationWithLValue): def __init__( - self, structure: StructureContract, lvalue: Union[TemporaryVariableSSA, TemporaryVariable] + self, + structure: StructureContract, + lvalue: Union[TemporaryVariableSSA, TemporaryVariable], ) -> None: super().__init__() assert isinstance(structure, Structure) diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py index fcf5f4868..aca3e645b 100644 --- a/slither/slithir/operations/operation.py +++ b/slither/slithir/operations/operation.py @@ -1,11 +1,14 @@ import abc -from typing import Any, List +from typing import Any, List, Optional, TYPE_CHECKING from slither.core.context.context import Context -from slither.core.children.child_expression import ChildExpression -from slither.core.children.child_node import ChildNode +from slither.core.expressions.expression import Expression from slither.core.variables.variable import Variable from slither.utils.utils import unroll +if TYPE_CHECKING: + from slither.core.compilation_unit import SlitherCompilationUnit + from slither.core.cfg.node import Node + class AbstractOperation(abc.ABC): @property @@ -25,7 +28,24 @@ class AbstractOperation(abc.ABC): pass # pylint: disable=unnecessary-pass -class Operation(Context, ChildExpression, ChildNode, AbstractOperation): +class Operation(Context, AbstractOperation): + def __init__(self) -> None: + super().__init__() + self._node: Optional["Node"] = None + self._expression: Optional[Expression] = None + + def set_node(self, node: "Node") -> None: + self._node = node + + @property + def node(self) -> "Node": + assert self._node + return self._node + + @property + def compilation_unit(self) -> "SlitherCompilationUnit": + return self.node.compilation_unit + @property def used(self) -> List[Variable]: """ @@ -37,3 +57,10 @@ class Operation(Context, ChildExpression, ChildNode, AbstractOperation): @staticmethod def _unroll(l: List[Any]) -> List[Any]: return unroll(l) + + def set_expression(self, expression: Expression) -> None: + self._expression = expression + + @property + def expression(self) -> Optional[Expression]: + return self._expression diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index b059c55a6..c0d8d8404 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -2,7 +2,6 @@ from typing import Any, List, Union from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue -from slither.core.children.child_node import ChildNode from slither.core.solidity_types.elementary_type import ElementaryType @@ -11,7 +10,7 @@ class SolidityCall(Call, OperationWithLValue): self, function: Union[SolidityCustomRevert, SolidityFunction], nbr_arguments: int, - result: ChildNode, + result, type_call: Union[str, List[ElementaryType]], ) -> None: assert isinstance(function, SolidityFunction) diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index f351f1fdd..ce41e3c54 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -17,7 +17,9 @@ class TypeConversion(OperationWithLValue): self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, - variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], + variable_type: Union[ + TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel + ], ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 95802b7e2..9ab51be65 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -1,6 +1,5 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.core.declarations import Contract, Enum, SolidityVariable, Function from slither.core.variables.variable import Variable @@ -8,7 +7,7 @@ if TYPE_CHECKING: from slither.core.cfg.node import Node -class ReferenceVariable(ChildNode, Variable): +class ReferenceVariable(Variable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -19,6 +18,10 @@ class ReferenceVariable(ChildNode, Variable): self._points_to = None self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index 8cb1cf350..5a485f985 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -1,13 +1,12 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.core.variables.variable import Variable if TYPE_CHECKING: from slither.core.cfg.node import Node -class TemporaryVariable(ChildNode, Variable): +class TemporaryVariable(Variable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -17,6 +16,10 @@ class TemporaryVariable(ChildNode, Variable): self._index = index self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py index dc085347e..9a13b1d5d 100644 --- a/slither/slithir/variables/tuple.py +++ b/slither/slithir/variables/tuple.py @@ -1,13 +1,12 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.slithir.variables.variable import SlithIRVariable if TYPE_CHECKING: from slither.core.cfg.node import Node -class TupleVariable(ChildNode, SlithIRVariable): +class TupleVariable(SlithIRVariable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -18,6 +17,10 @@ class TupleVariable(ChildNode, SlithIRVariable): self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 475c3fab2..47ee7ec10 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -2,7 +2,13 @@ import logging import re from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set -from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function +from slither.core.declarations import ( + Modifier, + Event, + EnumContract, + StructureContract, + Function, +) from slither.core.declarations.contract import Contract from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.function_contract import FunctionContract diff --git a/slither/utils/output.py b/slither/utils/output.py index 9dba15e31..84c9ac65a 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -10,8 +10,17 @@ from zipfile import ZipFile from pkg_resources import require from slither.core.cfg.node import Node -from slither.core.declarations import Contract, Function, Enum, Event, Structure, Pragma +from slither.core.declarations import ( + Contract, + Function, + Enum, + Event, + Structure, + Pragma, + FunctionContract, +) from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.variable import Variable from slither.exceptions import SlitherError from slither.utils.colors import yellow @@ -351,21 +360,19 @@ def _create_parent_element( ], ]: # pylint: disable=import-outside-toplevel - from slither.core.children.child_contract import ChildContract - from slither.core.children.child_function import ChildFunction - from slither.core.children.child_inheritance import ChildInheritance + from slither.core.declarations.contract_level import ContractLevel - if isinstance(element, ChildInheritance): + if isinstance(element, FunctionContract): if element.contract_declarer: contract = Output("") contract.add_contract(element.contract_declarer) return contract.data["elements"][0] - elif isinstance(element, ChildContract): + elif isinstance(element, ContractLevel): if element.contract: contract = Output("") contract.add_contract(element.contract) return contract.data["elements"][0] - elif isinstance(element, ChildFunction): + elif isinstance(element, (LocalVariable, Node)): if element.function: function = Output("") function.add_function(element.function) diff --git a/tests/test_ssa_generation.py b/tests/test_ssa_generation.py index f002ec4e1..9bb008fdf 100644 --- a/tests/test_ssa_generation.py +++ b/tests/test_ssa_generation.py @@ -689,7 +689,7 @@ def test_initial_version_exists_for_state_variables_function_assign(): # temporary variable, that is then assigned to a call = get_ssa_of_type(ctor, InternalCall)[0] - assert call.function == f + assert call.node.function == f assign = get_ssa_of_type(ctor, Assignment)[0] assert assign.rvalue == call.lvalue assert isinstance(assign.lvalue, StateIRVariable)