diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 731c14e92..b4f079c8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - type: ["4", "5", "cli", "dapp", "data_dependency", "embark", "erc", "etherlime", "find_paths", "kspec", "printers", "simil", "slither_config", "truffle", "upgradability", "prop"] + type: ["4", "5", "cli", "data_dependency", "embark", "erc", "etherlime", "find_paths", "kspec", "printers", "simil", "slither_config", "truffle", "upgradability", "prop"] steps: - uses: actions/checkout@v1 - name: Set up Python 3.6 diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 2c98b4c79..f5ccef2db 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -1,29 +1,59 @@ """ Node module """ -import logging -from typing import Optional +from enum import Enum +from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from slither.core.children.child_function import ChildFunction -from slither.core.declarations import Contract -from slither.core.declarations.solidity_variables import SolidityVariable +from slither.core.declarations.solidity_variables import ( + SolidityVariable, + SolidityFunction, +) from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.core.solidity_types import ElementaryType from slither.slithir.convert import convert_expression -from slither.slithir.operations import (Balance, HighLevelCall, Index, - InternalCall, Length, LibraryCall, - LowLevelCall, Member, - OperationWithLValue, Phi, PhiCallback, - SolidityCall, Return) -from slither.slithir.variables import (Constant, LocalIRVariable, - ReferenceVariable, StateIRVariable, - TemporaryVariable, TupleVariable) +from slither.slithir.operations import ( + Balance, + HighLevelCall, + Index, + InternalCall, + Length, + LibraryCall, + LowLevelCall, + Member, + OperationWithLValue, + Phi, + PhiCallback, + SolidityCall, + Return, + Operation, +) +from slither.slithir.variables import ( + Constant, + LocalIRVariable, + ReferenceVariable, + StateIRVariable, + TemporaryVariable, + TupleVariable, +) from slither.all_exceptions import SlitherException +from slither.core.declarations import Contract -logger = logging.getLogger("Node") +from slither.core.expressions.expression import Expression + +if TYPE_CHECKING: + from slither.core.declarations import Function + from slither.slithir.variables.variable import SlithIRVariable + from slither.core.slither_core import SlitherCore + from slither.utils.type_helpers import ( + InternalCallType, + HighLevelCallType, + LibraryCallType, + LowLevelCallType, + ) ################################################################################### @@ -32,7 +62,8 @@ logger = logging.getLogger("Node") ################################################################################### ################################################################################### -class NodeType: + +class NodeType(Enum): ENTRYPOINT = 0x0 # no expression # Node with expression @@ -71,158 +102,118 @@ class NodeType: OTHER_ENTRYPOINT = 0x50 # @staticmethod - def str(t): - if t == NodeType.ENTRYPOINT: - return 'ENTRY_POINT' - if t == NodeType.EXPRESSION: - return 'EXPRESSION' - if t == NodeType.RETURN: - return 'RETURN' - if t == NodeType.IF: - return 'IF' - if t == NodeType.VARIABLE: - return 'NEW VARIABLE' - if t == NodeType.ASSEMBLY: - return 'INLINE ASM' - if t == NodeType.IFLOOP: - return 'IF_LOOP' - if t == NodeType.THROW: - return 'THROW' - if t == NodeType.BREAK: - return 'BREAK' - if t == NodeType.CONTINUE: - return 'CONTINUE' - if t == NodeType.PLACEHOLDER: - return '_' - if t == NodeType.TRY: - return 'TRY' - if t == NodeType.CATCH: - return 'CATCH' - if t == NodeType.ENDIF: - return 'END_IF' - if t == NodeType.STARTLOOP: - return 'BEGIN_LOOP' - if t == NodeType.ENDLOOP: - return 'END_LOOP' - return 'Unknown type {}'.format(hex(t)) + def __str__(self): + if self == NodeType.ENTRYPOINT: + return "ENTRY_POINT" + if self == NodeType.EXPRESSION: + return "EXPRESSION" + if self == NodeType.RETURN: + return "RETURN" + if self == NodeType.IF: + return "IF" + if self == NodeType.VARIABLE: + return "NEW VARIABLE" + if self == NodeType.ASSEMBLY: + return "INLINE ASM" + if self == NodeType.IFLOOP: + return "IF_LOOP" + if self == NodeType.THROW: + return "THROW" + if self == NodeType.BREAK: + return "BREAK" + if self == NodeType.CONTINUE: + return "CONTINUE" + if self == NodeType.PLACEHOLDER: + return "_" + if self == NodeType.TRY: + return "TRY" + if self == NodeType.CATCH: + return "CATCH" + if self == NodeType.ENDIF: + return "END_IF" + if self == NodeType.STARTLOOP: + return "BEGIN_LOOP" + if self == NodeType.ENDLOOP: + return "END_LOOP" + return "Unknown type {}".format(hex(self.value)) # endregion -################################################################################### -################################################################################### -# region Utils -################################################################################### -################################################################################### - -def link_nodes(n1, n2): - n1.add_son(n2) - n2.add_father(n1) - - -def insert_node(origin, node_inserted): - sons = origin.sons - link_nodes(origin, node_inserted) - for son in sons: - son.remove_father(origin) - origin.remove_son(son) - - link_nodes(node_inserted, son) -def recheable(node): - ''' - Return the set of nodes reacheable from the node - :param node: - :return: set(Node) - ''' - nodes = node.sons - visited = set() - while nodes: - next = nodes[0] - nodes = nodes[1:] - if not next in visited: - visited.add(next) - for son in next.sons: - if not son in visited: - nodes.append(son) - return visited - - -# endregion - class Node(SourceMapping, ChildFunction): """ Node class """ - def __init__(self, node_type, node_id): + def __init__(self, node_type: NodeType, node_id: int): super(Node, self).__init__() self._node_type = node_type - # TODO: rename to explicit CFG - self._sons = [] - self._fathers = [] + # TODO: rename to explicit CFG + self._sons: List["Node"] = [] + self._fathers: List["Node"] = [] ## Dominators info # Dominators nodes - self._dominators = set() - self._immediate_dominator = None + self._dominators: Set["Node"] = set() + self._immediate_dominator: Optional["Node"] = None ## Nodes of the dominators tree # self._dom_predecessors = set() - self._dom_successors = set() + self._dom_successors: Set["Node"] = set() # Dominance frontier - self._dominance_frontier = set() + self._dominance_frontier: Set["Node"] = set() # Phi origin # key are variable name - # values are list of Node - self._phi_origins_state_variables = {} - self._phi_origins_local_variables = {} + self._phi_origins_state_variables: Dict[str, Tuple[StateVariable, Set["Node"]]] = {} + self._phi_origins_local_variables: Dict[str, Tuple[LocalVariable, Set["Node"]]] = {} + #self._phi_origins_member_variables: Dict[str, Tuple[MemberVariable, Set["Node"]]] = {} - self._expression = None - self._variable_declaration = None - self._node_id = node_id + self._expression: Optional[Expression] = None + self._variable_declaration: Optional[LocalVariable] = None + self._node_id: int = node_id - self._vars_written = [] - self._vars_read = [] + self._vars_written: List[Variable] = [] + self._vars_read: List[Variable] = [] - self._ssa_vars_written = [] - self._ssa_vars_read = [] + self._ssa_vars_written: List["SlithIRVariable"] = [] + self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls = [] - self._solidity_calls = [] - self._high_level_calls = [] # contains library calls - self._library_calls = [] - self._low_level_calls = [] - self._external_calls_as_expressions = [] - self._internal_calls_as_expressions = [] - self._irs = [] - self._irs_ssa = [] + self._internal_calls: List[Function] = [] + self._solidity_calls: List[SolidityFunction] = [] + self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls + self._library_calls: List["LibraryCallType"] = [] + self._low_level_calls: List["LowLevelCallType"] = [] + self._external_calls_as_expressions: List[Expression] = [] + self._internal_calls_as_expressions: List[Expression] = [] + self._irs: List[Operation] = [] + self._irs_ssa: List[Operation] = [] - self._state_vars_written = [] - self._state_vars_read = [] - self._solidity_vars_read = [] + self._state_vars_written: List[StateVariable] = [] + self._state_vars_read: List[StateVariable] = [] + self._solidity_vars_read: List[SolidityVariable] = [] - self._ssa_state_vars_written = [] - self._ssa_state_vars_read = [] + self._ssa_state_vars_written: List[StateIRVariable] = [] + self._ssa_state_vars_read: List[StateIRVariable] = [] - self._local_vars_read = [] - self._local_vars_written = [] + self._local_vars_read: List[LocalVariable] = [] + self._local_vars_written: List[LocalVariable] = [] - self._slithir_vars = set() # non SSA + self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA - self._ssa_local_vars_read = [] - self._ssa_local_vars_written = [] + self._ssa_local_vars_read: List[LocalIRVariable] = [] + self._ssa_local_vars_written: List[LocalIRVariable] = [] - self._expression_vars_written = [] - self._expression_vars_read = [] - self._expression_calls = [] + self._expression_vars_written: List[Expression] = [] + self._expression_vars_read: List[Expression] = [] + self._expression_calls: List[Expression] = [] # Computed on the fly, can be True of False - self._can_reenter = None - self._can_send_eth = None + self._can_reenter: Optional[bool] = None + self._can_send_eth: Optional[bool] = None - self._asm_source_code = None + self._asm_source_code: Optional[Union[str, Dict]] = None ################################################################################### ################################################################################### @@ -231,24 +222,32 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def slither(self): + def slither(self) -> "SlitherCore": return self.function.slither @property - def node_id(self): + def node_id(self) -> int: """Unique node id.""" return self._node_id @property - def type(self): + def type(self) -> NodeType: """ NodeType: type of the node """ return self._node_type @type.setter - def type(self, t): - self._node_type = t + def type(self, new_type: NodeType): + self._node_type = new_type + + @property + def will_return(self) -> bool: + if not self.sons and self.type != NodeType.THROW: + if SolidityFunction("revert()") not in self.solidity_calls: + if SolidityFunction("revert(string)") not in self.solidity_calls: + return True + return False # endregion ################################################################################### @@ -258,108 +257,116 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def variables_read(self): + def variables_read(self) -> List[Variable]: """ list(Variable): Variables read (local/state/solidity) """ return list(self._vars_read) @property - def state_variables_read(self): + def state_variables_read(self) -> List[StateVariable]: """ list(StateVariable): State variables read """ return list(self._state_vars_read) @property - def local_variables_read(self): + def local_variables_read(self) -> List[LocalVariable]: """ list(LocalVariable): Local variables read """ return list(self._local_vars_read) @property - def solidity_variables_read(self): + def solidity_variables_read(self) -> List[SolidityVariable]: """ list(SolidityVariable): State variables read """ return list(self._solidity_vars_read) @property - def ssa_variables_read(self): + def ssa_variables_read(self) -> List["SlithIRVariable"]: """ list(Variable): Variables read (local/state/solidity) """ return list(self._ssa_vars_read) @property - def ssa_state_variables_read(self): + def ssa_state_variables_read(self) -> List[StateIRVariable]: """ list(StateVariable): State variables read """ return list(self._ssa_state_vars_read) @property - def ssa_local_variables_read(self): + def ssa_local_variables_read(self) -> List[LocalIRVariable]: """ list(LocalVariable): Local variables read """ return list(self._ssa_local_vars_read) @property - def variables_read_as_expression(self): + def variables_read_as_expression(self) -> List[Expression]: return self._expression_vars_read + @variables_read_as_expression.setter + def variables_read_as_expression(self, exprs: List[Expression]): + self._expression_vars_read = exprs + @property - def slithir_variables(self): + def slithir_variables(self) -> List["SlithIRVariable"]: return list(self._slithir_vars) @property - def variables_written(self): + def variables_written(self) -> List[Variable]: """ list(Variable): Variables written (local/state/solidity) """ return list(self._vars_written) @property - def state_variables_written(self): + def state_variables_written(self) -> List[StateVariable]: """ list(StateVariable): State variables written """ return list(self._state_vars_written) @property - def local_variables_written(self): + def local_variables_written(self) -> List[LocalVariable]: """ list(LocalVariable): Local variables written """ return list(self._local_vars_written) @property - def ssa_variables_written(self): + def ssa_variables_written(self) -> List["SlithIRVariable"]: """ list(Variable): Variables written (local/state/solidity) """ return list(self._ssa_vars_written) @property - def ssa_state_variables_written(self): + def ssa_state_variables_written(self) -> List[StateIRVariable]: """ list(StateVariable): State variables written """ return list(self._ssa_state_vars_written) @property - def ssa_local_variables_written(self): + def ssa_local_variables_written(self) -> List[LocalIRVariable]: """ list(LocalVariable): Local variables written """ return list(self._ssa_local_vars_written) @property - def variables_written_as_expression(self): + def variables_written_as_expression(self) -> List[Expression]: return self._expression_vars_written + @variables_written_as_expression.setter + def variables_written_as_expression(self, exprs: List[Expression]): + self._expression_vars_written = exprs + # endregion ################################################################################### ################################################################################### @@ -368,21 +375,21 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def internal_calls(self): + def internal_calls(self) -> List["InternalCallType"]: """ list(Function or SolidityFunction): List of internal/soldiity function calls """ return list(self._internal_calls) @property - def solidity_calls(self): + def solidity_calls(self) -> List[SolidityFunction]: """ list(SolidityFunction): List of Soldity calls """ - return list(self._internal_calls) + return list(self._solidity_calls) @property - def high_level_calls(self): + def high_level_calls(self) -> List["HighLevelCallType"]: """ list((Contract, Function|Variable)): List of high level calls (external calls). @@ -392,7 +399,7 @@ class Node(SourceMapping, ChildFunction): return list(self._high_level_calls) @property - def library_calls(self): + def library_calls(self) -> List["LibraryCallType"]: """ list((Contract, Function)): Include library calls @@ -400,7 +407,7 @@ class Node(SourceMapping, ChildFunction): return list(self._library_calls) @property - def low_level_calls(self): + def low_level_calls(self) -> List["LowLevelCallType"]: """ list((Variable|SolidityVariable, str)): List of low_level call A low level call is defined by @@ -410,25 +417,37 @@ class Node(SourceMapping, ChildFunction): return list(self._low_level_calls) @property - def external_calls_as_expressions(self): + def external_calls_as_expressions(self) -> List[Expression]: """ list(CallExpression): List of message calls (that creates a transaction) """ return self._external_calls_as_expressions + @external_calls_as_expressions.setter + def external_calls_as_expressions(self, exprs: List[Expression]): + self._external_calls_as_expressions = exprs + @property - def internal_calls_as_expressions(self): + def internal_calls_as_expressions(self) -> List[Expression]: """ list(CallExpression): List of internal calls (that dont create a transaction) """ return self._internal_calls_as_expressions + @internal_calls_as_expressions.setter + def internal_calls_as_expressions(self, exprs: List[Expression]): + self._internal_calls_as_expressions = exprs + @property - def calls_as_expression(self): + def calls_as_expression(self) -> List[Expression]: return list(self._expression_calls) - def can_reenter(self, callstack=None): - ''' + @calls_as_expression.setter + def calls_as_expression(self, exprs: List[Expression]): + self._expression_calls = exprs + + def can_reenter(self, callstack=None) -> bool: + """ Check if the node can re-enter Do not consider CREATE as potential re-enter, but check if the destination's constructor can contain a call (recurs. follow nested CREATE) @@ -437,8 +456,9 @@ class Node(SourceMapping, ChildFunction): Do not consider Send/Transfer as there is not enough gas :param callstack: used internally to check for recursion :return bool: - ''' + """ from slither.slithir.operations import Call + if self._can_reenter is None: self._can_reenter = False for ir in self.irs: @@ -447,18 +467,20 @@ class Node(SourceMapping, ChildFunction): return True return self._can_reenter - def can_send_eth(self): - ''' + def can_send_eth(self) -> bool: + """ Check if the node can send eth :return bool: - ''' + """ from slither.slithir.operations import Call + if self._can_send_eth is None: + self._can_send_eth = False for ir in self.all_slithir_operations(): if isinstance(ir, Call) and ir.can_send_eth(): self._can_send_eth = True return True - return self._can_reenter + return self._can_send_eth # endregion ################################################################################### @@ -468,17 +490,17 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def expression(self): + def expression(self) -> Optional[Expression]: """ Expression: Expression of the node """ return self._expression - def add_expression(self, expression): - assert self._expression is None + def add_expression(self, expression: Expression, bypass_verif_empty: bool = False): + assert self._expression is None or bypass_verif_empty self._expression = expression - def add_variable_declaration(self, var): + def add_variable_declaration(self, var: LocalVariable): assert self._variable_declaration is None self._variable_declaration = var if var.expression: @@ -486,7 +508,7 @@ class Node(SourceMapping, ChildFunction): self._local_vars_written += [var] @property - def variable_declaration(self): + def variable_declaration(self) -> Optional[LocalVariable]: """ Returns: LocalVariable @@ -500,15 +522,18 @@ class Node(SourceMapping, ChildFunction): ################################################################################### ################################################################################### - def contains_require_or_assert(self): + def contains_require_or_assert(self) -> bool: """ Check if the node has a require or assert call Returns: bool: True if the node has a require or assert call """ - return any(c.name in ['require(bool)', 'require(bool,string)', 'assert(bool)'] for c in self.internal_calls) + return any( + c.name in ["require(bool)", "require(bool,string)", "assert(bool)"] + for c in self.internal_calls + ) - def contains_if(self, include_loop=True): + def contains_if(self, include_loop=True) -> bool: """ Check if the node is a IF node Returns: @@ -518,7 +543,7 @@ class Node(SourceMapping, ChildFunction): return self.type in [NodeType.IF, NodeType.IFLOOP] return self.type == NodeType.IF - def is_conditional(self, include_loop=True): + def is_conditional(self, include_loop=True) -> bool: """ Check if the node is a conditional node A conditional node is either a IF or a require/assert or a RETURN bool @@ -532,7 +557,7 @@ class Node(SourceMapping, ChildFunction): if last_ir: if isinstance(last_ir, Return): for r in last_ir.read: - if r.type == ElementaryType('bool'): + if r.type == ElementaryType("bool"): return True return False @@ -544,10 +569,10 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def inline_asm(self): + def inline_asm(self) -> Optional[Union[str, Dict]]: return self._asm_source_code - def add_inline_asm(self, asm): + def add_inline_asm(self, asm: Union[str, Dict]): self._asm_source_code = asm # endregion @@ -557,7 +582,7 @@ class Node(SourceMapping, ChildFunction): ################################################################################### ################################################################################### - def add_father(self, father): + def add_father(self, father: "Node"): """ Add a father node Args: @@ -565,7 +590,7 @@ class Node(SourceMapping, ChildFunction): """ self._fathers.append(father) - def set_fathers(self, fathers): + def set_fathers(self, fathers: List["Node"]): """ Set the father nodes Args: @@ -574,7 +599,7 @@ class Node(SourceMapping, ChildFunction): self._fathers = fathers @property - def fathers(self): + def fathers(self) -> List["Node"]: """ Returns the father nodes Returns: @@ -582,23 +607,23 @@ class Node(SourceMapping, ChildFunction): """ return list(self._fathers) - def remove_father(self, father): + def remove_father(self, father: "Node"): """ Remove the father node. Do nothing if the node is not a father Args: - fathers: list of fathers to add + :param father: """ self._fathers = [x for x in self._fathers if x.node_id != father.node_id] - def remove_son(self, son): + def remove_son(self, son: "Node"): """ Remove the son node. Do nothing if the node is not a son Args: - fathers: list of fathers to add + :param son: """ self._sons = [x for x in self._sons if x.node_id != son.node_id] - def add_son(self, son): + def add_son(self, son: "Node"): """ Add a son node Args: @@ -606,7 +631,7 @@ class Node(SourceMapping, ChildFunction): """ self._sons.append(son) - def set_sons(self, sons): + def set_sons(self, sons: List["Node"]): """ Set the son nodes Args: @@ -615,7 +640,7 @@ class Node(SourceMapping, ChildFunction): self._sons = sons @property - def sons(self): + def sons(self) -> List["Node"]: """ Returns the son nodes Returns: @@ -627,15 +652,13 @@ class Node(SourceMapping, ChildFunction): def son_true(self) -> Optional["Node"]: if self.type == NodeType.IF: return self._sons[0] - else: - return None + return None @property def son_false(self) -> Optional["Node"]: if self.type == NodeType.IF and len(self._sons) >= 1: return self._sons[1] - else: - return None + return None # endregion ################################################################################### @@ -645,7 +668,7 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def irs(self): + def irs(self) -> List[Operation]: """ Returns the slithIR representation return @@ -654,7 +677,7 @@ class Node(SourceMapping, ChildFunction): return self._irs @property - def irs_ssa(self): + def irs_ssa(self) -> List[Operation]: """ Returns the slithIR representation with SSA return @@ -666,10 +689,10 @@ class Node(SourceMapping, ChildFunction): def irs_ssa(self, irs): self._irs_ssa = irs - def add_ssa_ir(self, ir): - ''' + def add_ssa_ir(self, ir: Operation): + """ Use to place phi operation - ''' + """ ir.set_node(self) self._irs_ssa.append(ir) @@ -680,12 +703,19 @@ class Node(SourceMapping, ChildFunction): self._find_read_write_call() + def all_slithir_operations(self) -> List[Operation]: + irs = self.irs + for ir in irs: + if isinstance(ir, InternalCall): + irs += ir.function.all_slithir_operations() + return irs + @staticmethod - def _is_non_slithir_var(var): + def _is_non_slithir_var(var: Variable): return not isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) @staticmethod - def _is_valid_slithir_var(var): + def _is_valid_slithir_var(var: Variable): return isinstance(var, (ReferenceVariable, TemporaryVariable, TupleVariable)) # endregion @@ -696,44 +726,64 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def dominators(self): - ''' + def dominators(self) -> Set["Node"]: + """ Returns: set(Node) - ''' + """ return self._dominators + @dominators.setter + def dominators(self, dom: Set["Node"]): + self._dominators = dom + @property - def immediate_dominator(self): - ''' + def immediate_dominator(self) -> Optional["Node"]: + """ Returns: Node or None - ''' + """ return self._immediate_dominator + @immediate_dominator.setter + def immediate_dominator(self, idom: "Node"): + self._immediate_dominator = idom + @property - def dominance_frontier(self): - ''' + def dominance_frontier(self) -> Set["Node"]: + """ Returns: set(Node) - ''' + """ return self._dominance_frontier + @dominance_frontier.setter + def dominance_frontier(self, doms: Set["Node"]): + """ + Returns: + set(Node) + """ + self._dominance_frontier = doms + @property def dominator_successors(self): return self._dom_successors - @dominators.setter - def dominators(self, dom): - self._dominators = dom - - @immediate_dominator.setter - def immediate_dominator(self, idom): - self._immediate_dominator = idom + @property + def dominance_exploration_ordered(self) -> List["Node"]: + """ + Sorted list of all the nodes to explore to follow the dom + :return: list(nodes) + """ + # Explore direct dominance + to_explore = sorted(list(self.dominator_successors), key=lambda x: x.node_id) - @dominance_frontier.setter - def dominance_frontier(self, dom): - self._dominance_frontier = dom + # Explore dominance frontier + # The frontier is the limit where this node dominates + # We need to explore it because the sub of the direct dominance + # Might not be dominator of their own sub + to_explore += sorted(list(self.dominance_frontier), key=lambda x: x.node_id) + return to_explore # endregion ################################################################################### @@ -743,27 +793,41 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def phi_origins_local_variables(self): + def phi_origins_local_variables(self) -> Dict[str, Tuple[LocalVariable, Set["Node"]]]: return self._phi_origins_local_variables @property - def phi_origins_state_variables(self): + def phi_origins_state_variables(self) -> Dict[str, Tuple[StateVariable, Set["Node"]]]: return self._phi_origins_state_variables - def add_phi_origin_local_variable(self, variable, node): + # @property + # def phi_origin_member_variables(self) -> Dict[str, Tuple[MemberVariable, Set["Node"]]]: + # return self._phi_origins_member_variables + + def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node"): if variable.name not in self._phi_origins_local_variables: self._phi_origins_local_variables[variable.name] = (variable, set()) (v, nodes) = self._phi_origins_local_variables[variable.name] assert v == variable nodes.add(node) - def add_phi_origin_state_variable(self, variable, node): + def add_phi_origin_state_variable(self, variable: StateVariable, node: "Node"): if variable.canonical_name not in self._phi_origins_state_variables: - self._phi_origins_state_variables[variable.canonical_name] = (variable, set()) + self._phi_origins_state_variables[variable.canonical_name] = ( + variable, + set(), + ) (v, nodes) = self._phi_origins_state_variables[variable.canonical_name] assert v == variable nodes.add(node) + # def add_phi_origin_member_variable(self, variable: MemberVariable, node: "Node"): + # if variable.name not in self._phi_origins_member_variables: + # self._phi_origins_member_variables[variable.name] = (variable, set()) + # (v, nodes) = self._phi_origins_member_variables[variable.name] + # assert v == variable + # nodes.add(node) + # endregion ################################################################################### ################################################################################### @@ -784,13 +848,13 @@ class Node(SourceMapping, ChildFunction): if not isinstance(ir, (Phi, Index, Member)): self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)] for var in ir.read: - if isinstance(var, (ReferenceVariable)): + if isinstance(var, ReferenceVariable): self._vars_read.append(var.points_to_origin) elif isinstance(ir, (Member, Index)): var = ir.variable_left if isinstance(ir, Member) else ir.variable_right if self._is_non_slithir_var(var): self._vars_read.append(var) - if isinstance(var, (ReferenceVariable)): + if isinstance(var, ReferenceVariable): origin = var.points_to_origin if self._is_non_slithir_var(origin): self._vars_read.append(origin) @@ -799,7 +863,7 @@ class Node(SourceMapping, ChildFunction): if isinstance(ir, (Index, Member, Length, Balance)): continue # Don't consider Member and Index operations -> ReferenceVariable var = ir.lvalue - if isinstance(var, (ReferenceVariable)): + if isinstance(var, ReferenceVariable): var = var.points_to_origin if var and self._is_non_slithir_var(var): self._vars_written.append(var) @@ -816,14 +880,15 @@ class Node(SourceMapping, ChildFunction): elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): if isinstance(ir.destination.type, Contract): self._high_level_calls.append((ir.destination.type, ir.function)) - elif ir.destination == SolidityVariable('this'): + elif ir.destination == SolidityVariable("this"): self._high_level_calls.append((self.function.contract, ir.function)) else: try: self._high_level_calls.append((ir.destination.type.type, ir.function)) except AttributeError: raise SlitherException( - f'Function not found on {ir}. Please try compiling with a recent Solidity version.') + f"Function not found on {ir}. Please try compiling with a recent Solidity version." + ) elif isinstance(ir, LibraryCall): assert isinstance(ir.destination, Contract) self._high_level_calls.append((ir.destination, ir.function)) @@ -843,7 +908,7 @@ class Node(SourceMapping, ChildFunction): self._low_level_calls = list(set(self._low_level_calls)) @staticmethod - def _convert_ssa(v): + def _convert_ssa(v: Variable): if isinstance(v, StateIRVariable): contract = v.contract non_ssa_var = contract.get_state_variable_from_name(v.name) @@ -857,14 +922,14 @@ class Node(SourceMapping, ChildFunction): if not self.expression: return for ir in self.irs_ssa: - if isinstance(ir, (PhiCallback)): + if isinstance(ir, PhiCallback): continue if not isinstance(ir, (Phi, Index, Member)): - self._ssa_vars_read += [v for v in ir.read if isinstance(v, - (StateIRVariable, - LocalIRVariable))] + self._ssa_vars_read += [ + v for v in ir.read if isinstance(v, (StateIRVariable, LocalIRVariable)) + ] for var in ir.read: - if isinstance(var, (ReferenceVariable)): + if isinstance(var, ReferenceVariable): origin = var.points_to_origin if isinstance(origin, (StateIRVariable, LocalIRVariable)): self._ssa_vars_read.append(origin) @@ -872,7 +937,7 @@ class Node(SourceMapping, ChildFunction): elif isinstance(ir, (Member, Index)): if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)): self._ssa_vars_read.append(ir.variable_right) - if isinstance(ir.variable_right, (ReferenceVariable)): + if isinstance(ir.variable_right, ReferenceVariable): origin = ir.variable_right.points_to_origin if isinstance(origin, (StateIRVariable, LocalIRVariable)): self._ssa_vars_read.append(origin) @@ -881,19 +946,23 @@ class Node(SourceMapping, ChildFunction): if isinstance(ir, (Index, Member, Length, Balance)): continue # Don't consider Member and Index operations -> ReferenceVariable var = ir.lvalue - if isinstance(var, (ReferenceVariable)): + if isinstance(var, ReferenceVariable): var = var.points_to_origin # Only store non-slithIR variables if var and isinstance(var, (StateIRVariable, LocalIRVariable)): - if isinstance(ir, (PhiCallback)): + if isinstance(ir, PhiCallback): continue self._ssa_vars_written.append(var) self._ssa_vars_read = list(set(self._ssa_vars_read)) self._ssa_state_vars_read = [v for v in self._ssa_vars_read if isinstance(v, StateVariable)] self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] self._ssa_vars_written = list(set(self._ssa_vars_written)) - self._ssa_state_vars_written = [v for v in self._ssa_vars_written if isinstance(v, StateVariable)] - self._ssa_local_vars_written = [v for v in self._ssa_vars_written if isinstance(v, LocalVariable)] + self._ssa_state_vars_written = [ + v for v in self._ssa_vars_written if isinstance(v, StateVariable) + ] + self._ssa_local_vars_written = [ + v for v in self._ssa_vars_written if isinstance(v, LocalVariable) + ] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] @@ -919,7 +988,50 @@ class Node(SourceMapping, ChildFunction): additional_info += ' ' + str(self.expression) elif self.variable_declaration: additional_info += ' ' + str(self.variable_declaration) - txt = NodeType.str(self._node_type) + additional_info + txt = str(self._node_type) + additional_info return txt - # endregion + +# endregion +################################################################################### +################################################################################### +# region Utils +################################################################################### +################################################################################### + + +def link_nodes(node1: Node, node2: Node): + node1.add_son(node2) + node2.add_father(node1) + + +def insert_node(origin: Node, node_inserted: Node): + sons = origin.sons + link_nodes(origin, node_inserted) + for son in sons: + son.remove_father(origin) + origin.remove_son(son) + + link_nodes(node_inserted, son) + + +def recheable(node: Node) -> Set[Node]: + """ + Return the set of nodes reacheable from the node + :param node: + :return: set(Node) + """ + nodes = node.sons + visited = set() + while nodes: + next = nodes[0] + nodes = nodes[1:] + if next not in visited: + visited.add(next) + for son in next.sons: + if son not in visited: + nodes.append(son) + return visited + + +# endregion diff --git a/slither/core/children/child_contract.py b/slither/core/children/child_contract.py index 9ca39af8e..7d6ce0a86 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/children/child_contract.py @@ -1,14 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither.core.declarations import Contract -class ChildContract: +class ChildContract: def __init__(self): super(ChildContract, self).__init__() self._contract = None - def set_contract(self, contract): + def set_contract(self, contract: "Contract"): self._contract = contract @property - def contract(self): + def contract(self) -> "Contract": return self._contract - diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py index 184c2c779..21a35319b 100644 --- a/slither/core/children/child_event.py +++ b/slither/core/children/child_event.py @@ -1,12 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither.core.declarations import Event + class ChildEvent: def __init__(self): super(ChildEvent, self).__init__() self._event = None - def set_event(self, event): + def set_event(self, event: "Event"): self._event = event @property - def event(self): + def event(self) -> "Event": return self._event diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py index f918e7a52..bbfcadd22 100644 --- a/slither/core/children/child_expression.py +++ b/slither/core/children/child_expression.py @@ -1,12 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither.core.expressions.expression import Expression + class ChildExpression: def __init__(self): super(ChildExpression, self).__init__() self._expression = None - def set_expression(self, expression): + def set_expression(self, expression: "Expression"): self._expression = expression @property - def expression(self): + def expression(self) -> "Expression": return self._expression diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py index 2c5dc72ac..3ace993f3 100644 --- a/slither/core/children/child_function.py +++ b/slither/core/children/child_function.py @@ -1,12 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither.core.declarations import Function + class ChildFunction: def __init__(self): super(ChildFunction, self).__init__() self._function = None - def set_function(self, function): + def set_function(self, function: "Function"): self._function = function @property - def function(self): + def function(self) -> "Function": return self._function diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py index cc9c4065f..089401442 100644 --- a/slither/core/children/child_inheritance.py +++ b/slither/core/children/child_inheritance.py @@ -1,13 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither.core.declarations import Contract -class ChildInheritance: +class ChildInheritance: def __init__(self): super(ChildInheritance, self).__init__() self._contract_declarer = None - def set_contract_declarer(self, contract): + def set_contract_declarer(self, contract: "Contract"): self._contract_declarer = contract @property - def contract_declarer(self): + def contract_declarer(self) -> "Contract": return self._contract_declarer diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index bd6fd4e6f..a90acfe48 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -1,24 +1,31 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither import Slither + from slither.core.cfg.node import Node + from slither.core.declarations import Function, Contract + class ChildNode(object): def __init__(self): super(ChildNode, self).__init__() self._node = None - def set_node(self, node): + def set_node(self, node: "Node"): self._node = node @property - def node(self): + def node(self) -> "Node": return self._node @property - def function(self): + def function(self) -> "Function": return self.node.function @property - def contract(self): + def contract(self) -> "Contract": return self.node.function.contract @property - def slither(self): - return self.contract.slither \ No newline at end of file + def slither(self) -> "Slither": + return self.contract.slither diff --git a/slither/core/children/child_slither.py b/slither/core/children/child_slither.py index 5c11664c1..d7bffc889 100644 --- a/slither/core/children/child_slither.py +++ b/slither/core/children/child_slither.py @@ -1,13 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither import Slither -class ChildSlither: +class ChildSlither: def __init__(self): super(ChildSlither, self).__init__() self._slither = None - def set_slither(self, slither): + def set_slither(self, slither: "Slither"): self._slither = slither @property - def slither(self): + def slither(self) -> "Slither": return self._slither diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py index f5fc34cc4..6fd3aa7ad 100644 --- a/slither/core/children/child_structure.py +++ b/slither/core/children/child_structure.py @@ -1,13 +1,17 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from slither.core.declarations import Structure -class ChildStructure: +class ChildStructure: def __init__(self): super(ChildStructure, self).__init__() self._structure = None - def set_structure(self, structure): + def set_structure(self, structure: "Structure"): self._structure = structure @property - def structure(self): + def structure(self) -> "Structure": return self._structure diff --git a/slither/core/context/context.py b/slither/core/context/context.py index e5966ed0a..d16178a58 100644 --- a/slither/core/context/context.py +++ b/slither/core/context/context.py @@ -1,14 +1,15 @@ -class Context: +from collections import defaultdict +from typing import Dict + +class Context: def __init__(self): super(Context, self).__init__() - self._context = {} + self._context = {"MEMBERS": defaultdict(None)} @property - def context(self): - ''' + def context(self) -> Dict: + """ Dict used by analysis - ''' + """ return self._context - - diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index f7ad861cd..992e2d853 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -5,16 +5,33 @@ import logging from pathlib import Path from crytic_compile.platform import Type as PlatformType +from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union from slither.core.children.child_slither import ChildSlither +from slither.core.solidity_types.type import Type from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.declarations.function import Function -from slither.utils.erc import ERC20_signatures, \ - ERC165_signatures, ERC223_signatures, ERC721_signatures, \ - ERC1820_signatures, ERC777_signatures + +from slither.core.declarations.function import Function, FunctionType +from slither.utils.erc import ( + ERC20_signatures, + ERC165_signatures, + ERC223_signatures, + ERC721_signatures, + ERC1820_signatures, + ERC777_signatures, +) from slither.utils.tests_pattern import is_test_contract -logger = logging.getLogger("Contract") +if TYPE_CHECKING: + from slither.utils.type_helpers import LibraryCallType, HighLevelCallType + from slither.core.declarations import Enum, Event, Modifier + from slither.core.declarations import Structure + from slither.slithir.variables.variable import SlithIRVariable + from slither.core.variables.variable import Variable + from slither.core.variables.state_variable import StateVariable + +LOGGER = logging.getLogger("Contract") + class Contract(ChildSlither, SourceMapping): """ @@ -23,40 +40,41 @@ class Contract(ChildSlither, SourceMapping): def __init__(self): super(Contract, self).__init__() - self._name = None - self._name = None - self._id = None - self._inheritance = [] # all contract inherited, c3 linearization - self._immediate_inheritance = [] # immediate inheritance + self._name: Optional[str] = None + self._id: Optional[int] = None + self._inheritance: List["Contract"] = [] # all contract inherited, c3 linearization + self._immediate_inheritance: List["Contract"] = [] # immediate inheritance # Constructors called on contract's definition # contract B is A(1) { .. - self._explicit_base_constructor_calls = [] + self._explicit_base_constructor_calls: List["Contract"] = [] - self._enums = {} - self._structures = {} - self._events = {} - self._variables = {} - self._variables_ordered = [] # contain also shadowed variables - self._modifiers = {} - self._functions = {} + self._enums: Dict[str, "Enum"] = {} + self._structures: Dict[str, "Structure"] = {} + self._events: Dict[str, "Event"] = {} + self._variables: Dict[str, "StateVariable"] = {} + self._variables_ordered: List["StateVariable"] = [] # contain also shadowed variables + self._modifiers: Dict[str, "Modifier"] = {} + self._functions: Dict[str, "Function"] = {} + self._linearizedBaseContracts = List[int] - self._using_for = {} - self._kind = None + # The only str is "*" + self._using_for: Dict[Union[str, Type], List[str]] = {} + self._kind: Optional[str] = None + self._is_interface: bool = False - self._signatures = None - self._signatures_declared = None + self._signatures: Optional[List[str]] = None + self._signatures_declared: Optional[List[str]] = None - self._is_upgradeable = None - self._is_upgradeable_proxy = None + self._is_upgradeable: Optional[bool] = None + self._is_upgradeable_proxy: Optional[bool] = None self._is_top_level = False + self._initial_state_variables: List["StateVariable"] = [] # ssa - self._initial_state_variables = [] # ssa - - self._is_incorrectly_parsed = False + self._is_incorrectly_parsed: bool = False ################################################################################### ################################################################################### @@ -65,19 +83,46 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def name(self): + def name(self) -> str: """str: Name of the contract.""" + assert self._name return self._name + @name.setter + def name(self, name: str): + self._name = name + @property - def id(self): + def id(self) -> int: """Unique id.""" + assert self._id return self._id + @id.setter + def id(self, new_id): + """Unique id.""" + self._id = new_id + @property - def contract_kind(self): + def contract_kind(self) -> Optional[str]: + """ + contract_kind can be None if the legacy ast format is used + :return: + """ return self._kind + @contract_kind.setter + def contract_kind(self, kind): + self._kind = kind + + @property + def is_interface(self) -> bool: + return self._is_interface + + @is_interface.setter + def is_interface(self, is_interface: bool): + self._is_interface = is_interface + # endregion ################################################################################### ################################################################################### @@ -86,27 +131,28 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def structures(self): - ''' + def structures(self) -> List["Structure"]: + """ list(Structure): List of the structures - ''' + """ return list(self._structures.values()) @property - def structures_inherited(self): - ''' + def structures_inherited(self) -> List["Structure"]: + """ list(Structure): List of the inherited structures - ''' + """ return [s for s in self.structures if s.contract != self] @property - def structures_declared(self): - ''' + def structures_declared(self) -> List["Structure"]: + """ list(Structues): List of the structures declared within the contract (not inherited) - ''' + """ return [s for s in self.structures if s.contract == self] - def structures_as_dict(self): + @property + def structures_as_dict(self) -> Dict[str, "Structure"]: return self._structures # endregion @@ -117,24 +163,25 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def enums(self): + def enums(self) -> List["Enum"]: return list(self._enums.values()) @property - def enums_inherited(self): - ''' + def enums_inherited(self) -> List["Enum"]: + """ list(Enum): List of the inherited enums - ''' + """ return [e for e in self.enums if e.contract != self] @property - def enums_declared(self): - ''' + def enums_declared(self) -> List["Enum"]: + """ list(Enum): List of the enums declared within the contract (not inherited) - ''' + """ return [e for e in self.enums if e.contract == self] - def enums_as_dict(self): + @property + def enums_as_dict(self) -> Dict[str, "Enum"]: return self._enums # endregion @@ -145,27 +192,28 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def events(self): - ''' + def events(self) -> List["Event"]: + """ list(Event): List of the events - ''' + """ return list(self._events.values()) @property - def events_inherited(self): - ''' + def events_inherited(self) -> List["Event"]: + """ list(Event): List of the inherited events - ''' + """ return [e for e in self.events if e.contract != self] @property - def events_declared(self): - ''' + def events_declared(self) -> List["Event"]: + """ list(Event): List of the events declared within the contract (not inherited) - ''' + """ return [e for e in self.events if e.contract == self] - def events_as_dict(self): + @property + def events_as_dict(self) -> Dict[str, "Event"]: return self._events # endregion @@ -176,16 +224,9 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def using_for(self): + def using_for(self) -> Dict[Union[str, Type], List[str]]: return self._using_for - def reverse_using_for(self, name): - ''' - Returns: - (list) - ''' - return self._using_for[name] - # endregion ################################################################################### ################################################################################### @@ -194,49 +235,53 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def variables(self): - ''' + def variables(self) -> List["StateVariable"]: + """ list(StateVariable): List of the state variables. Alias to self.state_variables - ''' + """ return list(self.state_variables) - def variables_as_dict(self): + @property + def variables_as_dict(self) -> Dict[str, "StateVariable"]: return self._variables @property - def state_variables(self): - ''' + def state_variables(self) -> List["StateVariable"]: + """ list(StateVariable): List of the state variables. - ''' + """ return list(self._variables.values()) @property - def state_variables_ordered(self): - ''' + def state_variables_ordered(self) -> List["StateVariable"]: + """ list(StateVariable): List of the state variables by order of declaration. Contains also shadowed variables - ''' + """ return list(self._variables_ordered) + def add_variables_ordered(self, new_vars: List["StateVariable"]): + self._variables_ordered += new_vars + @property - def state_variables_inherited(self): - ''' + def state_variables_inherited(self) -> List["StateVariable"]: + """ list(StateVariable): List of the inherited state variables - ''' + """ return [s for s in self.state_variables if s.contract != self] @property - def state_variables_declared(self): - ''' + def state_variables_declared(self) -> List["StateVariable"]: + """ list(StateVariable): List of the state variables declared within the contract (not inherited) - ''' + """ return [s for s in self.state_variables if s.contract == self] @property - def slithir_variables(self): - ''' + def slithir_variables(self) -> List["SlithIRVariable"]: + """ List all of the slithir variables (non SSA) - ''' - slithir_variables = [f.slithir_variables for f in self.functions + self.modifiers] + """ + slithir_variables = [f.slithir_variables for f in self.functions + self.modifiers] # type: ignore slithir_variables = [item for sublist in slithir_variables for item in sublist] return list(set(slithir_variables)) @@ -248,13 +293,13 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def constructor(self): - ''' + def constructor(self) -> Optional["Function"]: + """ Return the contract's immediate constructor. If there is no immediate constructor, returns the first constructor executed, following the c3 linearization Return None if there is no constructor. - ''' + """ cst = self.constructors_declared if cst: return cst @@ -265,18 +310,25 @@ class Contract(ChildSlither, SourceMapping): return None @property - def constructors_declared(self): - return next((func for func in self.functions if func.is_constructor and func.contract_declarer == self), None) + def constructors_declared(self) -> Optional["Function"]: + return next( + ( + func + for func in self.functions + if func.is_constructor and func.contract_declarer == self + ), + None, + ) @property - def constructors(self): - ''' + def constructors(self) -> List["Function"]: + """ Return the list of constructors (including inherited) - ''' + """ return [func for func in self.functions if func.is_constructor] @property - def explicit_base_constructor_calls(self): + def explicit_base_constructor_calls(self) -> List["Function"]: """ list(Function): List of the base constructors called explicitly by this contract definition. @@ -296,110 +348,148 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def functions_signatures(self): + def functions_signatures(self) -> List[str]: """ Return the signatures of all the public/eterxnal functions/state variables :return: list(string) the signatures of all the functions that can be called """ if self._signatures is None: - sigs = [v.full_name for v in self.state_variables if v.visibility in ['public', - 'external']] + sigs = [ + v.full_name for v in self.state_variables if v.visibility in ["public", "external"] + ] - sigs += set([f.full_name for f in self.functions if f.visibility in ['public', 'external']]) + sigs += set( + [f.full_name for f in self.functions if f.visibility in ["public", "external"]] + ) self._signatures = list(set(sigs)) return self._signatures @property - def functions_signatures_declared(self): + def functions_signatures_declared(self) -> List[str]: """ Return the signatures of the public/eterxnal functions/state variables that are declared by this contract :return: list(string) the signatures of all the functions that can be called and are declared by this contract """ if self._signatures_declared is None: - sigs = [v.full_name for v in self.state_variables_declared if v.visibility in ['public', - 'external']] - - sigs += set([f.full_name for f in self.functions_declared if f.visibility in ['public', 'external']]) + sigs = [ + v.full_name + for v in self.state_variables_declared + if v.visibility in ["public", "external"] + ] + + sigs += set( + [ + f.full_name + for f in self.functions_declared + if f.visibility in ["public", "external"] + ] + ) self._signatures_declared = list(set(sigs)) return self._signatures_declared @property - def functions(self): - ''' + def functions(self) -> List["Function"]: + """ list(Function): List of the functions - ''' + """ return list(self._functions.values()) - def available_functions_as_dict(self): + def available_functions_as_dict(self) -> Dict[str, "Function"]: return {f.full_name: f for f in self._functions.values() if not f.is_shadowed} + def set_functions(self, functions: Dict[str, "Function"]): + """ + Set the functions + + :param functions: dict full_name -> function + :return: + """ + self._functions = functions + @property - def functions_inherited(self): - ''' + def functions_inherited(self) -> List["Function"]: + """ list(Function): List of the inherited functions - ''' + """ return [f for f in self.functions if f.contract_declarer != self] @property - def functions_declared(self): - ''' + def functions_declared(self) -> List["Function"]: + """ list(Function): List of the functions defined within the contract (not inherited) - ''' + """ return [f for f in self.functions if f.contract_declarer == self] @property - def functions_entry_points(self): - ''' + def functions_entry_points(self) -> List["Function"]: + """ list(Functions): List of public and external functions - ''' - return [f for f in self.functions if f.visibility in ['public', 'external'] and not f.is_shadowed] + """ + return [ + f + for f in self.functions + if f.visibility in ["public", "external"] and not f.is_shadowed + ] @property - def modifiers(self): - ''' + def modifiers(self) -> List["Modifier"]: + """ list(Modifier): List of the modifiers - ''' + """ return list(self._modifiers.values()) - def available_modifiers_as_dict(self): + def available_modifiers_as_dict(self) -> Dict[str, "Modifier"]: return {m.full_name: m for m in self._modifiers.values() if not m.is_shadowed} + def set_modifiers(self, modifiers: Dict[str, "Modifier"]): + """ + Set the modifiers + + :param modifiers: dict full_name -> modifier + :return: + """ + self._modifiers = modifiers + @property - def modifiers_inherited(self): - ''' + def modifiers_inherited(self) -> List["Modifier"]: + """ list(Modifier): List of the inherited modifiers - ''' + """ return [m for m in self.modifiers if m.contract_declarer != self] @property - def modifiers_declared(self): - ''' + def modifiers_declared(self) -> List["Modifier"]: + """ list(Modifier): List of the modifiers defined within the contract (not inherited) - ''' + """ return [m for m in self.modifiers if m.contract_declarer == self] @property - def functions_and_modifiers(self): - ''' + def functions_and_modifiers(self) -> List["Function"]: + """ list(Function|Modifier): List of the functions and modifiers - ''' - return self.functions + self.modifiers + """ + return self.functions + self.modifiers # type: ignore @property - def functions_and_modifiers_inherited(self): - ''' + def functions_and_modifiers_inherited(self) -> List["Function"]: + """ list(Function|Modifier): List of the inherited functions and modifiers - ''' - return self.functions_inherited + self.modifiers_inherited + """ + return self.functions_inherited + self.modifiers_inherited # type: ignore @property - def functions_and_modifiers_declared(self): - ''' + def functions_and_modifiers_declared(self) -> List["Function"]: + """ list(Function|Modifier): List of the functions and modifiers defined within the contract (not inherited) - ''' - return self.functions_declared + self.modifiers_declared + """ + return self.functions_declared + self.modifiers_declared # type: ignore - def available_elements_from_inheritances(self, elements, getter_available): + def available_elements_from_inheritances( + self, + elements: Dict[str, "Function"], + getter_available: Callable[["Contract"], List["Function"]], + ) -> Dict[str, "Function"]: """ :param elements: dict(canonical_name -> elements) @@ -409,12 +499,15 @@ class Contract(ChildSlither, SourceMapping): # keep track of the contracts visited # to prevent an ovveride due to multiple inheritance of the same contract # A is B, C, D is C, --> the second C was already seen - inherited_elements = {} + inherited_elements: Dict[str, "Function"] = {} accessible_elements = {} contracts_visited = [] for father in self.inheritance_reverse: - functions = {v.full_name: v for (v) in getter_available(father) - if not v.contract in contracts_visited} + functions: Dict[str, "Function"] = { + v.full_name: v + for v in getter_available(father) + if v.contract not in contracts_visited + } contracts_visited.append(father) inherited_elements.update(functions) @@ -423,7 +516,6 @@ class Contract(ChildSlither, SourceMapping): return accessible_elements - # endregion ################################################################################### ################################################################################### @@ -432,36 +524,41 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def inheritance(self): - ''' + def inheritance(self) -> List["Contract"]: + """ list(Contract): Inheritance list. Order: the first elem is the first father to be executed - ''' + """ return list(self._inheritance) @property - def immediate_inheritance(self): - ''' + def immediate_inheritance(self) -> List["Contract"]: + """ list(Contract): List of contracts immediately inherited from (fathers). Order: order of declaration. - ''' + """ return list(self._immediate_inheritance) @property - def inheritance_reverse(self): - ''' + def inheritance_reverse(self) -> List["Contract"]: + """ list(Contract): Inheritance list. Order: the last elem is the first father to be executed - ''' - return reversed(self._inheritance) - - def setInheritance(self, inheritance, immediate_inheritance, called_base_constructor_contracts): + """ + return list(reversed(self._inheritance)) + + def set_inheritance( + self, + inheritance: List["Contract"], + immediate_inheritance: List["Contract"], + called_base_constructor_contracts: List["Contract"], + ): self._inheritance = inheritance self._immediate_inheritance = immediate_inheritance self._explicit_base_constructor_calls = called_base_constructor_contracts @property - def derived_contracts(self): - ''' + def derived_contracts(self) -> List["Contract"]: + """ list(Contract): Return the list of contracts derived from self - ''' + """ candidates = self.slither.contracts return [c for c in candidates if self in c.inheritance] @@ -472,19 +569,19 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### ################################################################################### - def get_functions_reading_from_variable(self, variable): - ''' + def get_functions_reading_from_variable(self, variable: "Variable") -> List["Function"]: + """ Return the functions reading the variable - ''' + """ return [f for f in self.functions if f.is_reading(variable)] - def get_functions_writing_to_variable(self, variable): - ''' + def get_functions_writing_to_variable(self, variable: "Variable") -> List["Function"]: + """ Return the functions writting the variable - ''' + """ return [f for f in self.functions if f.is_writing(variable)] - def get_function_from_signature(self, function_signature): + def get_function_from_signature(self, function_signature: str) -> Optional["Function"]: """ Return a function from a signature Args: @@ -492,19 +589,23 @@ class Contract(ChildSlither, SourceMapping): Returns: Function """ - return next((f for f in self.functions if f.full_name == function_signature and not f.is_shadowed), None) + return next( + (f for f in self.functions if f.full_name == function_signature and not f.is_shadowed), + None, + ) - def get_modifier_from_signature(self, modifier_signature): + def get_modifier_from_signature(self, modifier_signature: str) -> Optional["Modifier"]: """ Return a modifier from a signature - Args: - modifier_name (str): signature of the modifier - Returns: - Modifier + + :param modifier_signature: """ - return next((m for m in self.modifiers if m.full_name == modifier_signature and not m.is_shadowed), None) + return next( + (m for m in self.modifiers if m.full_name == modifier_signature and not m.is_shadowed), + None, + ) - def get_function_from_canonical_name(self, canonical_name): + def get_function_from_canonical_name(self, canonical_name: str) -> Optional["Function"]: """ Return a function from a a canonical name (contract.signature()) Args: @@ -514,7 +615,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((f for f in self.functions if f.canonical_name == canonical_name), None) - def get_modifier_from_canonical_name(self, canonical_name): + def get_modifier_from_canonical_name(self, canonical_name: str) -> Optional["Modifier"]: """ Return a modifier from a canonical name (contract.signature()) Args: @@ -524,18 +625,17 @@ class Contract(ChildSlither, SourceMapping): """ return next((m for m in self.modifiers if m.canonical_name == canonical_name), None) - - def get_state_variable_from_name(self, variable_name): + def get_state_variable_from_name(self, variable_name: str) -> Optional["StateVariable"]: """ Return a state variable from a name - Args: - varible_name (str): name of the variable - Returns: - StateVariable + + :param variable_name: """ return next((v for v in self.state_variables if v.name == variable_name), None) - def get_state_variable_from_canonical_name(self, canonical_name): + def get_state_variable_from_canonical_name( + self, canonical_name: str + ) -> Optional["StateVariable"]: """ Return a state variable from a canonical_name Args: @@ -545,7 +645,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((v for v in self.state_variables if v.name == canonical_name), None) - def get_structure_from_name(self, structure_name): + def get_structure_from_name(self, structure_name: str) -> Optional["Structure"]: """ Return a structure from a name Args: @@ -555,7 +655,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((st for st in self.structures if st.name == structure_name), None) - def get_structure_from_canonical_name(self, structure_name): + def get_structure_from_canonical_name(self, structure_name: str) -> Optional["Structure"]: """ Return a structure from a canonical name Args: @@ -565,7 +665,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((st for st in self.structures if st.canonical_name == structure_name), None) - def get_event_from_signature(self, event_signature): + def get_event_from_signature(self, event_signature: str) -> Optional["Event"]: """ Return an event from a signature Args: @@ -575,7 +675,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((e for e in self.events if e.full_name == event_signature), None) - def get_event_from_canonical_name(self, event_canonical_name): + def get_event_from_canonical_name(self, event_canonical_name: str) -> Optional["Event"]: """ Return an event from a canonical name Args: @@ -585,7 +685,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((e for e in self.events if e.canonical_name == event_canonical_name), None) - def get_enum_from_name(self, enum_name): + def get_enum_from_name(self, enum_name: str) -> Optional["Enum"]: """ Return an enum from a name Args: @@ -595,7 +695,7 @@ class Contract(ChildSlither, SourceMapping): """ return next((e for e in self.enums if e.name == enum_name), None) - def get_enum_from_canonical_name(self, enum_name): + def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]: """ Return an enum from a canonical name Args: @@ -605,17 +705,17 @@ class Contract(ChildSlither, SourceMapping): """ return next((e for e in self.enums if e.canonical_name == enum_name), None) - def get_functions_overridden_by(self, function): - ''' + def get_functions_overridden_by(self, function: "Function") -> List["Function"]: + """ Return the list of functions overriden by the function Args: (core.Function) Returns: list(core.Function) - ''' - candidates = [c.functions_declared for c in self.inheritance] - candidates = [candidate for sublist in candidates for candidate in sublist] + """ + candidatess = [c.functions_declared for c in self.inheritance] + candidates = [candidate for sublist in candidatess for candidate in sublist] return [f for f in candidates if f.full_name == function.full_name] # endregion @@ -626,60 +726,67 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def all_functions_called(self): - ''' + def all_functions_called(self) -> List["Function"]: + """ list(Function): List of functions reachable from the contract Includes super, and private/internal functions not shadowed - ''' - all_calls = [f for f in self.functions + self.modifiers if not f.is_shadowed] - all_calls = [f.all_internal_calls() for f in all_calls] + [all_calls] - all_calls = [item for sublist in all_calls for item in sublist] + """ + all_calls = [f for f in self.functions + self.modifiers if not f.is_shadowed] # type: ignore + all_callss = [f.all_internal_calls() for f in all_calls] + [all_calls] + all_calls = [item for sublist in all_callss for item in sublist] all_calls = list(set(all_calls)) - all_constructors = [c.constructor for c in self.inheritance] - all_constructors = list(set([c for c in all_constructors if c])) + all_constructors = [c.constructor for c in self.inheritance if c.constructor] + all_constructors = list(set(all_constructors)) - all_calls = set(all_calls+all_constructors) + set_all_calls = set(all_calls + all_constructors) - return [c for c in all_calls if isinstance(c, Function)] + return [c for c in set_all_calls if isinstance(c, Function)] @property - def all_state_variables_written(self): - ''' + def all_state_variables_written(self) -> List["StateVariable"]: + """ list(StateVariable): List all of the state variables written - ''' - all_state_variables_written = [f.all_state_variables_written() for f in self.functions + self.modifiers] - all_state_variables_written = [item for sublist in all_state_variables_written for item in sublist] + """ + all_state_variables_written = [ + f.all_state_variables_written() for f in self.functions + self.modifiers # type: ignore + ] + all_state_variables_written = [ + item for sublist in all_state_variables_written for item in sublist + ] return list(set(all_state_variables_written)) @property - def all_state_variables_read(self): - ''' + def all_state_variables_read(self) -> List["StateVariable"]: + """ list(StateVariable): List all of the state variables read - ''' - all_state_variables_read = [f.all_state_variables_read() for f in self.functions + self.modifiers] - all_state_variables_read = [item for sublist in all_state_variables_read for item in sublist] + """ + all_state_variables_read = [ + f.all_state_variables_read() for f in self.functions + self.modifiers # type: ignore + ] + all_state_variables_read = [ + item for sublist in all_state_variables_read for item in sublist + ] return list(set(all_state_variables_read)) @property - def all_library_calls(self): - ''' + def all_library_calls(self) -> List["LibraryCallType"]: + """ list((Contract, Function): List all of the libraries func called - ''' - all_high_level_calls = [f.all_library_calls() for f in self.functions + self.modifiers] + """ + all_high_level_calls = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [item for sublist in all_high_level_calls for item in sublist] return list(set(all_high_level_calls)) @property - def all_high_level_calls(self): - ''' + def all_high_level_calls(self) -> List["HighLevelCallType"]: + """ list((Contract, Function|Variable)): List all of the external high level calls - ''' - all_high_level_calls = [f.all_high_level_calls() for f in self.functions + self.modifiers] + """ + all_high_level_calls = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [item for sublist in all_high_level_calls for item in sublist] return list(set(all_high_level_calls)) - # endregion ################################################################################### ################################################################################### @@ -687,18 +794,30 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### ################################################################################### - def get_summary(self, include_shadowed=True): + def get_summary( + self, include_shadowed=True + ) -> Tuple[str, List[str], List[str], List[str], List[str]]: """ Return the function summary :param include_shadowed: boolean to indicate if shadowed functions should be included (default True) Returns: (str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) """ - func_summaries = [f.get_summary() for f in self.functions if (not f.is_shadowed or include_shadowed)] - modif_summaries = [f.get_summary() for f in self.modifiers if (not f.is_shadowed or include_shadowed)] - return (self.name, [str(x) for x in self.inheritance], [str(x) for x in self.variables], func_summaries, modif_summaries) - - def is_signature_only(self): + func_summaries = [ + f.get_summary() for f in self.functions if (not f.is_shadowed or include_shadowed) + ] + modif_summaries = [ + f.get_summary() for f in self.modifiers if (not f.is_shadowed or include_shadowed) + ] + return ( + self.name, + [str(x) for x in self.inheritance], + [str(x) for x in self.variables], + func_summaries, + modif_summaries, + ) + + def is_signature_only(self) -> bool: """ Detect if the contract has only abstract functions Returns: @@ -713,21 +832,23 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### ################################################################################### - def ercs(self): + def ercs(self) -> List[str]: """ Return the ERC implemented :return: list of string """ - all = [('ERC20', lambda x: x.is_erc20()), - ('ERC165', lambda x: x.is_erc165()), - ('ERC1820', lambda x: x.is_erc1820()), - ('ERC223', lambda x: x.is_erc223()), - ('ERC721', lambda x: x.is_erc721()), - ('ERC777', lambda x: x.is_erc777())] - - return [erc[0] for erc in all if erc[1](self)] - - def is_erc20(self): + all_erc = [ + ("ERC20", lambda x: x.is_erc20()), + ("ERC165", lambda x: x.is_erc165()), + ("ERC1820", lambda x: x.is_erc1820()), + ("ERC223", lambda x: x.is_erc223()), + ("ERC721", lambda x: x.is_erc721()), + ("ERC777", lambda x: x.is_erc777()), + ] + + return [erc[0] for erc in all_erc if erc[1](self)] + + def is_erc20(self) -> bool: """ Check if the contract is an erc20 token @@ -737,7 +858,7 @@ class Contract(ChildSlither, SourceMapping): full_names = self.functions_signatures return all((s in full_names for s in ERC20_signatures)) - def is_erc165(self): + def is_erc165(self) -> bool: """ Check if the contract is an erc165 token @@ -747,7 +868,7 @@ class Contract(ChildSlither, SourceMapping): full_names = self.functions_signatures return all((s in full_names for s in ERC165_signatures)) - def is_erc1820(self): + def is_erc1820(self) -> bool: """ Check if the contract is an erc1820 @@ -757,7 +878,7 @@ class Contract(ChildSlither, SourceMapping): full_names = self.functions_signatures return all((s in full_names for s in ERC1820_signatures)) - def is_erc223(self): + def is_erc223(self) -> bool: """ Check if the contract is an erc223 token @@ -767,7 +888,7 @@ class Contract(ChildSlither, SourceMapping): full_names = self.functions_signatures return all((s in full_names for s in ERC223_signatures)) - def is_erc721(self): + def is_erc721(self) -> bool: """ Check if the contract is an erc721 token @@ -777,7 +898,7 @@ class Contract(ChildSlither, SourceMapping): full_names = self.functions_signatures return all((s in full_names for s in ERC721_signatures)) - def is_erc777(self): + def is_erc777(self) -> bool: """ Check if the contract is an erc777 @@ -793,34 +914,44 @@ class Contract(ChildSlither, SourceMapping): Check if the contract follows one of the standard ERC token :return: """ - return self.is_erc20() or self.is_erc721() or self.is_erc165() or self.is_erc223() or self.is_erc777() - - def is_possible_erc20(self): + return ( + self.is_erc20() + or self.is_erc721() + or self.is_erc165() + or self.is_erc223() + or self.is_erc777() + ) + + def is_possible_erc20(self) -> bool: """ Checks if the provided contract could be attempting to implement ERC20 standards. - :param contract: The contract to check for token compatibility. + :return: Returns a boolean indicating if the provided contract met the token standard. """ # We do not check for all the functions, as name(), symbol(), might give too many FPs full_names = self.functions_signatures - return 'transfer(address,uint256)' in full_names or \ - 'transferFrom(address,address,uint256)' in full_names or \ - 'approve(address,uint256)' in full_names + return ( + "transfer(address,uint256)" in full_names + or "transferFrom(address,address,uint256)" in full_names + or "approve(address,uint256)" in full_names + ) - def is_possible_erc721(self): + def is_possible_erc721(self) -> bool: """ Checks if the provided contract could be attempting to implement ERC721 standards. - :param contract: The contract to check for token compatibility. + :return: Returns a boolean indicating if the provided contract met the token standard. """ # We do not check for all the functions, as name(), symbol(), might give too many FPs full_names = self.functions_signatures - return ('ownerOf(uint256)' in full_names or - 'safeTransferFrom(address,address,uint256,bytes)' in full_names or - 'safeTransferFrom(address,address,uint256)' in full_names or - 'setApprovalForAll(address,bool)' in full_names or - 'getApproved(uint256)' in full_names or - 'isApprovedForAll(address,address)' in full_names) + return ( + "ownerOf(uint256)" in full_names + or "safeTransferFrom(address,address,uint256,bytes)" in full_names + or "safeTransferFrom(address,address,uint256)" in full_names + or "setApprovalForAll(address,bool)" in full_names + or "getApproved(uint256)" in full_names + or "isApprovedForAll(address,address)" in full_names + ) @property def is_possible_token(self) -> bool: @@ -837,10 +968,10 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### ################################################################################### - def is_from_dependency(self): + def is_from_dependency(self) -> bool: if self.slither.crytic_compile is None: return False - return self.slither.crytic_compile.is_dependency(self.source_mapping['filename_absolute']) + return self.slither.crytic_compile.is_dependency(self.source_mapping["filename_absolute"]) # endregion ################################################################################### @@ -857,10 +988,10 @@ class Contract(ChildSlither, SourceMapping): """ if self.slither.crytic_compile: if self.slither.crytic_compile.platform == PlatformType.TRUFFLE: - if self.name == 'Migrations': - paths = Path(self.source_mapping['filename_absolute']).parts + if self.name == "Migrations": + paths = Path(self.source_mapping["filename_absolute"]).parts if len(paths) >= 2: - return paths[-2] == 'contracts' and paths[-1] == 'migrations.sol' + return paths[-2] == "contracts" and paths[-1] == "migrations.sol" return False @property @@ -886,38 +1017,39 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def is_upgradeable(self): + def is_upgradeable(self) -> bool: if self._is_upgradeable is None: self._is_upgradeable = False - initializable = self.slither.get_contract_from_name('Initializable') + initializable = self.slither.get_contract_from_name("Initializable") if initializable: if initializable in self.inheritance: self._is_upgradeable = True else: for c in self.inheritance + [self]: # This might lead to false positive - if 'upgradeable' in c.name.lower() or 'upgradable' in c.name.lower(): + if "upgradeable" in c.name.lower() or "upgradable" in c.name.lower(): self._is_upgradeable = True break return self._is_upgradeable @property - def is_upgradeable_proxy(self): + def is_upgradeable_proxy(self) -> bool: from slither.core.cfg.node import NodeType from slither.slithir.operations import LowLevelCall + if self._is_upgradeable_proxy is None: self._is_upgradeable_proxy = False for f in self.functions: if f.is_fallback: for node in f.all_nodes(): for ir in node.irs: - if isinstance(ir, LowLevelCall) and ir.function_name == 'delegatecall': + if isinstance(ir, LowLevelCall) and ir.function_name == "delegatecall": self._is_upgradeable_proxy = True return self._is_upgradeable_proxy if node.type == NodeType.ASSEMBLY: inline_asm = node.inline_asm if inline_asm: - if 'delegatecall' in inline_asm: + if "delegatecall" in inline_asm: self._is_upgradeable_proxy = True return self._is_upgradeable_proxy return self._is_upgradeable_proxy @@ -930,13 +1062,140 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### @property - def is_incorrectly_constructed(self): + def is_incorrectly_constructed(self) -> bool: """ Return true if there was an internal Slither's issue when analyzing the contract :return: """ return self._is_incorrectly_parsed + @is_incorrectly_constructed.setter + def is_incorrectly_constructed(self, incorrect: bool): + self._is_incorrectly_parsed = incorrect + + def add_constructor_variables(self): + if self.state_variables: + for (idx, variable_candidate) in enumerate(self.state_variables): + if variable_candidate.expression and not variable_candidate.is_constant: + + constructor_variable = Function() + constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) + constructor_variable.set_contract(self) + constructor_variable.set_contract_declarer(self) + constructor_variable.set_visibility("internal") + # For now, source mapping of the constructor variable is the whole contract + # Could be improved with a targeted source mapping + constructor_variable.set_offset(self.source_mapping, self.slither) + self._functions[constructor_variable.canonical_name] = constructor_variable + + prev_node = self._create_node(constructor_variable, 0, variable_candidate) + variable_candidate.node_initialization = prev_node + counter = 1 + for v in self.state_variables[idx + 1 :]: + if v.expression and not v.is_constant: + next_node = self._create_node(constructor_variable, counter, v) + v.node_initialization = next_node + prev_node.add_son(next_node) + next_node.add_father(prev_node) + counter += 1 + break + + for (idx, variable_candidate) in enumerate(self.state_variables): + if variable_candidate.expression and variable_candidate.is_constant: + + constructor_variable = Function() + constructor_variable.set_function_type( + FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES + ) + constructor_variable.set_contract(self) + constructor_variable.set_contract_declarer(self) + constructor_variable.set_visibility("internal") + # For now, source mapping of the constructor variable is the whole contract + # Could be improved with a targeted source mapping + constructor_variable.set_offset(self.source_mapping, self.slither) + self._functions[constructor_variable.canonical_name] = constructor_variable + + prev_node = self._create_node(constructor_variable, 0, variable_candidate) + variable_candidate.node_initialization = prev_node + counter = 1 + for v in self.state_variables[idx + 1 :]: + if v.expression and v.is_constant: + next_node = self._create_node(constructor_variable, counter, v) + v.node_initialization = next_node + prev_node.add_son(next_node) + next_node.add_father(prev_node) + counter += 1 + + break + + def _create_node(self, func: Function, counter: int, variable: "Variable"): + from slither.core.cfg.node import Node, NodeType + from slither.core.expressions import ( + AssignmentOperationType, + AssignmentOperation, + Identifier, + ) + + # Function uses to create node for state variable declaration statements + node = Node(NodeType.OTHER_ENTRYPOINT, counter) + node.set_offset(variable.source_mapping, self.slither) + node.set_function(func) + func.add_node(node) + expression = AssignmentOperation( + Identifier(variable), variable.expression, AssignmentOperationType.ASSIGN, variable.type + ) + + expression.set_offset(variable.source_mapping, self.slither) + node.add_expression(expression) + return node + + # endregion + ################################################################################### + ################################################################################### + # region SlithIR + ################################################################################### + ################################################################################### + + def convert_expression_to_slithir_ssa(self): + """ + Assume generate_slithir_and_analyze was called on all functions + + :return: + """ + from slither.slithir.variables import StateIRVariable + + all_ssa_state_variables_instances = dict() + + for contract in self.inheritance: + for v in contract.state_variables_declared: + new_var = StateIRVariable(v) + all_ssa_state_variables_instances[v.canonical_name] = new_var + self._initial_state_variables.append(new_var) + + for v in self.variables: + if v.contract == self: + new_var = StateIRVariable(v) + all_ssa_state_variables_instances[v.canonical_name] = new_var + self._initial_state_variables.append(new_var) + + for func in self.functions + self.modifiers: + func.generate_slithir_ssa(all_ssa_state_variables_instances) + + def fix_phi(self): + last_state_variables_instances = dict() + initial_state_variables_instances = dict() + for v in self._initial_state_variables: + last_state_variables_instances[v.canonical_name] = [] + initial_state_variables_instances[v.canonical_name] = v + + for func in self.functions + self.modifiers: + result = func.get_last_ssa_state_variables_instances() + for variable_name, instances in result.items(): + last_state_variables_instances[variable_name] += instances + + for func in self.functions + self.modifiers: + func.fix_phi(last_state_variables_instances, initial_state_variables_instances) + @property def is_top_level(self) -> bool: """ @@ -970,4 +1229,7 @@ class Contract(ChildSlither, SourceMapping): def __str__(self): return self.name + def __hash__(self): + return self._id + # endregion diff --git a/slither/core/declarations/enum.py b/slither/core/declarations/enum.py index ace9b2095..ae306f266 100644 --- a/slither/core/declarations/enum.py +++ b/slither/core/declarations/enum.py @@ -1,25 +1,32 @@ +from typing import List, TYPE_CHECKING + from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.children.child_contract import ChildContract +if TYPE_CHECKING: + from slither.core.declarations import Contract + + class Enum(ChildContract, SourceMapping): - def __init__(self, name, canonical_name, values): + def __init__(self, name: str, canonical_name: str, values: List[str]): + super().__init__() self._name = name self._canonical_name = canonical_name self._values = values @property - def canonical_name(self): + def canonical_name(self) -> str: return self._canonical_name @property - def name(self): + def name(self) -> str: return self._name @property - def values(self): + def values(self) -> List[str]: return self._values - def is_declared_by(self, contract): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index 7d4eeeaf7..99ab268dd 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -1,47 +1,57 @@ +from typing import List, Tuple, TYPE_CHECKING + from slither.core.children.child_contract import ChildContract from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.event_variable import EventVariable -class Event(ChildContract, SourceMapping): +if TYPE_CHECKING: + from slither.core.declarations import Contract + +class Event(ChildContract, SourceMapping): def __init__(self): super(Event, self).__init__() self._name = None - self._elems = [] + self._elems: List[EventVariable] = [] @property - def name(self): + def name(self) -> str: return self._name + @name.setter + def name(self, name: str): + self._name = name + @property - def signature(self): - ''' Return the function signature + def signature(self) -> Tuple[str, List[str]]: + """ Return the function signature Returns: (str, list(str)): name, list parameters type - ''' + """ return self.name, [str(x.type) for x in self.elems] @property - def full_name(self): - ''' Return the function signature as a str + def full_name(self) -> str: + """ Return the function signature as a str Returns: str: func_name(type1,type2) - ''' + """ name, parameters = self.signature - return name+'('+','.join(parameters)+')' + return name + "(" + ",".join(parameters) + ")" @property - def canonical_name(self): - ''' Return the function signature as a str + def canonical_name(self) -> str: + """ Return the function signature as a str Returns: str: contract.func_name(type1,type2) - ''' + """ return self.contract.name + self.full_name @property - def elems(self): + def elems(self) -> List["EventVariable"]: return self._elems - def is_declared_by(self, contract): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index f9e0b9eda..d28dabb2c 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -5,50 +5,71 @@ import logging from collections import namedtuple from enum import Enum from itertools import groupby +from typing import Dict, TYPE_CHECKING, List, Optional, Set, Union, Callable, Tuple from slither.core.children.child_contract import ChildContract from slither.core.children.child_inheritance import ChildInheritance -from slither.core.declarations.solidity_variables import (SolidityFunction, - SolidityVariable, - SolidityVariableComposed) -from slither.core.expressions import (Identifier, IndexAccess, MemberAccess, - UnaryOperation) +from slither.core.declarations.solidity_variables import ( + SolidityFunction, + SolidityVariable, + SolidityVariableComposed, +) +from slither.core.expressions import Identifier, IndexAccess, MemberAccess, UnaryOperation from slither.core.solidity_types import UserDefinedType from slither.core.solidity_types.type import Type from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.local_variable import LocalVariable + from slither.core.variables.state_variable import StateVariable from slither.utils.utils import unroll -logger = logging.getLogger("Function") - -ReacheableNode = namedtuple('ReacheableNode', ['node', 'ir']) +if TYPE_CHECKING: + from slither.utils.type_helpers import ( + InternalCallType, + LowLevelCallType, + HighLevelCallType, + LibraryCallType, + ) + from slither.core.declarations import Contract + from slither.core.cfg.node import Node, NodeType + from slither.core.variables.variable import Variable + from slither.slithir.variables.variable import SlithIRVariable + from slither.slithir.variables import LocalIRVariable + from slither.core.expressions.expression import Expression + from slither.slithir.operations import Operation + from slither.slither import Slither + from slither.core.cfg.node import NodeType + +LOGGER = logging.getLogger("Function") +ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) class ModifierStatements: - - def __init__(self, modifier, entry_point, nodes): + def __init__( + self, modifier: Union["Contract", "Function"], entry_point: "Node", nodes: List["Node"] + ): self._modifier = modifier self._entry_point = entry_point self._nodes = nodes @property - def modifier(self): + def modifier(self) -> Union["Contract", "Function"]: return self._modifier @property - def entry_point(self): + def entry_point(self) -> "Node": return self._entry_point @entry_point.setter - def entry_point(self, entry_point): + def entry_point(self, entry_point: "Node"): self._entry_point = entry_point @property - def nodes(self): + def nodes(self) -> List["Node"]: return self._nodes @nodes.setter - def nodes(self, nodes): + def nodes(self, nodes: List["Node"]): self._nodes = nodes @@ -68,83 +89,89 @@ class Function(ChildContract, ChildInheritance, SourceMapping): def __init__(self): super(Function, self).__init__() - self._name = None - self._view = None - self._pure = None - self._payable = None - self._visibility = None - - self._is_implemented = None - self._is_empty = None - self._entry_point = None - self._nodes = [] - self._variables = {} - self._slithir_variables = set() # slithir Temporary and references variables (but not SSA) - self._parameters = [] - self._parameters_ssa = [] - self._parameters_src = None - self._returns = [] - self._returns_ssa = [] - self._returns_src = None - self._return_values = None - self._return_values_ssa = None - self._vars_read = [] - self._vars_written = [] - self._state_vars_read = [] - self._vars_read_or_written = [] - self._solidity_vars_read = [] - self._state_vars_written = [] - self._internal_calls = [] - self._solidity_calls = [] - self._low_level_calls = [] - self._high_level_calls = [] - self._library_calls = [] - self._external_calls_as_expressions = [] - self._expression_vars_read = [] - self._expression_vars_written = [] - self._expression_calls = [] - self._expression_modifiers = [] - self._modifiers = [] - self._explicit_base_constructor_calls = [] - self._payable = False - self._contains_assembly = False - - self._expressions = None - self._slithir_operations = None - self._slithir_ssa_operations = None - - self._all_expressions = None - self._all_slithir_operations = None - self._all_internals_calls = None - self._all_high_level_calls = None - self._all_library_calls = None - self._all_low_level_calls = None - self._all_solidity_calls = None - self._all_state_variables_read = None - self._all_solidity_variables_read = None - self._all_state_variables_written = None - self._all_slithir_variables = None - self._all_nodes = None - self._all_conditional_state_variables_read = None - self._all_conditional_state_variables_read_with_loop = None - self._all_conditional_solidity_variables_read = None - self._all_conditional_solidity_variables_read_with_loop = None - self._all_solidity_variables_used_as_args = None - - self._is_shadowed = False - self._shadows = False + self._name: Optional[str] = None + self._view: bool = False + self._pure: bool = False + self._payable: bool = False + self._visibility: Optional[str] = None + + self._is_implemented: Optional[bool] = None + self._is_empty: Optional[bool] = None + self._entry_point: Optional["Node"] = None + self._nodes: List["Node"] = [] + self._variables: Dict[str, "LocalVariable"] = {} + # slithir Temporary and references variables (but not SSA) + self._slithir_variables: Set["SlithIRVariable"] = set() + self._parameters: List["LocalVariable"] = [] + self._parameters_ssa: List["LocalIRVariable"] = [] + self._parameters_src: Optional[SourceMapping] = None + self._returns: List["LocalVariable"] = [] + self._returns_ssa: List["LocalIRVariable"] = [] + self._returns_src: Optional[SourceMapping] = None + self._return_values: Optional[List["SlithIRVariable"]] = None + self._return_values_ssa: Optional[List["SlithIRVariable"]] = None + self._vars_read: List["Variable"] = [] + self._vars_written: List["Variable"] = [] + self._state_vars_read: List["StateVariable"] = [] + self._vars_read_or_written: List["Variable"] = [] + self._solidity_vars_read: List["SolidityVariable"] = [] + self._state_vars_written: List["StateVariable"] = [] + self._internal_calls: List["InternalCallType"] = [] + self._solidity_calls: List["SolidityFunction"] = [] + self._low_level_calls: List["LowLevelCallType"] = [] + self._high_level_calls: List["HighLevelCallType"] = [] + self._library_calls: List["LibraryCallType"] = [] + self._external_calls_as_expressions: List["Expression"] = [] + self._expression_vars_read: List["Expression"] = [] + self._expression_vars_written: List["Expression"] = [] + self._expression_calls: List["Expression"] = [] + # self._expression_modifiers: List["Expression"] = [] + self._modifiers: List[ModifierStatements] = [] + self._explicit_base_constructor_calls: List[ModifierStatements] = [] + self._contains_assembly: bool = False + + self._expressions: Optional[List["Expression"]] = None + self._slithir_operations: Optional[List["Operation"]] = None + self._slithir_ssa_operations: Optional[List["Operation"]] = None + + self._all_expressions: Optional[List["Expression"]] = None + self._all_slithir_operations: Optional[List["Operation"]] = None + self._all_internals_calls: Optional[List["InternalCallType"]] = None + self._all_high_level_calls: Optional[List["HighLevelCallType"]] = None + self._all_library_calls: Optional[List["LibraryCallType"]] = None + self._all_low_level_calls: Optional[List["LowLevelCallType"]] = None + self._all_solidity_calls: Optional[List["SolidityFunction"]] = None + self._all_state_variables_read: Optional[List["StateVariable"]] = None + self._all_solidity_variables_read: Optional[List["SolidityVariable"]] = None + self._all_state_variables_written: Optional[List["StateVariable"]] = None + self._all_slithir_variables: Optional[List["SlithIRVariable"]] = None + self._all_nodes: Optional[List["Node"]] = None + self._all_conditional_state_variables_read: Optional[List["StateVariable"]] = None + self._all_conditional_state_variables_read_with_loop: Optional[List["StateVariable"]] = None + self._all_conditional_solidity_variables_read: Optional[List["SolidityVariable"]] = None + self._all_conditional_solidity_variables_read_with_loop: Optional[ + List["SolidityVariable"] + ] = None + self._all_solidity_variables_used_as_args: Optional[List["SolidityVariable"]] = None + + self._is_shadowed: bool = False + self._shadows: bool = False # set(ReacheableNode) - self._reachable_from_nodes = set() - self._reachable_from_functions = set() + self._reachable_from_nodes: Set[ReacheableNode] = set() + self._reachable_from_functions: Set[ReacheableNode] = set() # Constructor, fallback, State variable constructor - self._function_type = None - self._is_constructor = None + self._function_type: Optional[FunctionType] = None + self._is_constructor: Optional[bool] = None # Computed on the fly, can be True of False - self._can_reenter = None - self._can_send_eth = None + self._can_reenter: Optional[bool] = None + self._can_send_eth: Optional[bool] = None + + self._nodes_ordered_dominators: Optional[List["Node"]] = None + + self._counter_nodes = 0 ################################################################################### ################################################################################### @@ -153,46 +180,54 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def name(self): + def name(self) -> str: """ str: function name """ - if self._name == '' and self._function_type == FunctionType.CONSTRUCTOR: - return 'constructor' + if self._name == "" and self._function_type == FunctionType.CONSTRUCTOR: + return "constructor" elif self._function_type == FunctionType.FALLBACK: - return 'fallback' + return "fallback" elif self._function_type == FunctionType.RECEIVE: - return 'receive' + return "receive" elif self._function_type == FunctionType.CONSTRUCTOR_VARIABLES: - return 'slitherConstructorVariables' + return "slitherConstructorVariables" elif self._function_type == FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES: - return 'slitherConstructorConstantVariables' + return "slitherConstructorConstantVariables" return self._name + @name.setter + def name(self, new_name: str): + self._name = new_name + @property - def full_name(self): + def full_name(self) -> str: """ str: func_name(type1,type2) Return the function signature without the return values """ name, parameters, _ = self.signature - return name + '(' + ','.join(parameters) + ')' + return name + "(" + ",".join(parameters) + ")" @property - def canonical_name(self): + def canonical_name(self) -> str: """ str: contract.func_name(type1,type2) Return the function signature without the return values """ name, parameters, _ = self.signature - return self.contract_declarer.name + '.' + name + '(' + ','.join(parameters) + ')' + return self.contract_declarer.name + "." + name + "(" + ",".join(parameters) + ")" @property - def contains_assembly(self): + def contains_assembly(self) -> bool: return self._contains_assembly - def can_reenter(self, callstack=None): - ''' + @contains_assembly.setter + def contains_assembly(self, c: bool): + self._contains_assembly = c + + def can_reenter(self, callstack=None) -> bool: + """ Check if the function can re-enter Follow internal calls. Do not consider CREATE as potential re-enter, but check if the @@ -202,8 +237,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Do not consider Send/Transfer as there is not enough gas :param callstack: used internally to check for recursion :return bool: - ''' + """ from slither.slithir.operations import Call + if self._can_reenter is None: self._can_reenter = False for ir in self.all_slithir_operations(): @@ -212,12 +248,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return True return self._can_reenter - def can_send_eth(self): - ''' + def can_send_eth(self) -> bool: + """ Check if the function can send eth :return bool: - ''' + """ from slither.slithir.operations import Call + if self._can_send_eth is None: for ir in self.all_slithir_operations(): if isinstance(ir, Call) and ir.can_send_eth(): @@ -226,10 +263,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return self._can_reenter @property - def slither(self): + def slither(self) -> "Slither": return self.contract.slither - def is_declared_by(self, contract): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: @@ -244,27 +281,38 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def set_function_type(self, t): + def set_function_type(self, t: FunctionType): assert isinstance(t, FunctionType) self._function_type = t @property - def is_constructor(self): + def function_type(self) -> Optional[FunctionType]: + return self._function_type + + @function_type.setter + def function_type(self, t: FunctionType): + self._function_type = t + + @property + def is_constructor(self) -> bool: """ bool: True if the function is the constructor """ return self._function_type == FunctionType.CONSTRUCTOR @property - def is_constructor_variables(self): + def is_constructor_variables(self) -> bool: """ bool: True if the function is the constructor of the variables Slither has inbuilt functions to hold the state variables initialization """ - return self._function_type in [FunctionType.CONSTRUCTOR_VARIABLES, FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES] + return self._function_type in [ + FunctionType.CONSTRUCTOR_VARIABLES, + FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES, + ] @property - def is_fallback(self): + def is_fallback(self) -> bool: """ Determine if the function is the fallback function for the contract Returns @@ -273,7 +321,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return self._function_type == FunctionType.FALLBACK @property - def is_receive(self): + def is_receive(self) -> bool: """ Determine if the function is the receive function for the contract Returns @@ -289,12 +337,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def payable(self): + def payable(self) -> bool: """ bool: True if the function is payable """ return self._payable + @payable.setter + def payable(self, p: bool): + self._payable = p + # endregion ################################################################################### ################################################################################### @@ -303,31 +355,44 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def visibility(self): + def visibility(self) -> str: """ str: Function visibility """ + assert self._visibility is not None return self._visibility - def set_visibility(self, v): + @visibility.setter + def visibility(self, v: str): + self._visibility = v + + def set_visibility(self, v: str): self._visibility = v @property - def view(self): + def view(self) -> bool: """ bool: True if the function is declared as view """ return self._view + @view.setter + def view(self, v: bool): + self._view = v + @property - def pure(self): + def pure(self) -> bool: """ bool: True if the function is declared as pure """ return self._pure + @pure.setter + def pure(self, p: bool): + self._pure = p + @property - def is_shadowed(self): + def is_shadowed(self) -> bool: return self._is_shadowed @is_shadowed.setter @@ -335,11 +400,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): self._is_shadowed = is_shadowed @property - def shadows(self): + def shadows(self) -> bool: return self._shadows @shadows.setter - def shadows(self, _shadows): + def shadows(self, _shadows: bool): self._shadows = _shadows # endregion @@ -350,19 +415,27 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def is_implemented(self): + def is_implemented(self) -> bool: """ bool: True if the function is implemented """ return self._is_implemented + @is_implemented.setter + def is_implemented(self, is_impl: bool): + self._is_implemented = is_impl + @property - def is_empty(self): + def is_empty(self) -> bool: """ bool: True if the function is empty, None if the function is an interface """ return self._is_empty + @is_empty.setter + def is_empty(self, empty: bool): + self._is_empty = empty + # endregion ################################################################################### ################################################################################### @@ -371,24 +444,57 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def nodes(self): + def nodes(self) -> List["Node"]: """ list(Node): List of the nodes """ return list(self._nodes) + @nodes.setter + def nodes(self, nodes: List["Node"]): + self._nodes = nodes + @property - def entry_point(self): + def entry_point(self) -> "Node": """ Node: Entry point of the function """ return self._entry_point - def add_node(self, node): + @entry_point.setter + def entry_point(self, node: "Node"): + self._entry_point = node + + def add_node(self, node: "Node"): if not self._entry_point: self._entry_point = node self._nodes.append(node) + @property + def nodes_ordered_dominators(self) -> List["Node"]: + # TODO: does not work properly; most likely due to modifier call + # This will not work for modifier call that lead to multiple nodes + # from slither.core.cfg.node import NodeType + if self._nodes_ordered_dominators is None: + self._nodes_ordered_dominators = [] + if self.entry_point: + self._compute_nodes_ordered_dominators(self.entry_point) + + for node in self.nodes: + # if node.type == NodeType.OTHER_ENTRYPOINT: + if not node in self._nodes_ordered_dominators: + self._compute_nodes_ordered_dominators(node) + + return self._nodes_ordered_dominators + + def _compute_nodes_ordered_dominators(self, node: "Node"): + assert self._nodes_ordered_dominators is not None + if node in self._nodes_ordered_dominators: + return + self._nodes_ordered_dominators.append(node) + for dom in node.dominance_exploration_ordered: + self._compute_nodes_ordered_dominators(dom) + # endregion ################################################################################### ################################################################################### @@ -397,20 +503,23 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def parameters(self): + def parameters(self) -> List["LocalVariable"]: """ list(LocalVariable): List of the parameters """ return list(self._parameters) + def add_parameters(self, p: "LocalVariable"): + self._parameters.append(p) + @property - def parameters_ssa(self): + def parameters_ssa(self) -> List["LocalIRVariable"]: """ list(LocalIRVariable): List of the parameters (SSA form) """ return list(self._parameters_ssa) - def add_parameter_ssa(self, var): + def add_parameter_ssa(self, var: "LocalIRVariable"): self._parameters_ssa.append(var) # endregion @@ -421,7 +530,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def return_type(self): + def return_type(self) -> Optional[List[Type]]: """ Return the list of return type If no return, return None @@ -432,28 +541,32 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return None @property - def type(self): + def type(self) -> Optional[List[Type]]: """ Return the list of return type If no return, return None + Alias of return_type """ return self.return_type @property - def returns(self): + def returns(self) -> List["LocalVariable"]: """ list(LocalVariable): List of the return variables """ return list(self._returns) + def add_return(self, r: "LocalVariable"): + self._returns.append(r) + @property - def returns_ssa(self): + def returns_ssa(self) -> List["LocalIRVariable"]: """ list(LocalIRVariable): List of the return variables (SSA form) """ return list(self._returns_ssa) - def add_return_ssa(self, var): + def add_return_ssa(self, var: "LocalIRVariable"): self._returns_ssa.append(var) # endregion @@ -464,21 +577,26 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def modifiers(self): + def modifiers(self) -> List[Union["Contract", "Function"]]: """ list(Modifier): List of the modifiers + Can be contract for constructor's calls + """ return [c.modifier for c in self._modifiers] + def add_modifier(self, modif: "ModifierStatements"): + self._modifiers.append(modif) + @property - def modifiers_statements(self): + def modifiers_statements(self) -> List[ModifierStatements]: """ list(ModifierCall): List of the modifiers call (include expression and irs) """ return list(self._modifiers) @property - def explicit_base_constructor_calls(self): + def explicit_base_constructor_calls(self) -> List["Function"]: """ list(Function): List of the base constructors called explicitly by this presumed constructor definition. @@ -486,11 +604,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping): included. """ # This is a list of contracts internally, so we convert it to a list of constructor functions. - return [c.modifier.constructors_declared for c in self._explicit_base_constructor_calls if - c.modifier.constructors_declared] + return [ + c.modifier.constructors_declared + for c in self._explicit_base_constructor_calls + if c.modifier.constructors_declared + ] @property - def explicit_base_constructor_calls_statements(self): + def explicit_base_constructor_calls_statements(self) -> List[ModifierStatements]: """ list(ModifierCall): List of the base constructors called explicitly by this presumed constructor definition. @@ -498,6 +619,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): # This is a list of contracts internally, so we convert it to a list of constructor functions. return list(self._explicit_base_constructor_calls) + def add_explicit_base_constructor_calls_statements(self, modif: ModifierStatements): + self._explicit_base_constructor_calls.append(modif) + # endregion ################################################################################### ################################################################################### @@ -506,7 +630,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def variables(self): + def variables(self) -> List[LocalVariable]: """ Return all local variables Include paramters and return values @@ -514,70 +638,71 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return list(self._variables.values()) @property - def local_variables(self): + def local_variables(self) -> List[LocalVariable]: """ Return all local variables (dont include paramters and return values) """ return list(set(self.variables) - set(self.returns) - set(self.parameters)) - def variables_as_dict(self): + @property + def variables_as_dict(self) -> Dict[str, LocalVariable]: return self._variables @property - def variables_read(self): + def variables_read(self) -> List["Variable"]: """ list(Variable): Variables read (local/state/solidity) """ return list(self._vars_read) @property - def variables_written(self): + def variables_written(self) -> List["Variable"]: """ list(Variable): Variables written (local/state/solidity) """ return list(self._vars_written) @property - def state_variables_read(self): + def state_variables_read(self) -> List["StateVariable"]: """ list(StateVariable): State variables read """ return list(self._state_vars_read) @property - def solidity_variables_read(self): + def solidity_variables_read(self) -> List["SolidityVariable"]: """ list(SolidityVariable): Solidity variables read """ return list(self._solidity_vars_read) @property - def state_variables_written(self): + def state_variables_written(self) -> List["StateVariable"]: """ list(StateVariable): State variables written """ return list(self._state_vars_written) @property - def variables_read_or_written(self): + def variables_read_or_written(self) -> List["Variable"]: """ list(Variable): Variables read or written (local/state/solidity) """ return list(self._vars_read_or_written) @property - def variables_read_as_expression(self): + def variables_read_as_expression(self) -> List["Expression"]: return self._expression_vars_read @property - def variables_written_as_expression(self): + def variables_written_as_expression(self) -> List["Expression"]: return self._expression_vars_written @property - def slithir_variables(self): - ''' + def slithir_variables(self) -> List["SlithIRVariable"]: + """ Temporary and Reference Variables (not SSA form) - ''' + """ return list(self._slithir_variables) @@ -589,21 +714,21 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def internal_calls(self): + def internal_calls(self) -> List["InternalCallType"]: """ list(Function or SolidityFunction): List of function calls (that does not create a transaction) """ return list(self._internal_calls) @property - def solidity_calls(self): + def solidity_calls(self) -> List[SolidityFunction]: """ list(SolidityFunction): List of Soldity calls """ return list(self._solidity_calls) @property - def high_level_calls(self): + def high_level_calls(self) -> List["HighLevelCallType"]: """ list((Contract, Function|Variable)): List of high level calls (external calls). @@ -613,14 +738,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return list(self._high_level_calls) @property - def library_calls(self): + def library_calls(self) -> List["LibraryCallType"]: """ list((Contract, Function)): """ return list(self._library_calls) @property - def low_level_calls(self): + def low_level_calls(self) -> List["LowLevelCallType"]: """ list((Variable|SolidityVariable, str)): List of low_level call A low level call is defined by @@ -630,7 +755,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return list(self._low_level_calls) @property - def external_calls_as_expressions(self): + def external_calls_as_expressions(self) -> List["Expression"]: """ list(ExpressionCall): List of message calls (that creates a transaction) """ @@ -644,22 +769,22 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def calls_as_expressions(self): + def calls_as_expressions(self) -> List["Expression"]: return self._expression_calls @property - def expressions(self): + def expressions(self) -> List["Expression"]: """ list(Expression): List of the expressions """ if self._expressions is None: - expressions = [n.expression for n in self.nodes] - expressions = [e for e in expressions if e] + expressionss = [n.expression for n in self.nodes] + expressions = [e for e in expressionss if e] self._expressions = expressions return self._expressions @property - def return_values(self): + def return_values(self) -> List["SlithIRVariable"]: """ list(Return Values): List of the return values """ @@ -670,12 +795,19 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if self._return_values is None: return_values = list() returns = [n for n in self.nodes if n.type == NodeType.RETURN] - [return_values.extend(ir.values) for node in returns for ir in node.irs if isinstance(ir, Return)] - self._return_values = list(set([x for x in return_values if not isinstance(x, Constant)])) + [ + return_values.extend(ir.values) + for node in returns + for ir in node.irs + if isinstance(ir, Return) + ] + self._return_values = list( + set([x for x in return_values if not isinstance(x, Constant)]) + ) return self._return_values @property - def return_values_ssa(self): + def return_values_ssa(self) -> List["SlithIRVariable"]: """ list(Return Values in SSA form): List of the return values in ssa form """ @@ -686,8 +818,15 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if self._return_values_ssa is None: return_values_ssa = list() returns = [n for n in self.nodes if n.type == NodeType.RETURN] - [return_values_ssa.extend(ir.values) for node in returns for ir in node.irs_ssa if isinstance(ir, Return)] - self._return_values_ssa = list(set([x for x in return_values_ssa if not isinstance(x, Constant)])) + [ + return_values_ssa.extend(ir.values) + for node in returns + for ir in node.irs_ssa + if isinstance(ir, Return) + ] + self._return_values_ssa = list( + set([x for x in return_values_ssa if not isinstance(x, Constant)]) + ) return self._return_values_ssa # endregion @@ -698,24 +837,24 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def slithir_operations(self): + def slithir_operations(self) -> List["Operation"]: """ list(Operation): List of the slithir operations """ if self._slithir_operations is None: - operations = [n.irs for n in self.nodes] - operations = [item for sublist in operations for item in sublist if item] + operationss = [n.irs for n in self.nodes] + operations = [item for sublist in operationss for item in sublist if item] self._slithir_operations = operations return self._slithir_operations @property - def slithir_ssa_operations(self): + def slithir_ssa_operations(self) -> List["Operation"]: """ list(Operation): List of the slithir operations (SSA) """ if self._slithir_ssa_operations is None: - operations = [n.irs_ssa for n in self.nodes] - operations = [item for sublist in operations for item in sublist if item] + operationss = [n.irs_ssa for n in self.nodes] + operations = [item for sublist in operationss for item in sublist if item] self._slithir_ssa_operations = operations return self._slithir_ssa_operations @@ -729,6 +868,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): @staticmethod def _convert_type_for_solidity_signature(t: Type): from slither.core.declarations import Contract + if isinstance(t, UserDefinedType) and isinstance(t.type, Contract): return "address" return str(t) @@ -741,27 +881,28 @@ class Function(ChildContract, ChildInheritance, SourceMapping): :return: the solidity signature """ parameters = [self._convert_type_for_solidity_signature(x.type) for x in self.parameters] - return self.name + '(' + ','.join(parameters) + ')' - + return self.name + "(" + ",".join(parameters) + ")" @property - def signature(self): + def signature(self) -> Tuple[str, List[str], List[str]]: """ (str, list(str), list(str)): Function signature as (name, list parameters type, list return values type) """ - return (self.name, - [str(x.type) for x in self.parameters], - [str(x.type) for x in self.returns]) + return ( + self.name, + [str(x.type) for x in self.parameters], + [str(x.type) for x in self.returns], + ) @property - def signature_str(self): + def signature_str(self) -> str: """ str: func_name(type1,type2) returns (type3) Return the function signature as a str (contains the return values) """ name, parameters, returnVars = self.signature - return name + '(' + ','.join(parameters) + ') returns(' + ','.join(returnVars) + ')' + return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")" # endregion ################################################################################### @@ -771,13 +912,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def functions_shadowed(self): - ''' + def functions_shadowed(self) -> List["Function"]: + """ Return the list of functions shadowed Returns: list(core.Function) - ''' + """ candidates = [c.functions_declared for c in self.contract.inheritance] candidates = [candidate for sublist in candidates for candidate in sublist] return [f for f in candidates if f.full_name == self.full_name] @@ -790,18 +931,18 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### @property - def reachable_from_nodes(self): - ''' + def reachable_from_nodes(self) -> Set[ReacheableNode]: + """ Return ReacheableNode - ''' + """ return self._reachable_from_nodes @property - def reachable_from_functions(self): + def reachable_from_functions(self) -> Set[ReacheableNode]: return self._reachable_from_functions - def add_reachable_from_node(self, n, ir): + def add_reachable_from_node(self, n: "Node", ir: "Operation"): self._reachable_from_nodes.add(ReacheableNode(n, ir)) self._reachable_from_functions.add(n.function) @@ -812,13 +953,15 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def _explore_functions(self, f_new_values): + def _explore_functions(self, f_new_values: Callable[["Function"], List]): values = f_new_values(self) explored = [self] - to_explore = [c for c in self.internal_calls if - isinstance(c, Function) and c not in explored] - to_explore += [c for (_, c) in self.library_calls if - isinstance(c, Function) and c not in explored] + to_explore = [ + c for c in self.internal_calls if isinstance(c, Function) and c not in explored + ] + to_explore += [ + c for (_, c) in self.library_calls if isinstance(c, Function) and c not in explored + ] to_explore += [m for m in self.modifiers if m not in explored] while to_explore: @@ -830,96 +973,104 @@ class Function(ChildContract, ChildInheritance, SourceMapping): values += f_new_values(f) - to_explore += [c for c in f.internal_calls if \ - isinstance(c, Function) and c not in explored and c not in to_explore] - to_explore += [c for (_, c) in f.library_calls if - isinstance(c, Function) and c not in explored and c not in to_explore] + to_explore += [ + c + for c in f.internal_calls + if isinstance(c, Function) and c not in explored and c not in to_explore + ] + to_explore += [ + c + for (_, c) in f.library_calls + if isinstance(c, Function) and c not in explored and c not in to_explore + ] to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] return list(set(values)) - def all_state_variables_read(self): + def all_state_variables_read(self) -> List["StateVariable"]: """ recursive version of variables_read """ if self._all_state_variables_read is None: self._all_state_variables_read = self._explore_functions( - lambda x: x.state_variables_read) + lambda x: x.state_variables_read + ) return self._all_state_variables_read - def all_solidity_variables_read(self): + def all_solidity_variables_read(self) -> List[SolidityVariable]: """ recursive version of solidity_read """ if self._all_solidity_variables_read is None: self._all_solidity_variables_read = self._explore_functions( - lambda x: x.solidity_variables_read) + lambda x: x.solidity_variables_read + ) return self._all_solidity_variables_read - def all_slithir_variables(self): + def all_slithir_variables(self) -> List["SlithIRVariable"]: """ recursive version of slithir_variables """ if self._all_slithir_variables is None: - self._all_slithir_variables = self._explore_functions( - lambda x: x.slithir_variable) + self._all_slithir_variables = self._explore_functions(lambda x: x.slithir_variables) return self._all_slithir_variables - def all_nodes(self): + def all_nodes(self) -> List["Node"]: """ recursive version of nodes """ if self._all_nodes is None: self._all_nodes = self._explore_functions(lambda x: x.nodes) return self._all_nodes - def all_expressions(self): + def all_expressions(self) -> List["Expression"]: """ recursive version of variables_read """ if self._all_expressions is None: self._all_expressions = self._explore_functions(lambda x: x.expressions) return self._all_expressions - def all_slithir_operations(self): + def all_slithir_operations(self) -> List["Operation"]: """ """ if self._all_slithir_operations is None: self._all_slithir_operations = self._explore_functions(lambda x: x.slithir_operations) return self._all_slithir_operations - def all_state_variables_written(self): + def all_state_variables_written(self) -> List[StateVariable]: """ recursive version of variables_written """ if self._all_state_variables_written is None: self._all_state_variables_written = self._explore_functions( - lambda x: x.state_variables_written) + lambda x: x.state_variables_written + ) return self._all_state_variables_written - def all_internal_calls(self): + def all_internal_calls(self) -> List["InternalCallType"]: """ recursive version of internal_calls """ if self._all_internals_calls is None: self._all_internals_calls = self._explore_functions(lambda x: x.internal_calls) return self._all_internals_calls - def all_low_level_calls(self): + def all_low_level_calls(self) -> List["LowLevelCallType"]: """ recursive version of low_level calls """ if self._all_low_level_calls is None: self._all_low_level_calls = self._explore_functions(lambda x: x.low_level_calls) return self._all_low_level_calls - def all_high_level_calls(self): + def all_high_level_calls(self) -> List["HighLevelCallType"]: """ recursive version of high_level calls """ if self._all_high_level_calls is None: self._all_high_level_calls = self._explore_functions(lambda x: x.high_level_calls) return self._all_high_level_calls - def all_library_calls(self): + def all_library_calls(self) -> List["LibraryCallType"]: """ recursive version of library calls """ if self._all_library_calls is None: self._all_library_calls = self._explore_functions(lambda x: x.library_calls) return self._all_library_calls - def all_solidity_calls(self): + def all_solidity_calls(self) -> List[SolidityFunction]: """ recursive version of solidity calls """ if self._all_solidity_calls is None: @@ -927,11 +1078,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return self._all_solidity_calls @staticmethod - def _explore_func_cond_read(func, include_loop): + def _explore_func_cond_read(func: "Function", include_loop: bool) -> List["StateVariable"]: ret = [n.state_variables_read for n in func.nodes if n.is_conditional(include_loop)] return [item for sublist in ret for item in sublist] - def all_conditional_state_variables_read(self, include_loop=True): + def all_conditional_state_variables_read(self, include_loop=True) -> List["StateVariable"]: """ Return the state variable used in a condition @@ -941,19 +1092,19 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if include_loop: if self._all_conditional_state_variables_read_with_loop is None: self._all_conditional_state_variables_read_with_loop = self._explore_functions( - lambda x: self._explore_func_cond_read(x, - include_loop)) + lambda x: self._explore_func_cond_read(x, include_loop) + ) return self._all_conditional_state_variables_read_with_loop - else: - if self._all_conditional_state_variables_read is None: - self._all_conditional_state_variables_read = self._explore_functions( - lambda x: self._explore_func_cond_read(x, - include_loop)) - return self._all_conditional_state_variables_read + if self._all_conditional_state_variables_read is None: + self._all_conditional_state_variables_read = self._explore_functions( + lambda x: self._explore_func_cond_read(x, include_loop) + ) + return self._all_conditional_state_variables_read @staticmethod - def _solidity_variable_in_binary(node): + def _solidity_variable_in_binary(node: "Node") -> List[SolidityVariable]: from slither.slithir.operations.binary import Binary + ret = [] for ir in node.irs: if isinstance(ir, Binary): @@ -961,11 +1112,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return [var for var in ret if isinstance(var, SolidityVariable)] @staticmethod - def _explore_func_conditional(func, f, include_loop): + def _explore_func_conditional( + func: "Function", f: Callable[["Node"], List[SolidityVariable]], include_loop: bool + ): ret = [f(n) for n in func.nodes if n.is_conditional(include_loop)] return [item for sublist in ret for item in sublist] - def all_conditional_solidity_variables_read(self, include_loop=True): + def all_conditional_solidity_variables_read(self, include_loop=True) -> List[SolidityVariable]: """ Return the Soldiity variables directly used in a condtion @@ -976,21 +1129,24 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if include_loop: if self._all_conditional_solidity_variables_read_with_loop is None: self._all_conditional_solidity_variables_read_with_loop = self._explore_functions( - lambda x: self._explore_func_conditional(x, - self._solidity_variable_in_binary, - include_loop)) + lambda x: self._explore_func_conditional( + x, self._solidity_variable_in_binary, include_loop + ) + ) return self._all_conditional_solidity_variables_read_with_loop - else: - if self._all_conditional_solidity_variables_read is None: - self._all_conditional_solidity_variables_read = self._explore_functions( - lambda x: self._explore_func_conditional(x, - self._solidity_variable_in_binary, - include_loop)) - return self._all_conditional_solidity_variables_read + + if self._all_conditional_solidity_variables_read is None: + self._all_conditional_solidity_variables_read = self._explore_functions( + lambda x: self._explore_func_conditional( + x, self._solidity_variable_in_binary, include_loop + ) + ) + return self._all_conditional_solidity_variables_read @staticmethod - def _solidity_variable_in_internal_calls(node): + def _solidity_variable_in_internal_calls(node: "Node") -> List[SolidityVariable]: from slither.slithir.operations.internal_call import InternalCall + ret = [] for ir in node.irs: if isinstance(ir, InternalCall): @@ -998,11 +1154,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return [var for var in ret if isinstance(var, SolidityVariable)] @staticmethod - def _explore_func_nodes(func, f): + def _explore_func_nodes(func: "Function", f: Callable[["Node"], List[SolidityVariable]]): ret = [f(n) for n in func.nodes] return [item for sublist in ret for item in sublist] - def all_solidity_variables_used_as_args(self): + def all_solidity_variables_used_as_args(self) -> List[SolidityVariable]: """ Return the Soldiity variables directly used in a call @@ -1011,7 +1167,8 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ if self._all_solidity_variables_used_as_args is None: self._all_solidity_variables_used_as_args = self._explore_functions( - lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls)) + lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls) + ) return self._all_solidity_variables_used_as_args # endregion @@ -1021,7 +1178,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def apply_visitor(self, Visitor): + def apply_visitor(self, Visitor: Callable) -> List: """ Apply a visitor to all the function expressions Args: @@ -1040,11 +1197,12 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def get_local_variable_from_name(self, variable_name): + def get_local_variable_from_name(self, variable_name: str) -> Optional[LocalVariable]: """ Return a local variable from a name + Args: - varible_name (str): name of the variable + variable_name (str): name of the variable Returns: LocalVariable """ @@ -1057,22 +1215,22 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def cfg_to_dot(self, filename): + def cfg_to_dot(self, filename: str): """ Export the function to a dot file Args: filename (str) """ - with open(filename, 'w', encoding='utf8') as f: - f.write('digraph{\n') + with open(filename, "w", encoding="utf8") as f: + f.write("digraph{\n") for node in self.nodes: f.write('{}[label="{}"];\n'.format(node.node_id, str(node))) for son in node.sons: - f.write('{}->{};\n'.format(node.node_id, son.node_id)) + f.write("{}->{};\n".format(node.node_id, son.node_id)) f.write("}\n") - def dominator_tree_to_dot(self, filename): + def dominator_tree_to_dot(self, filename: str): """ Export the dominator tree of the function to a dot file Args: @@ -1080,29 +1238,31 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ def description(node): - desc = '{}\n'.format(node) - desc += 'id: {}'.format(node.node_id) + desc = "{}\n".format(node) + desc += "id: {}".format(node.node_id) if node.dominance_frontier: - desc += '\ndominance frontier: {}'.format([n.node_id for n in node.dominance_frontier]) + desc += "\ndominance frontier: {}".format( + [n.node_id for n in node.dominance_frontier] + ) return desc - with open(filename, 'w', encoding='utf8') as f: - f.write('digraph{\n') + with open(filename, "w", encoding="utf8") as f: + f.write("digraph{\n") for node in self.nodes: f.write('{}[label="{}"];\n'.format(node.node_id, description(node))) if node.immediate_dominator: - f.write('{}->{};\n'.format(node.immediate_dominator.node_id, node.node_id)) + f.write("{}->{};\n".format(node.immediate_dominator.node_id, node.node_id)) f.write("}\n") - def slithir_cfg_to_dot(self, filename): + def slithir_cfg_to_dot(self, filename: str): """ Export the CFG to a DOT file. The nodes includes the Solidity expressions and the IRs :param filename: :return: """ content = self.slithir_cfg_to_dot_str() - with open(filename, 'w', encoding='utf8') as f: + with open(filename, "w", encoding="utf8") as f: f.write(content) def slithir_cfg_to_dot_str(self) -> str: @@ -1112,14 +1272,15 @@ class Function(ChildContract, ChildInheritance, SourceMapping): :rtype: str """ from slither.core.cfg.node import NodeType - content = '' - content += 'digraph{\n' + + content = "" + content += "digraph{\n" for node in self.nodes: - label = 'Node Type: {} {}\n'.format(NodeType.str(node.type), node.node_id) + label = "Node Type: {} {}\n".format(str(node.type), node.node_id) if node.expression: - label += '\nEXPRESSION:\n{}\n'.format(node.expression) + label += "\nEXPRESSION:\n{}\n".format(node.expression) if node.irs: - label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs]) + label += "\nIRs:\n" + "\n".join([str(ir) for ir in node.irs]) content += '{}[label="{}"];\n'.format(node.node_id, label) if node.type == NodeType.IF: true_node = node.son_true @@ -1130,7 +1291,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id) else: for son in node.sons: - content += '{}->{};\n'.format(node.node_id, son.node_id) + content += "{}->{};\n".format(node.node_id, son.node_id) content += "}\n" return content @@ -1142,7 +1303,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def is_reading(self, variable): + def is_reading(self, variable: "Variable") -> bool: """ Check if the function reads the variable Args: @@ -1152,7 +1313,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ return variable in self.variables_read - def is_reading_in_conditional_node(self, variable): + def is_reading_in_conditional_node(self, variable: "Variable") -> bool: """ Check if the function reads the variable in a IF node Args: @@ -1160,11 +1321,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Returns: bool: True if the variable is read """ - variables_read = [n.variables_read for n in self.nodes if n.contains_if()] - variables_read = [item for sublist in variables_read for item in sublist] + variables_reads = [n.variables_read for n in self.nodes if n.contains_if()] + variables_read = [item for sublist in variables_reads for item in sublist] return variable in variables_read - def is_reading_in_require_or_assert(self, variable): + def is_reading_in_require_or_assert(self, variable: "Variable") -> bool: """ Check if the function reads the variable in an require or assert Args: @@ -1172,11 +1333,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Returns: bool: True if the variable is read """ - variables_read = [n.variables_read for n in self.nodes if n.contains_require_or_assert()] - variables_read = [item for sublist in variables_read for item in sublist] + variables_reads = [n.variables_read for n in self.nodes if n.contains_require_or_assert()] + variables_read = [item for sublist in variables_reads for item in sublist] return variable in variables_read - def is_writing(self, variable): + def is_writing(self, variable: "Variable") -> bool: """ Check if the function writes the variable Args: @@ -1186,21 +1347,27 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ return variable in self.variables_written - def get_summary(self): + def get_summary( + self, + ) -> Tuple[str, str, str, List[str], List[str], List[str], List[str], List[str]]: """ Return the function summary Returns: (str, str, str, list(str), list(str), listr(str), list(str), list(str); contract_name, name, visibility, modifiers, vars read, vars written, internal_calls, external_calls_as_expressions """ - return (self.contract_declarer.name, self.full_name, self.visibility, - [str(x) for x in self.modifiers], - [str(x) for x in self.state_variables_read + self.solidity_variables_read], - [str(x) for x in self.state_variables_written], - [str(x) for x in self.internal_calls], - [str(x) for x in self.external_calls_as_expressions]) + return ( + self.contract_declarer.name, + self.full_name, + self.visibility, + [str(x) for x in self.modifiers], + [str(x) for x in self.state_variables_read + self.solidity_variables_read], + [str(x) for x in self.state_variables_written], + [str(x) for x in self.internal_calls], + [str(x) for x in self.external_calls_as_expressions], + ) - def is_protected(self): + def is_protected(self) -> bool: """ Determine if the function is protected using a check on msg.sender @@ -1216,7 +1383,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return True conditional_vars = self.all_conditional_solidity_variables_read(include_loop=False) args_vars = self.all_solidity_variables_used_as_args() - return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars + return SolidityVariableComposed("msg.sender") in conditional_vars + args_vars # endregion ################################################################################### @@ -1225,7 +1392,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def _filter_state_variables_written(self, expressions): + def _filter_state_variables_written(self, expressions: List["Expression"]): ret = [] for expression in expressions: if isinstance(expression, Identifier): @@ -1247,7 +1414,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping): write_var = [item for sublist in write_var for item in sublist] write_var = list(set(write_var)) # Remove dupplicate if they share the same string representation - write_var = [next(obj) for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))] + write_var = [ + next(obj) + for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x)) + ] self._expression_vars_written = write_var write_var = [x.variables_written for x in self.nodes] @@ -1255,32 +1425,39 @@ class Function(ChildContract, ChildInheritance, SourceMapping): write_var = [item for sublist in write_var for item in sublist] write_var = list(set(write_var)) # Remove dupplicate if they share the same string representation - write_var = [next(obj) for i, obj in \ - groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))] + write_var = [ + next(obj) + for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x)) + ] self._vars_written = write_var read_var = [x.variables_read_as_expression for x in self.nodes] read_var = [x for x in read_var if x] read_var = [item for sublist in read_var for item in sublist] # Remove dupplicate if they share the same string representation - read_var = [next(obj) for i, obj in \ - groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))] + read_var = [ + next(obj) + for i, obj in groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x)) + ] self._expression_vars_read = read_var read_var = [x.variables_read for x in self.nodes] read_var = [x for x in read_var if x] read_var = [item for sublist in read_var for item in sublist] # Remove dupplicate if they share the same string representation - read_var = [next(obj) for i, obj in \ - groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))] + read_var = [ + next(obj) + for i, obj in groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x)) + ] self._vars_read = read_var - self._state_vars_written = [x for x in self.variables_written if \ - isinstance(x, StateVariable)] - self._state_vars_read = [x for x in self.variables_read if \ - isinstance(x, (StateVariable))] - self._solidity_vars_read = [x for x in self.variables_read if \ - isinstance(x, (SolidityVariable))] + self._state_vars_written = [ + x for x in self.variables_written if isinstance(x, StateVariable) + ] + self._state_vars_read = [x for x in self.variables_read if isinstance(x, StateVariable)] + self._solidity_vars_read = [ + x for x in self.variables_read if isinstance(x, SolidityVariable) + ] self._vars_read_or_written = self._vars_written + self._vars_read @@ -1318,9 +1495,29 @@ class Function(ChildContract, ChildInheritance, SourceMapping): external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] external_calls_as_expressions = [x for x in external_calls_as_expressions if x] - external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist] + external_calls_as_expressions = [ + item for sublist in external_calls_as_expressions for item in sublist + ] self._external_calls_as_expressions = list(set(external_calls_as_expressions)) + # endregion + ################################################################################### + ################################################################################### + # region Nodes + ################################################################################### + ################################################################################### + + def new_node(self, node_type: "NodeType", src: Union[str, Dict]) -> "Node": + from slither.core.cfg.node import Node + + node = Node(node_type, self._counter_nodes) + node.set_offset(src, self.slither) + self._counter_nodes += 1 + node.set_function(self) + self._nodes.append(node) + + return node + # endregion ################################################################################### ################################################################################### @@ -1328,7 +1525,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def get_last_ssa_state_variables_instances(self): + def _get_last_ssa_variable_instances( + self, target_state: bool, target_local: bool + ) -> Dict[str, Set["SlithIRVariable"]]: from slither.slithir.variables import ReferenceVariable from slither.slithir.operations import OperationWithLValue from slither.core.cfg.node import NodeType @@ -1336,12 +1535,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if not self.is_implemented: return dict() + if self._entry_point is None: + return dict() # node, values - to_explore = [(self._entry_point, dict())] + to_explore: List[Tuple[Node, Dict]] = [(self._entry_point, dict())] # node -> values - explored = dict() + explored: Dict = dict() # name -> instances - ret = dict() + ret: Dict = dict() while to_explore: node, values = to_explore[0] @@ -1353,7 +1554,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): lvalue = ir_ssa.lvalue if isinstance(lvalue, ReferenceVariable): lvalue = lvalue.points_to_origin - if isinstance(lvalue, StateVariable): + if isinstance(lvalue, StateVariable) and target_state: + values[lvalue.canonical_name] = {lvalue} + if isinstance(lvalue, LocalVariable) and target_local: values[lvalue.canonical_name] = {lvalue} # Check for fixpoint @@ -1361,7 +1564,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if values == explored[node]: continue for k, instances in values.items(): - if not k in explored[node]: + if k not in explored[node]: explored[node][k] = set() explored[node][k] |= instances values = explored[node] @@ -1369,7 +1572,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): explored[node] = values # Return condition - if not node.sons and node.type != NodeType.THROW: + if node.will_return: for name, instances in values.items(): if name not in ret: ret[name] = set() @@ -1380,9 +1583,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return ret + def get_last_ssa_state_variables_instances(self) -> Dict[str, Set["SlithIRVariable"]]: + return self._get_last_ssa_variable_instances(target_state=True, target_local=False) + + def get_last_ssa_local_variables_instances(self) -> Dict[str, Set["SlithIRVariable"]]: + return self._get_last_ssa_variable_instances(target_state=False, target_local=True) + @staticmethod - def _unchange_phi(ir): - from slither.slithir.operations import (Phi, PhiCallback) + def _unchange_phi(ir: "Operation"): + from slither.slithir.operations import Phi, PhiCallback + if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1: return False if not ir.rvalues: @@ -1431,8 +1641,8 @@ class Function(ChildContract, ChildInheritance, SourceMapping): def generate_slithir_ssa(self, all_ssa_state_variables_instances): from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa - from slither.core.dominators.utils import (compute_dominance_frontier, - compute_dominators) + from slither.core.dominators.utils import compute_dominance_frontier, compute_dominators + compute_dominators(self.nodes) compute_dominance_frontier(self.nodes) transform_slithir_vars_to_ssa(self) diff --git a/slither/core/declarations/import_directive.py b/slither/core/declarations/import_directive.py index 8e22eb8c9..ec235c953 100644 --- a/slither/core/declarations/import_directive.py +++ b/slither/core/declarations/import_directive.py @@ -1,13 +1,13 @@ from slither.core.source_mapping.source_mapping import SourceMapping -class Import(SourceMapping): - def __init__(self, filename): +class Import(SourceMapping): + def __init__(self, filename: str): super(Import, self).__init__() self._filename = filename @property - def filename(self): + def filename(self) -> str: return self._filename def __str__(self): diff --git a/slither/core/declarations/modifier.py b/slither/core/declarations/modifier.py index 3530305b3..9dd61a353 100644 --- a/slither/core/declarations/modifier.py +++ b/slither/core/declarations/modifier.py @@ -3,5 +3,6 @@ """ from .function import Function -class Modifier(Function): pass +class Modifier(Function): + pass diff --git a/slither/core/declarations/pragma_directive.py b/slither/core/declarations/pragma_directive.py index d8ef80b7d..d43a62200 100644 --- a/slither/core/declarations/pragma_directive.py +++ b/slither/core/declarations/pragma_directive.py @@ -1,37 +1,39 @@ +from typing import List + from slither.core.source_mapping.source_mapping import SourceMapping -class Pragma(SourceMapping): - def __init__(self, directive): +class Pragma(SourceMapping): + def __init__(self, directive: List[str]): super(Pragma, self).__init__() self._directive = directive @property - def directive(self): - ''' + def directive(self) -> List[str]: + """ list(str) - ''' + """ return self._directive @property - def version(self): - return ''.join(self.directive[1:]) + def version(self) -> str: + return "".join(self.directive[1:]) @property - def name(self): + def name(self) -> str: return self.version @property - def is_solidity_version(self): + def is_solidity_version(self) -> bool: if len(self._directive) > 0: - return self._directive[0].lower() == 'solidity' + return self._directive[0].lower() == "solidity" return False @property - def is_abi_encoder_v2(self): + def is_abi_encoder_v2(self) -> bool: if len(self._directive) == 2: - return self._directive[0] == 'experimental' and self._directive[1] == 'ABIEncoderV2' + return self._directive[0] == "experimental" and self._directive[1] == "ABIEncoderV2" return False def __str__(self): - return 'pragma '+''.join(self.directive) + return "pragma " + "".join(self.directive) diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index da85e9b12..ff26dd1c7 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -1,64 +1,73 @@ # https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html +from typing import List, Dict, Union + from slither.core.context.context import Context from slither.core.solidity_types import ElementaryType, TypeInformation -SOLIDITY_VARIABLES = {"now":'uint256', - "this":'address', - 'abi':'address', # to simplify the conversion, assume that abi return an address - 'msg':'', - 'tx':'', - 'block':'', - 'super':''} - -SOLIDITY_VARIABLES_COMPOSED = {"block.coinbase":"address", - "block.difficulty":"uint256", - "block.gaslimit":"uint256", - "block.number":"uint256", - "block.timestamp":"uint256", - "block.blockhash":"uint256", # alias for blockhash. It's a call - "msg.data":"bytes", - "msg.gas":"uint256", - "msg.sender":"address", - "msg.sig":"bytes4", - "msg.value":"uint256", - "tx.gasprice":"uint256", - "tx.origin":"address"} - - -SOLIDITY_FUNCTIONS = {"gasleft()":['uint256'], - "assert(bool)":[], - "require(bool)":[], - "require(bool,string)":[], - "revert()":[], - "revert(string)":[], - "addmod(uint256,uint256,uint256)":['uint256'], - "mulmod(uint256,uint256,uint256)":['uint256'], - "keccak256()":['bytes32'], - "keccak256(bytes)":['bytes32'], # Solidity 0.5 - "sha256()":['bytes32'], - "sha256(bytes)":['bytes32'], # Solidity 0.5 - "sha3()":['bytes32'], - "ripemd160()":['bytes32'], - "ripemd160(bytes)":['bytes32'], # Solidity 0.5 - "ecrecover(bytes32,uint8,bytes32,bytes32)":['address'], - "selfdestruct(address)":[], - "suicide(address)":[], - "log0(bytes32)":[], - "log1(bytes32,bytes32)":[], - "log2(bytes32,bytes32,bytes32)":[], - "log3(bytes32,bytes32,bytes32,bytes32)":[], - "blockhash(uint256)":['bytes32'], - # the following need a special handling - # as they are recognized as a SolidityVariableComposed - # and converted to a SolidityFunction by SlithIR - "this.balance()":['uint256'], - "abi.encode()":['bytes'], - "abi.encodePacked()":['bytes'], - "abi.encodeWithSelector()":["bytes"], - "abi.encodeWithSignature()":["bytes"], - # abi.decode returns an a list arbitrary types - "abi.decode()":[], - "type(address)":[]} +SOLIDITY_VARIABLES = { + "now": "uint256", + "this": "address", + "abi": "address", # to simplify the conversion, assume that abi return an address + "msg": "", + "tx": "", + "block": "", + "super": "", +} + +SOLIDITY_VARIABLES_COMPOSED = { + "block.coinbase": "address", + "block.difficulty": "uint256", + "block.gaslimit": "uint256", + "block.number": "uint256", + "block.timestamp": "uint256", + "block.blockhash": "uint256", # alias for blockhash. It's a call + "msg.data": "bytes", + "msg.gas": "uint256", + "msg.sender": "address", + "msg.sig": "bytes4", + "msg.value": "uint256", + "tx.gasprice": "uint256", + "tx.origin": "address", +} + + +SOLIDITY_FUNCTIONS: Dict[str, List[str]] = { + "gasleft()": ["uint256"], + "assert(bool)": [], + "require(bool)": [], + "require(bool,string)": [], + "revert()": [], + "revert(string)": [], + "addmod(uint256,uint256,uint256)": ["uint256"], + "mulmod(uint256,uint256,uint256)": ["uint256"], + "keccak256()": ["bytes32"], + "keccak256(bytes)": ["bytes32"], # Solidity 0.5 + "sha256()": ["bytes32"], + "sha256(bytes)": ["bytes32"], # Solidity 0.5 + "sha3()": ["bytes32"], + "ripemd160()": ["bytes32"], + "ripemd160(bytes)": ["bytes32"], # Solidity 0.5 + "ecrecover(bytes32,uint8,bytes32,bytes32)": ["address"], + "selfdestruct(address)": [], + "suicide(address)": [], + "log0(bytes32)": [], + "log1(bytes32,bytes32)": [], + "log2(bytes32,bytes32,bytes32)": [], + "log3(bytes32,bytes32,bytes32,bytes32)": [], + "blockhash(uint256)": ["bytes32"], + # the following need a special handling + # as they are recognized as a SolidityVariableComposed + # and converted to a SolidityFunction by SlithIR + "this.balance()": ["uint256"], + "abi.encode()": ["bytes"], + "abi.encodePacked()": ["bytes"], + "abi.encodeWithSelector()": ["bytes"], + "abi.encodeWithSignature()": ["bytes"], + # abi.decode returns an a list arbitrary types + "abi.decode()": [], + "type(address)": [], +} + def solidity_function_signature(name): """ @@ -70,25 +79,25 @@ def solidity_function_signature(name): Returns: str """ - return name+' returns({})'.format(','.join(SOLIDITY_FUNCTIONS[name])) + return name + " returns({})".format(",".join(SOLIDITY_FUNCTIONS[name])) -class SolidityVariable(Context): - def __init__(self, name): +class SolidityVariable(Context): + def __init__(self, name: str): super(SolidityVariable, self).__init__() self._check_name(name) self._name = name # dev function, will be removed once the code is stable - def _check_name(self, name): + def _check_name(self, name: str): assert name in SOLIDITY_VARIABLES @property - def name(self): + def name(self) -> str: return self._name @property - def type(self): + def type(self) -> ElementaryType: return ElementaryType(SOLIDITY_VARIABLES[self.name]) def __str__(self): @@ -100,19 +109,20 @@ class SolidityVariable(Context): def __hash__(self): return hash(self.name) + class SolidityVariableComposed(SolidityVariable): - def __init__(self, name): + def __init__(self, name: str): super(SolidityVariableComposed, self).__init__(name) - def _check_name(self, name): + def _check_name(self, name: str): assert name in SOLIDITY_VARIABLES_COMPOSED @property - def name(self): + def name(self) -> str: return self._name @property - def type(self): + def type(self) -> ElementaryType: return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name]) def __str__(self): @@ -131,25 +141,28 @@ class SolidityFunction: # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # As a result, we set return_type during the Ir conversion - def __init__(self, name): + def __init__(self, name: str): assert name in SOLIDITY_FUNCTIONS self._name = name - self._return_type = [ElementaryType(x) for x in SOLIDITY_FUNCTIONS[self.name]] + # Can be TypeInformation if type(address) is used + self._return_type: List[Union[TypeInformation, ElementaryType]] = [ + ElementaryType(x) for x in SOLIDITY_FUNCTIONS[self.name] + ] @property - def name(self): + def name(self) -> str: return self._name @property - def full_name(self): + def full_name(self) -> str: return self.name @property - def return_type(self): + def return_type(self) -> List[Union[TypeInformation, ElementaryType]]: return self._return_type @return_type.setter - def return_type(self, r): + def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): self._return_type = r def __str__(self): diff --git a/slither/core/declarations/structure.py b/slither/core/declarations/structure.py index 087a06a1c..22bc96ac0 100644 --- a/slither/core/declarations/structure.py +++ b/slither/core/declarations/structure.py @@ -1,30 +1,43 @@ -from slither.core.source_mapping.source_mapping import SourceMapping +from typing import List, TYPE_CHECKING, Dict + from slither.core.children.child_contract import ChildContract +from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.variable import Variable +if TYPE_CHECKING: + from slither.core.variables.structure_variable import StructureVariable -class Structure(ChildContract, SourceMapping): +class Structure(ChildContract, SourceMapping): def __init__(self): super(Structure, self).__init__() self._name = None self._canonical_name = None - self._elems = None + self._elems: Dict[str, "StructureVariable"] = dict() # Name of the elements in the order of declaration - self._elems_ordered = None + self._elems_ordered: List[str] = [] @property - def canonical_name(self): + def canonical_name(self) -> str: return self._canonical_name + @canonical_name.setter + def canonical_name(self, name: str): + self._canonical_name = name + @property - def name(self): + def name(self) -> str: return self._name + @name.setter + def name(self, new_name: str): + self._name = new_name + @property - def elems(self): + def elems(self) -> Dict[str, "StructureVariable"]: return self._elems + def add_elem_in_order(self, s: str): + self._elems_ordered.append(s) def is_declared_by(self, contract): """ @@ -35,12 +48,11 @@ class Structure(ChildContract, SourceMapping): return self.contract == contract @property - def elems_ordered(self): + def elems_ordered(self) -> List["StructureVariable"]: ret = [] for e in self._elems_ordered: ret.append(self._elems[e]) return ret - def __str__(self): return self.name diff --git a/slither/core/dominators/node_dominator_tree.py b/slither/core/dominators/node_dominator_tree.py index 046acc439..b97279065 100644 --- a/slither/core/dominators/node_dominator_tree.py +++ b/slither/core/dominators/node_dominator_tree.py @@ -1,37 +1,27 @@ -''' +""" Nodes of the dominator tree -''' +""" +from typing import TYPE_CHECKING, Set, List -from slither.core.children.child_function import ChildFunction +if TYPE_CHECKING: + from slither.core.cfg.node import Node -class DominatorNode(object): +class DominatorNode(object): def __init__(self): - self._succ = set() - self._nodes = [] + self._succ: Set["Node"] = set() + self._nodes: List["Node"] = [] - def add_node(self, node): + def add_node(self, node: "Node"): self._nodes.append(node) - def add_successor(self, succ): + def add_successor(self, succ: "Node"): self._succ.add(succ) @property - def cfg_nodes(self): + def cfg_nodes(self) -> List["Node"]: return self._nodes @property - def sucessors(self): - ''' - Returns: - dict(Node) - ''' + def sucessors(self) -> Set["Node"]: return self._succ - -class DominatorTree(ChildFunction): - - def __init__(self, entry_point): - super(DominatorTree, self).__init__() - - - diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index 137af5dcf..b41409a58 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -1,6 +1,12 @@ +from typing import List, TYPE_CHECKING + from slither.core.cfg.node import NodeType -def intersection_predecessor(node): +if TYPE_CHECKING: + from slither.core.cfg.node import Node + + +def intersection_predecessor(node: "Node"): if not node.fathers: return set() ret = node.fathers[0].dominators @@ -8,13 +14,14 @@ def intersection_predecessor(node): ret = ret.intersection(pred.dominators) return ret -def compute_dominators(nodes): - ''' + +def compute_dominators(nodes: List["Node"]): + """ Naive implementation of Cooper, Harvey, Kennedy algo See 'A Simple,Fast Dominance Algorithm' Compute strict domniators - ''' + """ changed = True for n in nodes: @@ -36,33 +43,38 @@ def compute_dominators(nodes): for dominator in node.dominators: if dominator != node: - [idom_candidates.remove(d) for d in dominator.dominators if d in idom_candidates and d!=dominator] + [ + idom_candidates.remove(d) + for d in dominator.dominators + if d in idom_candidates and d != dominator + ] - assert len(idom_candidates)<=1 + assert len(idom_candidates) <= 1 if idom_candidates: idom = idom_candidates.pop() node.immediate_dominator = idom idom.dominator_successors.add(node) - -def compute_dominance_frontier(nodes): - ''' +def compute_dominance_frontier(nodes: List["Node"]): + """ Naive implementation of Cooper, Harvey, Kennedy algo See 'A Simple,Fast Dominance Algorithm' Compute dominance frontier - ''' + """ for node in nodes: if len(node.fathers) >= 2: for father in node.fathers: runner = father # Corner case: if there is a if without else - # we need to add update the conditional node - if runner == node.immediate_dominator and runner.type == NodeType.IF and node.type == NodeType.ENDIF: + # we need to add update the conditional node + if ( + runner == node.immediate_dominator + and runner.type == NodeType.IF + and node.type == NodeType.ENDIF + ): runner.dominance_frontier = runner.dominance_frontier.union({node}) while runner != node.immediate_dominator: runner.dominance_frontier = runner.dominance_frontier.union({node}) runner = runner.immediate_dominator - - diff --git a/slither/core/exceptions.py b/slither/core/exceptions.py index 3f4db4bbf..c928689a5 100644 --- a/slither/core/exceptions.py +++ b/slither/core/exceptions.py @@ -1,3 +1,5 @@ from slither.exceptions import SlitherException -class SlitherCoreError(SlitherException): pass + +class SlitherCoreError(SlitherException): + pass diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index 749e3843a..c1677d9b5 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -1,111 +1,118 @@ import logging +from enum import Enum +from typing import Optional, TYPE_CHECKING, List + from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError +if TYPE_CHECKING: + from slither.core.solidity_types.type import Type + logger = logging.getLogger("AssignmentOperation") -class AssignmentOperationType: - ASSIGN = 0 # = - ASSIGN_OR = 1 # |= - ASSIGN_CARET = 2 # ^= - ASSIGN_AND = 3 # &= - ASSIGN_LEFT_SHIFT = 4 # <<= - ASSIGN_RIGHT_SHIFT = 5 # >>= - ASSIGN_ADDITION = 6 # += - ASSIGN_SUBTRACTION = 7 # -= - ASSIGN_MULTIPLICATION = 8 # *= - ASSIGN_DIVISION = 9 # /= - ASSIGN_MODULO = 10 # %= + +class AssignmentOperationType(Enum): + ASSIGN = 0 # = + ASSIGN_OR = 1 # |= + ASSIGN_CARET = 2 # ^= + ASSIGN_AND = 3 # &= + ASSIGN_LEFT_SHIFT = 4 # <<= + ASSIGN_RIGHT_SHIFT = 5 # >>= + ASSIGN_ADDITION = 6 # += + ASSIGN_SUBTRACTION = 7 # -= + ASSIGN_MULTIPLICATION = 8 # *= + ASSIGN_DIVISION = 9 # /= + ASSIGN_MODULO = 10 # %= @staticmethod - def get_type(operation_type): - if operation_type == '=': + def get_type(operation_type: "AssignmentOperationType"): + if operation_type == "=": return AssignmentOperationType.ASSIGN - if operation_type == '|=': + if operation_type == "|=": return AssignmentOperationType.ASSIGN_OR - if operation_type == '^=': + if operation_type == "^=": return AssignmentOperationType.ASSIGN_CARET - if operation_type == '&=': + if operation_type == "&=": return AssignmentOperationType.ASSIGN_AND - if operation_type == '<<=': + if operation_type == "<<=": return AssignmentOperationType.ASSIGN_LEFT_SHIFT - if operation_type == '>>=': + if operation_type == ">>=": return AssignmentOperationType.ASSIGN_RIGHT_SHIFT - if operation_type == '+=': + if operation_type == "+=": return AssignmentOperationType.ASSIGN_ADDITION - if operation_type == '-=': + if operation_type == "-=": return AssignmentOperationType.ASSIGN_SUBTRACTION - if operation_type == '*=': + if operation_type == "*=": return AssignmentOperationType.ASSIGN_MULTIPLICATION - if operation_type == '/=': + if operation_type == "/=": return AssignmentOperationType.ASSIGN_DIVISION - if operation_type == '%=': + if operation_type == "%=": return AssignmentOperationType.ASSIGN_MODULO - raise SlitherCoreError('get_type: Unknown operation type {})'.format(operation_type)) + raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) - @staticmethod - def str(operation_type): - if operation_type == AssignmentOperationType.ASSIGN: - return '=' - if operation_type == AssignmentOperationType.ASSIGN_OR: - return '|=' - if operation_type == AssignmentOperationType.ASSIGN_CARET: - return '^=' - if operation_type == AssignmentOperationType.ASSIGN_AND: - return '&=' - if operation_type == AssignmentOperationType.ASSIGN_LEFT_SHIFT: - return '<<=' - if operation_type == AssignmentOperationType.ASSIGN_RIGHT_SHIFT: - return '>>=' - if operation_type == AssignmentOperationType.ASSIGN_ADDITION: - return '+=' - if operation_type == AssignmentOperationType.ASSIGN_SUBTRACTION: - return '-=' - if operation_type == AssignmentOperationType.ASSIGN_MULTIPLICATION: - return '*=' - if operation_type == AssignmentOperationType.ASSIGN_DIVISION: - return '/=' - if operation_type == AssignmentOperationType.ASSIGN_MODULO: - return '%=' - - raise SlitherCoreError('str: Unknown operation type {})'.format(operation_type)) + def __str__(self): + if self == AssignmentOperationType.ASSIGN: + return "=" + if self == AssignmentOperationType.ASSIGN_OR: + return "|=" + if self == AssignmentOperationType.ASSIGN_CARET: + return "^=" + if self == AssignmentOperationType.ASSIGN_AND: + return "&=" + if self == AssignmentOperationType.ASSIGN_LEFT_SHIFT: + return "<<=" + if self == AssignmentOperationType.ASSIGN_RIGHT_SHIFT: + return ">>=" + if self == AssignmentOperationType.ASSIGN_ADDITION: + return "+=" + if self == AssignmentOperationType.ASSIGN_SUBTRACTION: + return "-=" + if self == AssignmentOperationType.ASSIGN_MULTIPLICATION: + return "*=" + if self == AssignmentOperationType.ASSIGN_DIVISION: + return "/=" + if self == AssignmentOperationType.ASSIGN_MODULO: + return "%=" + raise SlitherCoreError("str: Unknown operation type {})".format(self)) -class AssignmentOperation(ExpressionTyped): - def __init__(self, left_expression, right_expression, expression_type, expression_return_type): +class AssignmentOperation(ExpressionTyped): + def __init__( + self, + left_expression: Expression, + right_expression: Expression, + expression_type: AssignmentOperationType, + expression_return_type: Optional["Type"], + ): assert isinstance(left_expression, Expression) assert isinstance(right_expression, Expression) super(AssignmentOperation, self).__init__() left_expression.set_lvalue() self._expressions = [left_expression, right_expression] self._type = expression_type - self._expression_return_type = expression_return_type + self._expression_return_type: Optional["Type"] = expression_return_type @property - def expressions(self): + def expressions(self) -> List[Expression]: return self._expressions @property - def expression_return_type(self): + def expression_return_type(self) -> Optional["Type"]: return self._expression_return_type @property - def expression_left(self): + def expression_left(self) -> Expression: return self._expressions[0] @property - def expression_right(self): + def expression_right(self) -> Expression: return self._expressions[1] @property - def type(self): + def type(self) -> AssignmentOperationType: return self._type - @property - def type_str(self): - return AssignmentOperationType.str(self._type) - def __str__(self): - return str(self.expression_left) + " "+ self.type_str + " " + str(self.expression_right) + return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) diff --git a/slither/core/expressions/binary_operation.py b/slither/core/expressions/binary_operation.py index 8ede21931..0a1c4b7a2 100644 --- a/slither/core/expressions/binary_operation.py +++ b/slither/core/expressions/binary_operation.py @@ -1,145 +1,144 @@ import logging +from enum import Enum +from typing import List + from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError + logger = logging.getLogger("BinaryOperation") -class BinaryOperationType: - POWER = 0 # ** - MULTIPLICATION = 1 # * - DIVISION = 2 # / - MODULO = 3 # % - ADDITION = 4 # + - SUBTRACTION = 5 # - - LEFT_SHIFT = 6 # << - RIGHT_SHIFT = 7 # >>> - AND = 8 # & - CARET = 9 # ^ - OR = 10 # | - LESS = 11 # < - GREATER = 12 # > - LESS_EQUAL = 13 # <= - GREATER_EQUAL = 14 # >= - EQUAL = 15 # == - NOT_EQUAL = 16 # != - ANDAND = 17 # && - OROR = 18 # || + +class BinaryOperationType(Enum): + POWER = 0 # ** + MULTIPLICATION = 1 # * + DIVISION = 2 # / + MODULO = 3 # % + ADDITION = 4 # + + SUBTRACTION = 5 # - + LEFT_SHIFT = 6 # << + RIGHT_SHIFT = 7 # >>> + AND = 8 # & + CARET = 9 # ^ + OR = 10 # | + LESS = 11 # < + GREATER = 12 # > + LESS_EQUAL = 13 # <= + GREATER_EQUAL = 14 # >= + EQUAL = 15 # == + NOT_EQUAL = 16 # != + ANDAND = 17 # && + OROR = 18 # || @staticmethod - def get_type(operation_type): - if operation_type == '**': + def get_type(operation_type: "BinaryOperation"): + if operation_type == "**": return BinaryOperationType.POWER - if operation_type == '*': + if operation_type == "*": return BinaryOperationType.MULTIPLICATION - if operation_type == '/': + if operation_type == "/": return BinaryOperationType.DIVISION - if operation_type == '%': + if operation_type == "%": return BinaryOperationType.MODULO - if operation_type == '+': + if operation_type == "+": return BinaryOperationType.ADDITION - if operation_type == '-': + if operation_type == "-": return BinaryOperationType.SUBTRACTION - if operation_type == '<<': + if operation_type == "<<": return BinaryOperationType.LEFT_SHIFT - if operation_type == '>>': + if operation_type == ">>": return BinaryOperationType.RIGHT_SHIFT - if operation_type == '&': + if operation_type == "&": return BinaryOperationType.AND - if operation_type == '^': + if operation_type == "^": return BinaryOperationType.CARET - if operation_type == '|': + if operation_type == "|": return BinaryOperationType.OR - if operation_type == '<': + if operation_type == "<": return BinaryOperationType.LESS - if operation_type == '>': + if operation_type == ">": return BinaryOperationType.GREATER - if operation_type == '<=': + if operation_type == "<=": return BinaryOperationType.LESS_EQUAL - if operation_type == '>=': + if operation_type == ">=": return BinaryOperationType.GREATER_EQUAL - if operation_type == '==': + if operation_type == "==": return BinaryOperationType.EQUAL - if operation_type == '!=': + if operation_type == "!=": return BinaryOperationType.NOT_EQUAL - if operation_type == '&&': + if operation_type == "&&": return BinaryOperationType.ANDAND - if operation_type == '||': + if operation_type == "||": return BinaryOperationType.OROR - raise SlitherCoreError('get_type: Unknown operation type {})'.format(operation_type)) + raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) - @staticmethod - def str(operation_type): - if operation_type == BinaryOperationType.POWER: - return '**' - if operation_type == BinaryOperationType.MULTIPLICATION: - return '*' - if operation_type == BinaryOperationType.DIVISION: - return '/' - if operation_type == BinaryOperationType.MODULO: - return '%' - if operation_type == BinaryOperationType.ADDITION: - return '+' - if operation_type == BinaryOperationType.SUBTRACTION: - return '-' - if operation_type == BinaryOperationType.LEFT_SHIFT: - return '<<' - if operation_type == BinaryOperationType.RIGHT_SHIFT: - return '>>' - if operation_type == BinaryOperationType.AND: - return '&' - if operation_type == BinaryOperationType.CARET: - return '^' - if operation_type == BinaryOperationType.OR: - return '|' - if operation_type == BinaryOperationType.LESS: - return '<' - if operation_type == BinaryOperationType.GREATER: - return '>' - if operation_type == BinaryOperationType.LESS_EQUAL: - return '<=' - if operation_type == BinaryOperationType.GREATER_EQUAL: - return '>=' - if operation_type == BinaryOperationType.EQUAL: - return '==' - if operation_type == BinaryOperationType.NOT_EQUAL: - return '!=' - if operation_type == BinaryOperationType.ANDAND: - return '&&' - if operation_type == BinaryOperationType.OROR: - return '||' - raise SlitherCoreError('str: Unknown operation type {})'.format(operation_type)) + def __str__(self): + if self == BinaryOperationType.POWER: + return "**" + if self == BinaryOperationType.MULTIPLICATION: + return "*" + if self == BinaryOperationType.DIVISION: + return "/" + if self == BinaryOperationType.MODULO: + return "%" + if self == BinaryOperationType.ADDITION: + return "+" + if self == BinaryOperationType.SUBTRACTION: + return "-" + if self == BinaryOperationType.LEFT_SHIFT: + return "<<" + if self == BinaryOperationType.RIGHT_SHIFT: + return ">>" + if self == BinaryOperationType.AND: + return "&" + if self == BinaryOperationType.CARET: + return "^" + if self == BinaryOperationType.OR: + return "|" + if self == BinaryOperationType.LESS: + return "<" + if self == BinaryOperationType.GREATER: + return ">" + if self == BinaryOperationType.LESS_EQUAL: + return "<=" + if self == BinaryOperationType.GREATER_EQUAL: + return ">=" + if self == BinaryOperationType.EQUAL: + return "==" + if self == BinaryOperationType.NOT_EQUAL: + return "!=" + if self == BinaryOperationType.ANDAND: + return "&&" + if self == BinaryOperationType.OROR: + return "||" + raise SlitherCoreError("str: Unknown operation type {})".format(self)) -class BinaryOperation(ExpressionTyped): +class BinaryOperation(ExpressionTyped): def __init__(self, left_expression, right_expression, expression_type): assert isinstance(left_expression, Expression) assert isinstance(right_expression, Expression) super(BinaryOperation, self).__init__() self._expressions = [left_expression, right_expression] - self._type = expression_type + self._type: BinaryOperationType = expression_type @property - def expressions(self): + def expressions(self) -> List[Expression]: return self._expressions @property - def expression_left(self): + def expression_left(self) -> Expression: return self._expressions[0] @property - def expression_right(self): + def expression_right(self) -> Expression: return self._expressions[1] @property - def type(self): + def type(self) -> BinaryOperationType: return self._type - @property - def type_str(self): - return BinaryOperationType.str(self._type) - def __str__(self): - return str(self.expression_left) + ' ' + self.type_str + ' ' + str(self.expression_right) - + return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) diff --git a/slither/core/expressions/call_expression.py b/slither/core/expressions/call_expression.py index 955cac8b7..229b8daf1 100644 --- a/slither/core/expressions/call_expression.py +++ b/slither/core/expressions/call_expression.py @@ -1,23 +1,24 @@ +from typing import Optional, List + from slither.core.expressions.expression import Expression class CallExpression(Expression): - def __init__(self, called, arguments, type_call): assert isinstance(called, Expression) super(CallExpression, self).__init__() - self._called = called - self._arguments = arguments - self._type_call = type_call + self._called: Expression = called + self._arguments: List[Expression] = arguments + self._type_call: str = type_call # gas and value are only available if the syntax is {gas: , value: } # For the .gas().value(), the member are considered as function call # And converted later to the correct info (convert.py) - self._gas = None - self._value = None - self._salt = None + self._gas: Optional[Expression] = None + self._value: Optional[Expression] = None + self._salt: Optional[Expression] = None @property - def call_value(self): + def call_value(self) -> Optional[Expression]: return self._value @call_value.setter @@ -25,7 +26,7 @@ class CallExpression(Expression): self._value = v @property - def call_gas(self): + def call_gas(self) -> Optional[Expression]: return self._gas @call_gas.setter @@ -41,24 +42,32 @@ class CallExpression(Expression): self._salt = salt @property - def called(self): + def call_salt(self): + return self._salt + + @call_salt.setter + def call_salt(self, salt): + self._salt = salt + + @property + def called(self) -> Expression: return self._called @property - def arguments(self): + def arguments(self) -> List[Expression]: return self._arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call def __str__(self): txt = str(self._called) if self.call_gas or self.call_value: - gas = f'gas: {self.call_gas}' if self.call_gas else '' - value = f'value: {self.call_value}' if self.call_value else '' - salt = f'salt: {self.call_salt}' if self.call_salt else '' + gas = f"gas: {self.call_gas}" if self.call_gas else "" + value = f"value: {self.call_value}" if self.call_value else "" + salt = f"salt: {self.call_salt}" if self.call_salt else "" if gas or value or salt: options = [gas, value, salt] - txt += '{' + ','.join([o for o in options if o != '']) + '}' - return txt + '(' + ','.join([str(a) for a in self._arguments]) + ')' + txt += "{" + ",".join([o for o in options if o != ""]) + "}" + return txt + "(" + ",".join([str(a) for a in self._arguments]) + ")" diff --git a/slither/core/expressions/conditional_expression.py b/slither/core/expressions/conditional_expression.py index b26573ce6..c4e15f14f 100644 --- a/slither/core/expressions/conditional_expression.py +++ b/slither/core/expressions/conditional_expression.py @@ -1,32 +1,40 @@ +from typing import List + from .expression import Expression -class ConditionalExpression(Expression): +class ConditionalExpression(Expression): def __init__(self, if_expression, then_expression, else_expression): assert isinstance(if_expression, Expression) assert isinstance(then_expression, Expression) assert isinstance(else_expression, Expression) super(ConditionalExpression, self).__init__() - self._if_expression = if_expression - self._then_expression = then_expression - self._else_expression = else_expression + self._if_expression: Expression = if_expression + self._then_expression: Expression = then_expression + self._else_expression: Expression = else_expression @property - def expressions(self): + def expressions(self) -> List[Expression]: return [self._if_expression, self._then_expression, self._else_expression] @property - def if_expression(self): + def if_expression(self) -> Expression: return self._if_expression @property - def else_expression(self): + def else_expression(self) -> Expression: return self._else_expression @property - def then_expression(self): + def then_expression(self) -> Expression: return self._then_expression def __str__(self): - return 'if ' + str(self._if_expression) + ' then ' + str(self._then_expression) + ' else ' + str(self._else_expression) - + return ( + "if " + + str(self._if_expression) + + " then " + + str(self._then_expression) + + " else " + + str(self._else_expression) + ) diff --git a/slither/core/expressions/elementary_type_name_expression.py b/slither/core/expressions/elementary_type_name_expression.py index cf8c07c17..6c10acd5b 100644 --- a/slither/core/expressions/elementary_type_name_expression.py +++ b/slither/core/expressions/elementary_type_name_expression.py @@ -4,17 +4,16 @@ from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type -class ElementaryTypeNameExpression(Expression): +class ElementaryTypeNameExpression(Expression): def __init__(self, t): assert isinstance(t, Type) super(ElementaryTypeNameExpression, self).__init__() self._type = t @property - def type(self): + def type(self) -> Type: return self._type def __str__(self): return str(self._type) - diff --git a/slither/core/expressions/expression.py b/slither/core/expressions/expression.py index 6e9bd7aba..792430d63 100644 --- a/slither/core/expressions/expression.py +++ b/slither/core/expressions/expression.py @@ -1,15 +1,14 @@ from slither.core.source_mapping.source_mapping import SourceMapping -class Expression( SourceMapping): +class Expression(SourceMapping): def __init__(self): super(Expression, self).__init__() self._is_lvalue = False @property - def is_lvalue(self): + def is_lvalue(self) -> bool: return self._is_lvalue def set_lvalue(self): self._is_lvalue = True - diff --git a/slither/core/expressions/expression_typed.py b/slither/core/expressions/expression_typed.py index 3df80a6f7..8aaf2b756 100644 --- a/slither/core/expressions/expression_typed.py +++ b/slither/core/expressions/expression_typed.py @@ -1,13 +1,16 @@ +from typing import Optional, TYPE_CHECKING from .expression import Expression -class ExpressionTyped(Expression): +if TYPE_CHECKING: + from ..solidity_types.type import Type + +class ExpressionTyped(Expression): def __init__(self): super(ExpressionTyped, self).__init__() - self._type = None + self._type: Optional["Type"] = None @property def type(self): return self._type - diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index cb92d8329..0efdc22b1 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -1,15 +1,19 @@ +from typing import TYPE_CHECKING + from slither.core.expressions.expression_typed import ExpressionTyped -class Identifier(ExpressionTyped): +if TYPE_CHECKING: + from slither.core.variables.variable import Variable + +class Identifier(ExpressionTyped): def __init__(self, value): super(Identifier, self).__init__() - self._value = value + self._value: "Variable" = value @property - def value(self): + def value(self) -> "Variable": return self._value def __str__(self): return str(self._value) - diff --git a/slither/core/expressions/index_access.py b/slither/core/expressions/index_access.py index 5e02bc6fa..e60836164 100644 --- a/slither/core/expressions/index_access.py +++ b/slither/core/expressions/index_access.py @@ -1,32 +1,36 @@ +from typing import List, TYPE_CHECKING + from slither.core.expressions.expression_typed import ExpressionTyped -from slither.core.solidity_types.type import Type -class IndexAccess(ExpressionTyped): +if TYPE_CHECKING: + from slither.core.expressions.expression import Expression + from slither.core.solidity_types.type import Type + +class IndexAccess(ExpressionTyped): def __init__(self, left_expression, right_expression, index_type): super(IndexAccess, self).__init__() self._expressions = [left_expression, right_expression] # TODO type of undexAccess is not always a Type -# assert isinstance(index_type, Type) - self._type = index_type + # assert isinstance(index_type, Type) + self._type: "Type" = index_type @property - def expressions(self): + def expressions(self) -> List["Expression"]: return self._expressions @property - def expression_left(self): + def expression_left(self) -> "Expression": return self._expressions[0] @property - def expression_right(self): + def expression_right(self) -> "Expression": return self._expressions[1] @property - def type(self): + def type(self) -> "Type": return self._type def __str__(self): - return str(self.expression_left) + '[' + str(self.expression_right) + ']' - + return str(self.expression_left) + "[" + str(self.expression_right) + "]" diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 1e9b96ec3..620dc83f2 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -1,24 +1,29 @@ +from typing import Optional, Union, TYPE_CHECKING + from slither.core.expressions.expression import Expression from slither.utils.arithmetic import convert_subdenomination -class Literal(Expression): +if TYPE_CHECKING: + from slither.core.solidity_types.type import Type + +class Literal(Expression): def __init__(self, value, type, subdenomination=None): super(Literal, self).__init__() - self._value = value + self._value: Union[int, str] = value self._type = type - self._subdenomination = subdenomination + self._subdenomination: Optional[str] = subdenomination @property - def value(self): + def value(self) -> Union[int, str]: return self._value @property - def type(self): + def type(self) -> "Type": return self._type @property - def subdenomination(self): + def subdenomination(self) -> Optional[str]: return self._subdenomination def __str__(self): diff --git a/slither/core/expressions/member_access.py b/slither/core/expressions/member_access.py index d93583142..5096785bd 100644 --- a/slither/core/expressions/member_access.py +++ b/slither/core/expressions/member_access.py @@ -1,29 +1,33 @@ +from typing import TYPE_CHECKING + from slither.core.expressions.expression import Expression from slither.core.expressions.expression_typed import ExpressionTyped +if TYPE_CHECKING: + from slither.core.solidity_types.type import Type -class MemberAccess(ExpressionTyped): +class MemberAccess(ExpressionTyped): def __init__(self, member_name, member_type, expression): # assert isinstance(member_type, Type) # TODO member_type is not always a Type assert isinstance(expression, Expression) super(MemberAccess, self).__init__() - self._type = member_type - self._member_name = member_name - self._expression = expression + self._type: "Type" = member_type + self._member_name: str = member_name + self._expression: Expression = expression @property - def expression(self): + def expression(self) -> Expression: return self._expression @property - def member_name(self): + def member_name(self) -> str: return self._member_name @property - def type(self): + def type(self) -> "Type": return self._type def __str__(self): - return str(self.expression) + '.' + self.member_name + return str(self.expression) + "." + self.member_name diff --git a/slither/core/expressions/new_array.py b/slither/core/expressions/new_array.py index 9c59a75a9..c4e126519 100644 --- a/slither/core/expressions/new_array.py +++ b/slither/core/expressions/new_array.py @@ -1,23 +1,23 @@ from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type + class NewArray(Expression): # note: dont conserve the size of the array if provided def __init__(self, depth, array_type): super(NewArray, self).__init__() assert isinstance(array_type, Type) - self._depth = depth - self._array_type = array_type + self._depth: int = depth + self._array_type: Type = array_type @property - def array_type(self): + def array_type(self) -> Type: return self._array_type @property - def depth(self): + def depth(self) -> int: return self._depth def __str__(self): - return 'new ' + str(self._array_type) + '[]'* self._depth - + return "new " + str(self._array_type) + "[]" * self._depth diff --git a/slither/core/expressions/new_contract.py b/slither/core/expressions/new_contract.py index 8bfc6de58..92021a3d2 100644 --- a/slither/core/expressions/new_contract.py +++ b/slither/core/expressions/new_contract.py @@ -1,17 +1,16 @@ from .expression import Expression -class NewContract(Expression): +class NewContract(Expression): def __init__(self, contract_name): super(NewContract, self).__init__() - self._contract_name = contract_name + self._contract_name: str = contract_name self._gas = None self._value = None self._salt = None - @property - def contract_name(self): + def contract_name(self) -> str: return self._contract_name @property @@ -30,7 +29,5 @@ class NewContract(Expression): def call_salt(self, salt): self._salt = salt - def __str__(self): - return 'new ' + str(self._contract_name) - + return "new " + str(self._contract_name) diff --git a/slither/core/expressions/new_elementary_type.py b/slither/core/expressions/new_elementary_type.py index b099bed9e..c3a24f086 100644 --- a/slither/core/expressions/new_elementary_type.py +++ b/slither/core/expressions/new_elementary_type.py @@ -1,17 +1,16 @@ from slither.core.expressions.expression import Expression from slither.core.solidity_types.elementary_type import ElementaryType -class NewElementaryType(Expression): +class NewElementaryType(Expression): def __init__(self, new_type): assert isinstance(new_type, ElementaryType) super(NewElementaryType, self).__init__() self._type = new_type @property - def type(self): + def type(self) -> ElementaryType: return self._type def __str__(self): - return 'new ' + str(self._type) - + return "new " + str(self._type) diff --git a/slither/core/expressions/super_call_expression.py b/slither/core/expressions/super_call_expression.py index 420e324f9..c8b0dd9f8 100644 --- a/slither/core/expressions/super_call_expression.py +++ b/slither/core/expressions/super_call_expression.py @@ -1,4 +1,6 @@ from slither.core.expressions.expression import Expression from slither.core.expressions.call_expression import CallExpression -class SuperCallExpression(CallExpression): pass + +class SuperCallExpression(CallExpression): + pass diff --git a/slither/core/expressions/super_identifier.py b/slither/core/expressions/super_identifier.py index 33299b9a9..8c60d6d91 100644 --- a/slither/core/expressions/super_identifier.py +++ b/slither/core/expressions/super_identifier.py @@ -1,8 +1,6 @@ -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.identifier import Identifier -class SuperIdentifier(Identifier): +class SuperIdentifier(Identifier): def __str__(self): - return 'super.' + str(self._value) - + return "super." + str(self._value) diff --git a/slither/core/expressions/tuple_expression.py b/slither/core/expressions/tuple_expression.py index 90c7f91c0..3bb4f7c3d 100644 --- a/slither/core/expressions/tuple_expression.py +++ b/slither/core/expressions/tuple_expression.py @@ -1,17 +1,18 @@ +from typing import List + from slither.core.expressions.expression import Expression -class TupleExpression(Expression): +class TupleExpression(Expression): def __init__(self, expressions): assert all(isinstance(x, Expression) for x in expressions if x) super(TupleExpression, self).__init__() self._expressions = expressions @property - def expressions(self): + def expressions(self) -> List[Expression]: return self._expressions def __str__(self): expressions_str = [str(e) for e in self.expressions] - return '(' + ','.join(expressions_str) + ')' - + return "(" + ",".join(expressions_str) + ")" diff --git a/slither/core/expressions/type_conversion.py b/slither/core/expressions/type_conversion.py index 7fe165754..599af5fd7 100644 --- a/slither/core/expressions/type_conversion.py +++ b/slither/core/expressions/type_conversion.py @@ -4,18 +4,16 @@ from slither.core.solidity_types.type import Type class TypeConversion(ExpressionTyped): - def __init__(self, expression, expression_type): super(TypeConversion, self).__init__() assert isinstance(expression, Expression) assert isinstance(expression_type, Type) - self._expression = expression - self._type = expression_type + self._expression: Expression = expression + self._type: Type = expression_type @property - def expression(self): + def expression(self) -> Expression: return self._expression def __str__(self): - return str(self.type) + '(' + str(self.expression) + ')' - + return str(self.type) + "(" + str(self.expression) + ")" diff --git a/slither/core/expressions/unary_operation.py b/slither/core/expressions/unary_operation.py index 82cb52eb5..72d2f8410 100644 --- a/slither/core/expressions/unary_operation.py +++ b/slither/core/expressions/unary_operation.py @@ -1,114 +1,121 @@ import logging +from enum import Enum + from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError logger = logging.getLogger("UnaryOperation") -class UnaryOperationType: - BANG = 0 # ! - TILD = 1 # ~ - DELETE = 2 # delete - PLUSPLUS_PRE = 3 # ++ - MINUSMINUS_PRE = 4 # -- - PLUSPLUS_POST = 5 # ++ - MINUSMINUS_POST = 6 # -- - PLUS_PRE = 7 # for stuff like uint(+1) - MINUS_PRE = 8 # for stuff like uint(-1) + +class UnaryOperationType(Enum): + BANG = 0 # ! + TILD = 1 # ~ + DELETE = 2 # delete + PLUSPLUS_PRE = 3 # ++ + MINUSMINUS_PRE = 4 # -- + PLUSPLUS_POST = 5 # ++ + MINUSMINUS_POST = 6 # -- + PLUS_PRE = 7 # for stuff like uint(+1) + MINUS_PRE = 8 # for stuff like uint(-1) @staticmethod def get_type(operation_type, isprefix): if isprefix: - if operation_type == '!': + if operation_type == "!": return UnaryOperationType.BANG - if operation_type == '~': + if operation_type == "~": return UnaryOperationType.TILD - if operation_type == 'delete': + if operation_type == "delete": return UnaryOperationType.DELETE - if operation_type == '++': + if operation_type == "++": return UnaryOperationType.PLUSPLUS_PRE - if operation_type == '--': + if operation_type == "--": return UnaryOperationType.MINUSMINUS_PRE - if operation_type == '+': + if operation_type == "+": return UnaryOperationType.PLUS_PRE - if operation_type == '-': + if operation_type == "-": return UnaryOperationType.MINUS_PRE else: - if operation_type == '++': + if operation_type == "++": return UnaryOperationType.PLUSPLUS_POST - if operation_type == '--': + if operation_type == "--": return UnaryOperationType.MINUSMINUS_POST - raise SlitherCoreError('get_type: Unknown operation type {}'.format(operation_type)) + raise SlitherCoreError("get_type: Unknown operation type {}".format(operation_type)) - @staticmethod - def str(operation_type): - if operation_type == UnaryOperationType.BANG: - return '!' - if operation_type == UnaryOperationType.TILD: - return '~' - if operation_type == UnaryOperationType.DELETE: - return 'delete' - if operation_type == UnaryOperationType.PLUS_PRE: - return '+' - if operation_type == UnaryOperationType.MINUS_PRE: - return '-' - if operation_type in [UnaryOperationType.PLUSPLUS_PRE, UnaryOperationType.PLUSPLUS_POST]: - return '++' - if operation_type in [UnaryOperationType.MINUSMINUS_PRE, UnaryOperationType.MINUSMINUS_POST]: - return '--' + def __str__(self): + if self == UnaryOperationType.BANG: + return "!" + if self == UnaryOperationType.TILD: + return "~" + if self == UnaryOperationType.DELETE: + return "delete" + if self == UnaryOperationType.PLUS_PRE: + return "+" + if self == UnaryOperationType.MINUS_PRE: + return "-" + if self in [UnaryOperationType.PLUSPLUS_PRE, UnaryOperationType.PLUSPLUS_POST]: + return "++" + if self in [ + UnaryOperationType.MINUSMINUS_PRE, + UnaryOperationType.MINUSMINUS_POST, + ]: + return "--" - raise SlitherCoreError('str: Unknown operation type {}'.format(operation_type)) + raise SlitherCoreError("str: Unknown operation type {}".format(self)) @staticmethod def is_prefix(operation_type): - if operation_type in [UnaryOperationType.BANG, - UnaryOperationType.TILD, - UnaryOperationType.DELETE, - UnaryOperationType.PLUSPLUS_PRE, - UnaryOperationType.MINUSMINUS_PRE, - UnaryOperationType.PLUS_PRE, - UnaryOperationType.MINUS_PRE]: + if operation_type in [ + UnaryOperationType.BANG, + UnaryOperationType.TILD, + UnaryOperationType.DELETE, + UnaryOperationType.PLUSPLUS_PRE, + UnaryOperationType.MINUSMINUS_PRE, + UnaryOperationType.PLUS_PRE, + UnaryOperationType.MINUS_PRE, + ]: return True - elif operation_type in [UnaryOperationType.PLUSPLUS_POST, UnaryOperationType.MINUSMINUS_POST]: + elif operation_type in [ + UnaryOperationType.PLUSPLUS_POST, + UnaryOperationType.MINUSMINUS_POST, + ]: return False - raise SlitherCoreError('is_prefix: Unknown operation type {}'.format(operation_type)) + raise SlitherCoreError("is_prefix: Unknown operation type {}".format(operation_type)) -class UnaryOperation(ExpressionTyped): +class UnaryOperation(ExpressionTyped): def __init__(self, expression, expression_type): assert isinstance(expression, Expression) super(UnaryOperation, self).__init__() - self._expression = expression - self._type = expression_type - if expression_type in [UnaryOperationType.DELETE, - UnaryOperationType.PLUSPLUS_PRE, - UnaryOperationType.MINUSMINUS_PRE, - UnaryOperationType.PLUSPLUS_POST, - UnaryOperationType.MINUSMINUS_POST, - UnaryOperationType.PLUS_PRE, - UnaryOperationType.MINUS_PRE]: + self._expression: Expression = expression + self._type: UnaryOperationType = expression_type + if expression_type in [ + UnaryOperationType.DELETE, + UnaryOperationType.PLUSPLUS_PRE, + UnaryOperationType.MINUSMINUS_PRE, + UnaryOperationType.PLUSPLUS_POST, + UnaryOperationType.MINUSMINUS_POST, + UnaryOperationType.PLUS_PRE, + UnaryOperationType.MINUS_PRE, + ]: expression.set_lvalue() @property - def expression(self): + def expression(self) -> Expression: return self._expression @property - def type_str(self): - return UnaryOperationType.str(self._type) - - @property - def type(self): + def type(self) -> UnaryOperationType: return self._type @property - def is_prefix(self): + def is_prefix(self) -> bool: return UnaryOperationType.is_prefix(self._type) def __str__(self): if self.is_prefix: - return self.type_str + ' ' + str(self._expression) + return str(self.type) + " " + str(self._expression) else: - return str(self._expression) + ' ' + self.type_str - + return str(self._expression) + " " + str(self.type) diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e73132e9e..cb4a5fd04 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -6,39 +6,47 @@ import logging import json import re from collections import defaultdict +from typing import Optional, Dict, List, Set, Union + +from crytic_compile import CryticCompile from slither.core.context.context import Context +from slither.core.declarations import Contract, Pragma, Import, Function, Modifier +from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import InternalCall +from slither.slithir.variables import Constant from slither.utils.colors import red logger = logging.getLogger("Slither") logging.basicConfig() -class Slither(Context): + +class SlitherCore(Context): """ Slither static analyzer """ def __init__(self): - super(Slither, self).__init__() - self._contracts = {} - self._filename = None - self._source_units = {} - self._solc_version = None # '0.3' or '0.4':! - self._pragma_directives = [] - self._import_directives = [] - self._raw_source_code = {} - self._all_functions = set() - self._all_modifiers = set() - self._all_state_variables = None - - self._previous_results_filename = 'slither.db.json' - self._results_to_hide = [] - self._previous_results = [] - self._previous_results_ids = set() - self._paths_to_filter = set() - - self._crytic_compile = None + super(SlitherCore, self).__init__() + self._contracts: Dict[str, Contract] = {} + self._filename: Optional[str] = None + self._source_units: Dict[int, str] = {} + self._solc_version: Optional[str] = None # '0.3' or '0.4':! + self._pragma_directives: List[Pragma] = [] + self._import_directives: List[Import] = [] + self._raw_source_code: Dict[str, str] = {} + self._all_functions: Set[Function] = set() + self._all_modifiers: Set[Modifier] = set() + # Memoize + self._all_state_variables: Optional[Set[StateVariable]] = None + + self._previous_results_filename: str = "slither.db.json" + self._results_to_hide: List = [] + self._previous_results: List = [] + self._previous_results_ids: Set[str] = set() + self._paths_to_filter: Set[str] = set() + + self._crytic_compile: Optional[CryticCompile] = None self._generate_patches = False self._exclude_dependencies = False @@ -55,20 +63,24 @@ class Slither(Context): ################################################################################### @property - def source_code(self): + def source_code(self) -> Dict[str, str]: """ {filename: source_code (str)}: source code """ return self._raw_source_code @property - def source_units(self): + def source_units(self) -> Dict[int, str]: return self._source_units @property - def filename(self): + def filename(self) -> Optional[str]: """str: Filename.""" return self._filename - def _add_source_code(self, path): + @filename.setter + def filename(self, filename: str): + self._filename = filename + + def add_source_code(self, path): """ :param path: :return: @@ -76,11 +88,11 @@ class Slither(Context): if self.crytic_compile and path in self.crytic_compile.src_content: self.source_code[path] = self.crytic_compile.src_content[path] else: - with open(path, encoding='utf8', newline='') as f: + with open(path, encoding="utf8", newline="") as f: self.source_code[path] = f.read() @property - def markdown_root(self): + def markdown_root(self) -> str: return self._markdown_root # endregion @@ -91,19 +103,19 @@ class Slither(Context): ################################################################################### @property - def solc_version(self): + def solc_version(self) -> str: """str: Solidity version.""" if self.crytic_compile: return self.crytic_compile.compiler_version.version return self._solc_version @property - def pragma_directives(self): + def pragma_directives(self) -> List[Pragma]: """ list(core.declarations.Pragma): Pragma directives.""" return self._pragma_directives @property - def import_directives(self): + def import_directives(self) -> List[Import]: """ list(core.declarations.Import): Import directives""" return self._import_directives @@ -115,22 +127,23 @@ class Slither(Context): ################################################################################### @property - def contracts(self): + def contracts(self) -> List[Contract]: """list(Contract): List of contracts.""" return list(self._contracts.values()) @property - def contracts_derived(self): + def contracts_derived(self) -> List[Contract]: """list(Contract): List of contracts that are derived and not inherited.""" inheritance = (x.inheritance for x in self.contracts) inheritance = [item for sublist in inheritance for item in sublist] return [c for c in self._contracts.values() if c not in inheritance] - def contracts_as_dict(self): + @property + def contracts_as_dict(self) -> Dict[str, Contract]: """list(dict(str: Contract): List of contracts as dict: name -> Contract.""" return self._contracts - def get_contract_from_name(self, contract_name): + def get_contract_from_name(self, contract_name: Union[str, Constant]) -> Optional[Contract]: """ Return a contract from a name Args: @@ -148,24 +161,24 @@ class Slither(Context): ################################################################################### @property - def functions(self): + def functions(self) -> List[Function]: return list(self._all_functions) - def add_function(self, func): + def add_function(self, func: Function): self._all_functions.add(func) @property - def modifiers(self): + def modifiers(self) -> List[Modifier]: return list(self._all_modifiers) - def add_modifier(self, modif): + def add_modifier(self, modif: Modifier): self._all_modifiers.add(modif) @property - def functions_and_modifiers(self): + def functions_and_modifiers(self) -> List[Function]: return self.functions + self.modifiers - def _propagate_function_calls(self): + def propagate_function_calls(self): for f in self.functions_and_modifiers: for node in f.nodes: for ir in node.irs_ssa: @@ -180,7 +193,7 @@ class Slither(Context): ################################################################################### @property - def state_variables(self): + def state_variables(self) -> List[StateVariable]: if self._all_state_variables is None: state_variables = [c.state_variables for c in self.contracts] state_variables = [item for sublist in state_variables for item in sublist] @@ -194,13 +207,13 @@ class Slither(Context): ################################################################################### ################################################################################### - def print_functions(self, d): + def print_functions(self, d: str): """ Export all the functions to dot files """ for c in self.contracts: for f in c.functions: - f.cfg_to_dot(os.path.join(d, '{}.{}.dot'.format(c.name, f.name))) + f.cfg_to_dot(os.path.join(d, "{}.{}.dot".format(c.name, f.name))) # endregion ################################################################################### @@ -209,44 +222,53 @@ class Slither(Context): ################################################################################### ################################################################################### - def relative_path_format(self, path): + def relative_path_format(self, path: str) -> str: """ Strip relative paths of "." and ".." """ - return path.split('..')[-1].strip('.').strip('/') + return path.split("..")[-1].strip(".").strip("/") - def valid_result(self, r): - ''' + def valid_result(self, r: Dict) -> bool: + """ Check if the result is valid A result is invalid if: - All its source paths belong to the source path filtered - Or a similar result was reported and saved during a previous run - The --exclude-dependencies flag is set and results are only related to dependencies - ''' - source_mapping_elements = [elem['source_mapping']['filename_absolute'] - for elem in r['elements'] if 'source_mapping' in elem] - source_mapping_elements = map(lambda x: os.path.normpath(x) if x else x, source_mapping_elements) + """ + source_mapping_elements = [ + elem["source_mapping"]["filename_absolute"] + for elem in r["elements"] + if "source_mapping" in elem + ] + source_mapping_elements = map( + lambda x: os.path.normpath(x) if x else x, source_mapping_elements + ) matching = False for path in self._paths_to_filter: try: - if any(bool(re.search(self.relative_path_format(path), src_mapping)) - for src_mapping in source_mapping_elements): + if any( + bool(re.search(self.relative_path_format(path), src_mapping)) + for src_mapping in source_mapping_elements + ): matching = True break except re.error: - logger.error(f'Incorrect regular expression for --filter-paths {path}.' - '\nSlither supports the Python re format' - ': https://docs.python.org/3/library/re.html') + logger.error( + f"Incorrect regular expression for --filter-paths {path}." + "\nSlither supports the Python re format" + ": https://docs.python.org/3/library/re.html" + ) - if r['elements'] and matching: + if r["elements"] and matching: return False - if r['elements'] and self._exclude_dependencies: - return not all(element['source_mapping']['is_dependency'] for element in r['elements']) - if r['id'] in self._previous_results_ids: + if r["elements"] and self._exclude_dependencies: + return not all(element["source_mapping"]["is_dependency"] for element in r["elements"]) + if r["id"] in self._previous_results_ids: return False # Conserve previous result filtering. This is conserved for compatibility, but is meant to be removed - return not r['description'] in [pr['description'] for pr in self._previous_results] + return not r["description"] in [pr["description"] for pr in self._previous_results] def load_previous_results(self): filename = self._previous_results_filename @@ -256,27 +278,29 @@ class Slither(Context): self._previous_results = json.load(f) if self._previous_results: for r in self._previous_results: - if 'id' in r: - self._previous_results_ids.add(r['id']) + if "id" in r: + self._previous_results_ids.add(r["id"]) except json.decoder.JSONDecodeError: - logger.error(red('Impossible to decode {}. Consider removing the file'.format(filename))) + logger.error( + red("Impossible to decode {}. Consider removing the file".format(filename)) + ) def write_results_to_hide(self): if not self._results_to_hide: return filename = self._previous_results_filename - with open(filename, 'w', encoding='utf8') as f: + with open(filename, "w", encoding="utf8") as f: results = self._results_to_hide + self._previous_results json.dump(results, f) - def save_results_to_hide(self, results): + def save_results_to_hide(self, results: List[Dict]): self._results_to_hide += results - def add_path_to_filter(self, path): - ''' + def add_path_to_filter(self, path: str): + """ Add path to filter Path are used through direct comparison (no regex) - ''' + """ self._paths_to_filter.add(path) # endregion @@ -287,7 +311,7 @@ class Slither(Context): ################################################################################### @property - def crytic_compile(self): + def crytic_compile(self) -> Optional[CryticCompile]: return self._crytic_compile # endregion @@ -298,14 +322,13 @@ class Slither(Context): ################################################################################### @property - def generate_patches(self): + def generate_patches(self) -> bool: return self._generate_patches @generate_patches.setter - def generate_patches(self, p): + def generate_patches(self, p: bool): self._generate_patches = p - # endregion ################################################################################### ################################################################################### @@ -314,10 +337,11 @@ class Slither(Context): ################################################################################### @property - def contract_name_collisions(self): + def contract_name_collisions(self) -> Dict: return self._contract_name_collisions @property - def contracts_with_missing_inheritance(self): + def contracts_with_missing_inheritance(self) -> Set: return self._contract_with_missing_inheritance + # endregion diff --git a/slither/core/solidity_types/__init__.py b/slither/core/solidity_types/__init__.py index 24288488a..4a6e8e5df 100644 --- a/slither/core/solidity_types/__init__.py +++ b/slither/core/solidity_types/__init__.py @@ -3,4 +3,4 @@ from .elementary_type import ElementaryType from .function_type import FunctionType from .mapping_type import MappingType from .user_defined_type import UserDefinedType -from .type_information import TypeInformation \ No newline at end of file +from .type_information import TypeInformation diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index ad8061b74..7022539a7 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -1,20 +1,21 @@ -from slither.core.variables.variable import Variable -from slither.core.solidity_types.type import Type -from slither.core.expressions.expression import Expression +from typing import Optional + from slither.core.expressions import Literal +from slither.core.expressions.expression import Expression +from slither.core.solidity_types.type import Type from slither.visitors.expression.constants_folding import ConstantFolding -class ArrayType(Type): +class ArrayType(Type): def __init__(self, t, length): assert isinstance(t, Type) if length: if isinstance(length, int): - length = Literal(length, 'uint256') + length = Literal(length, "uint256") assert isinstance(length, Expression) super(ArrayType, self).__init__() - self._type = t - self._length = length + self._type: Type = t + self._length: Optional[Expression] = length if length: if not isinstance(length, Literal): @@ -25,18 +26,21 @@ class ArrayType(Type): self._length_value = None @property - def type(self): + def type(self) -> Type: return self._type @property - def length(self): + def length(self) -> Optional[Expression]: return self._length + @property + def lenght_value(self) -> Optional[Literal]: + return self._length_value + def __str__(self): if self._length: - return str(self._type)+'[{}]'.format(str(self._length_value)) - return str(self._type)+'[]' - + return str(self._type) + "[{}]".format(str(self._length_value)) + return str(self._type) + "[]" def __eq__(self, other): if not isinstance(other, ArrayType): diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index ccfb13222..474f2a452 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -1,69 +1,175 @@ import itertools +from typing import Optional from slither.core.solidity_types.type import Type # see https://solidity.readthedocs.io/en/v0.4.24/miscellaneous.html?highlight=grammar -Int = ['int', 'int8', 'int16', 'int24', 'int32', 'int40', 'int48', 'int56', 'int64', 'int72', 'int80', 'int88', 'int96', 'int104', 'int112', 'int120', 'int128', 'int136', 'int144', 'int152', 'int160', 'int168', 'int176', 'int184', 'int192', 'int200', 'int208', 'int216', 'int224', 'int232', 'int240', 'int248', 'int256'] - -Uint = ['uint', 'uint8', 'uint16', 'uint24', 'uint32', 'uint40', 'uint48', 'uint56', 'uint64', 'uint72', 'uint80', 'uint88', 'uint96', 'uint104', 'uint112', 'uint120', 'uint128', 'uint136', 'uint144', 'uint152', 'uint160', 'uint168', 'uint176', 'uint184', 'uint192', 'uint200', 'uint208', 'uint216', 'uint224', 'uint232', 'uint240', 'uint248', 'uint256'] - -Byte = ['byte', 'bytes', 'bytes1', 'bytes2', 'bytes3', 'bytes4', 'bytes5', 'bytes6', 'bytes7', 'bytes8', 'bytes9', 'bytes10', 'bytes11', 'bytes12', 'bytes13', 'bytes14', 'bytes15', 'bytes16', 'bytes17', 'bytes18', 'bytes19', 'bytes20', 'bytes21', 'bytes22', 'bytes23', 'bytes24', 'bytes25', 'bytes26', 'bytes27', 'bytes28', 'bytes29', 'bytes30', 'bytes31', 'bytes32'] +Int = [ + "int", + "int8", + "int16", + "int24", + "int32", + "int40", + "int48", + "int56", + "int64", + "int72", + "int80", + "int88", + "int96", + "int104", + "int112", + "int120", + "int128", + "int136", + "int144", + "int152", + "int160", + "int168", + "int176", + "int184", + "int192", + "int200", + "int208", + "int216", + "int224", + "int232", + "int240", + "int248", + "int256", +] + +Uint = [ + "uint", + "uint8", + "uint16", + "uint24", + "uint32", + "uint40", + "uint48", + "uint56", + "uint64", + "uint72", + "uint80", + "uint88", + "uint96", + "uint104", + "uint112", + "uint120", + "uint128", + "uint136", + "uint144", + "uint152", + "uint160", + "uint168", + "uint176", + "uint184", + "uint192", + "uint200", + "uint208", + "uint216", + "uint224", + "uint232", + "uint240", + "uint248", + "uint256", +] + +Byte = [ + "byte", + "bytes", + "bytes1", + "bytes2", + "bytes3", + "bytes4", + "bytes5", + "bytes6", + "bytes7", + "bytes8", + "bytes9", + "bytes10", + "bytes11", + "bytes12", + "bytes13", + "bytes14", + "bytes15", + "bytes16", + "bytes17", + "bytes18", + "bytes19", + "bytes20", + "bytes21", + "bytes22", + "bytes23", + "bytes24", + "bytes25", + "bytes26", + "bytes27", + "bytes28", + "bytes29", + "bytes30", + "bytes31", + "bytes32", +] # https://solidity.readthedocs.io/en/v0.4.24/types.html#fixed-point-numbers M = list(range(8, 257, 8)) N = list(range(0, 81)) -MN = list(itertools.product(M,N)) +MN = list(itertools.product(M, N)) -Fixed = ['fixed{}x{}'.format(m,n) for (m,n) in MN] + ['fixed'] -Ufixed = ['ufixed{}x{}'.format(m,n) for (m,n) in MN] + ['ufixed'] +Fixed = ["fixed{}x{}".format(m, n) for (m, n) in MN] + ["fixed"] +Ufixed = ["ufixed{}x{}".format(m, n) for (m, n) in MN] + ["ufixed"] -ElementaryTypeName = ['address', 'bool', 'string', 'var'] + Int + Uint + Byte + Fixed + Ufixed +ElementaryTypeName = ["address", "bool", "string", "var"] + Int + Uint + Byte + Fixed + Ufixed -class NonElementaryType(Exception): pass -class ElementaryType(Type): +class NonElementaryType(Exception): + pass + +class ElementaryType(Type): def __init__(self, t): if t not in ElementaryTypeName: raise NonElementaryType super(ElementaryType, self).__init__() - if t == 'uint': - t = 'uint256' - elif t == 'int': - t = 'int256' - elif t == 'byte': - t = 'bytes1' + if t == "uint": + t = "uint256" + elif t == "int": + t = "int256" + elif t == "byte": + t = "bytes1" self._type = t @property - def type(self): + def type(self) -> str: return self._type @property - def name(self): + def name(self) -> str: return self.type @property - def size(self): - ''' + def size(self) -> Optional[int]: + """ Return the size in bits Return None if the size is not known Returns: int - ''' + """ t = self._type - if t.startswith('uint'): - return int(t[len('uint'):]) - if t.startswith('int'): - return int(t[len('int'):]) - if t == 'bool': + if t.startswith("uint"): + return int(t[len("uint") :]) + if t.startswith("int"): + return int(t[len("int") :]) + if t == "bool": return int(8) - if t == 'address': + if t == "address": return int(160) - if t.startswith('bytes'): - return int(t[len('bytes'):]) + if t.startswith("bytes"): + return int(t[len("bytes") :]) return None def __str__(self): @@ -76,4 +182,3 @@ class ElementaryType(Type): def __hash__(self): return hash(str(self)) - diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index b5cdb0f63..e1d20d68a 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -1,25 +1,29 @@ +from typing import List + from slither.core.solidity_types.type import Type from slither.core.variables.function_type_variable import FunctionTypeVariable -class FunctionType(Type): - def __init__(self, params, return_values): +class FunctionType(Type): + def __init__( + self, params: List[FunctionTypeVariable], return_values: List[FunctionTypeVariable] + ): assert all(isinstance(x, FunctionTypeVariable) for x in params) assert all(isinstance(x, FunctionTypeVariable) for x in return_values) super(FunctionType, self).__init__() - self._params = params - self._return_values = return_values + self._params: List[FunctionTypeVariable] = params + self._return_values: List[FunctionTypeVariable] = return_values @property - def params(self): + def params(self) -> List[FunctionTypeVariable]: return self._params @property - def return_values(self): + def return_values(self) -> List[FunctionTypeVariable]: return self._return_values @property - def return_type(self): + def return_type(self) -> List[Type]: return [x.type for x in self.return_values] def __str__(self): @@ -28,33 +32,31 @@ class FunctionType(Type): params = ",".join([str(x.type) for x in self._params]) return_values = ",".join([str(x.type) for x in self._return_values]) if return_values: - return 'function({}) returns({})'.format(params, return_values) - return 'function({})'.format(params) + return "function({}) returns({})".format(params, return_values) + return "function({})".format(params) @property - def parameters_signature(self): - ''' + def parameters_signature(self) -> str: + """ Return the parameters signature(without the return statetement) - ''' + """ # Use x.type # x.name may be empty params = ",".join([str(x.type) for x in self._params]) - return '({})'.format(params) + return "({})".format(params) @property - def signature(self): - ''' + def signature(self) -> str: + """ Return the signature(with the return statetement if it exists) - ''' + """ # Use x.type # x.name may be empty params = ",".join([str(x.type) for x in self._params]) return_values = ",".join([str(x.type) for x in self._return_values]) if return_values: - return '({}) returns({})'.format(params, return_values) - return '({})'.format(params) - - + return "({}) returns({})".format(params, return_values) + return "({})".format(params) def __eq__(self, other): if not isinstance(other, FunctionType): diff --git a/slither/core/solidity_types/mapping_type.py b/slither/core/solidity_types/mapping_type.py index cf7e0794a..7c4b8c2bb 100644 --- a/slither/core/solidity_types/mapping_type.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,7 +1,7 @@ from slither.core.solidity_types.type import Type -class MappingType(Type): +class MappingType(Type): def __init__(self, type_from, type_to): assert isinstance(type_from, Type) assert isinstance(type_to, Type) @@ -10,15 +10,15 @@ class MappingType(Type): self._to = type_to @property - def type_from(self): + def type_from(self) -> Type: return self._from @property - def type_to(self): + def type_to(self) -> Type: return self._to def __str__(self): - return 'mapping({} => {})'.format(str(self._from), str(self._to)) + return "mapping({} => {})".format(str(self._from), str(self._to)) def __eq__(self, other): if not isinstance(other, MappingType): @@ -27,4 +27,3 @@ class MappingType(Type): def __hash__(self): return hash(str(self)) - diff --git a/slither/core/solidity_types/type.py b/slither/core/solidity_types/type.py index 1c2794bec..e06c7bf0d 100644 --- a/slither/core/solidity_types/type.py +++ b/slither/core/solidity_types/type.py @@ -1,3 +1,5 @@ from slither.core.source_mapping.source_mapping import SourceMapping -class Type(SourceMapping): pass + +class Type(SourceMapping): + pass diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index ee6e71ee8..26dbda55a 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -1,23 +1,29 @@ +from typing import TYPE_CHECKING + from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.declarations.contract import Contract + + # Use to model the Type(X) function, which returns an undefined type # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information class TypeInformation(Type): def __init__(self, c): from slither.core.declarations.contract import Contract - assert isinstance(c, (Contract)) + assert isinstance(c, Contract) super(TypeInformation, self).__init__() self._type = c @property - def type(self): + def type(self) -> "Contract": return self._type def __str__(self): - return f'type({self.type.name})' + return f"type({self.type.name})" def __eq__(self, other): if not isinstance(other, TypeInformation): return False - return self.type == other.type \ No newline at end of file + return self.type == other.type diff --git a/slither/core/solidity_types/user_defined_type.py b/slither/core/solidity_types/user_defined_type.py index bfabb3c85..eeb38b670 100644 --- a/slither/core/solidity_types/user_defined_type.py +++ b/slither/core/solidity_types/user_defined_type.py @@ -1,8 +1,14 @@ +from typing import Union, TYPE_CHECKING + from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.declarations.structure import Structure + from slither.core.declarations.enum import Enum + from slither.core.declarations.contract import Contract -class UserDefinedType(Type): +class UserDefinedType(Type): def __init__(self, t): from slither.core.declarations.structure import Structure from slither.core.declarations.enum import Enum @@ -13,7 +19,7 @@ class UserDefinedType(Type): self._type = t @property - def type(self): + def type(self) -> Union["Contract", "Enum", "Structure"]: return self._type def __str__(self): @@ -21,7 +27,7 @@ class UserDefinedType(Type): from slither.core.declarations.enum import Enum if isinstance(self.type, (Enum, Structure)): - return str(self.type.contract)+'.'+str(self.type.name) + return str(self.type.contract) + "." + str(self.type.name) return str(self.type.name) def __eq__(self, other): @@ -29,7 +35,5 @@ class UserDefinedType(Type): return False return self.type == other.type - def __hash__(self): return hash(str(self)) - diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index 50832f721..31e1ff34e 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -1,16 +1,17 @@ import re +from typing import Dict, Union, Optional from slither.core.context.context import Context class SourceMapping(Context): - def __init__(self): super(SourceMapping, self).__init__() - self._source_mapping = None + # TODO create a namedtuple for the source mapping rather than a dict + self._source_mapping: Optional[Dict] = None @property - def source_mapping(self): + def source_mapping(self) -> Optional[Dict]: return self._source_mapping @staticmethod @@ -21,7 +22,7 @@ class SourceMapping(Context): Not done in an efficient way """ - source_code = source_code.encode('utf-8') + source_code = source_code.encode("utf-8") total_length = len(source_code) source_code = source_code.splitlines(True) counter = 0 @@ -38,7 +39,11 @@ class SourceMapping(Context): # Determine our column numbers. if starting_column is None and counter + line_length > start: starting_column = (start - counter) + 1 - if starting_column is not None and ending_column is None and counter + line_length > start + length: + if ( + starting_column is not None + and ending_column is None + and counter + line_length > start + length + ): ending_column = ((start + length) - counter) + 1 # Advance the current position counter, and determine line numbers. @@ -50,19 +55,19 @@ class SourceMapping(Context): if counter > start + length: break - return (lines, starting_column, ending_column) + return lines, starting_column, ending_column @staticmethod - def _convert_source_mapping(offset, slither): - ''' + def _convert_source_mapping(offset: str, slither): + """ Convert a text offset to a real offset see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings Returns: (dict): {'start':0, 'length':0, 'filename': 'file.sol'} - ''' + """ sourceUnits = slither.source_units - position = re.findall('([0-9]*):([0-9]*):([-]?[0-9]*)', offset) + position = re.findall("([0-9]*):([0-9]*):([-]?[0-9]*)", offset) if len(position) != 1: return {} @@ -72,7 +77,7 @@ class SourceMapping(Context): f = int(f) if f not in sourceUnits: - return {'start':s, 'length':l} + return {"start": s, "length": l} filename_used = sourceUnits[f] filename_absolute = None filename_relative = None @@ -91,7 +96,10 @@ class SourceMapping(Context): is_dependency = slither.crytic_compile.is_dependency(filename_absolute) - if filename_absolute in slither.source_code or filename_absolute in slither.crytic_compile.src_content: + if ( + filename_absolute in slither.source_code + or filename_absolute in slither.crytic_compile.src_content + ): filename = filename_absolute elif filename_relative in slither.source_code: filename = filename_relative @@ -104,51 +112,47 @@ class SourceMapping(Context): if slither.crytic_compile and filename in slither.crytic_compile.src_content: source_code = slither.crytic_compile.src_content[filename] - (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, - s, - l) + (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l) elif filename in slither.source_code: source_code = slither.source_code[filename] - (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, - s, - l) + (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l) else: (lines, starting_column, ending_column) = ([], None, None) - return {'start':s, - 'length':l, - 'filename_used': filename_used, - 'filename_relative': filename_relative, - 'filename_absolute': filename_absolute, - 'filename_short': filename_short, - 'is_dependency': is_dependency, - 'lines' : lines, - 'starting_column': starting_column, - 'ending_column': ending_column - } - - def set_offset(self, offset, slither): + return { + "start": s, + "length": l, + "filename_used": filename_used, + "filename_relative": filename_relative, + "filename_absolute": filename_absolute, + "filename_short": filename_short, + "is_dependency": is_dependency, + "lines": lines, + "starting_column": starting_column, + "ending_column": ending_column, + } + + def set_offset(self, offset: Union[Dict, str], slither): if isinstance(offset, dict): self._source_mapping = offset else: self._source_mapping = self._convert_source_mapping(offset, slither) def _get_lines_str(self, line_descr=""): - lines = self.source_mapping.get('lines', None) + lines = self.source_mapping.get("lines", None) if not lines: - lines = '' + lines = "" elif len(lines) == 1: - lines = '#{}{}'.format(line_descr, lines[0]) + lines = "#{}{}".format(line_descr, lines[0]) else: - lines = '#{}{}-{}{}'.format(line_descr, lines[0], line_descr, lines[-1]) + lines = "#{}{}-{}{}".format(line_descr, lines[0], line_descr, lines[-1]) return lines - def source_mapping_to_markdown(self, markdown_root): + def source_mapping_to_markdown(self, markdown_root: str) -> str: lines = self._get_lines_str(line_descr="L") return f'{markdown_root}{self.source_mapping["filename_relative"]}{lines}' @property - def source_mapping_str(self): + def source_mapping_str(self) -> str: lines = self._get_lines_str() return f'{self.source_mapping["filename_short"]}{lines}' - diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index 293a3bdae..3cf000273 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -1,16 +1,20 @@ from .variable import Variable from slither.core.children.child_event import ChildEvent + class EventVariable(ChildEvent, Variable): def __init__(self): super(EventVariable, self).__init__() self._indexed = False @property - def indexed(self): + def indexed(self) -> bool: """ Indicates whether the event variable is indexed in the bloom filter. :return: Returns True if the variable is indexed in bloom filter, False otherwise. """ return self._indexed + @indexed.setter + def indexed(self, is_indexed: bool): + self._indexed = is_indexed diff --git a/slither/core/variables/function_type_variable.py b/slither/core/variables/function_type_variable.py index 9efab7f36..450f17877 100644 --- a/slither/core/variables/function_type_variable.py +++ b/slither/core/variables/function_type_variable.py @@ -8,5 +8,6 @@ from .variable import Variable -class FunctionTypeVariable(Variable): pass +class FunctionTypeVariable(Variable): + pass diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py index 8b353d530..2a0143f64 100644 --- a/slither/core/variables/local_variable.py +++ b/slither/core/variables/local_variable.py @@ -1,45 +1,51 @@ +from typing import Optional + from .variable import Variable from slither.core.children.child_function import ChildFunction from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.mapping_type import MappingType +from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.declarations.structure import Structure class LocalVariable(ChildFunction, Variable): - def __init__(self): super(LocalVariable, self).__init__() - self._location = None + self._location: Optional[str] = None - def set_location(self, loc): + def set_location(self, loc: str): self._location = loc @property - def location(self): - ''' + def location(self) -> Optional[str]: + """ Variable Location Can be storage/memory or default Returns: (str) - ''' + """ return self._location @property - def is_storage(self): + def is_scalar(self) -> bool: + return isinstance(self.type, ElementaryType) and not self.is_storage + + @property + def is_storage(self) -> bool: """ Return true if the variable is located in storage See https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location Returns: (bool) """ - if self.location == 'memory': + if self.location == "memory": return False # Use by slithIR SSA - if self.location == 'reference_to_storage': + if self.location == "reference_to_storage": return False - if self.location == 'storage': + if self.location == "storage": return True if isinstance(self.type, (ArrayType, MappingType)): @@ -51,7 +57,5 @@ class LocalVariable(ChildFunction, Variable): return False @property - def canonical_name(self): - return '{}.{}'.format(self.function.canonical_name, self.name) - - + def canonical_name(self) -> str: + return "{}.{}".format(self.function.canonical_name, self.name) diff --git a/slither/core/variables/local_variable_init_from_tuple.py b/slither/core/variables/local_variable_init_from_tuple.py index 09ca7a361..3271225b8 100644 --- a/slither/core/variables/local_variable_init_from_tuple.py +++ b/slither/core/variables/local_variable_init_from_tuple.py @@ -1,5 +1,8 @@ +from typing import Optional + from slither.core.variables.local_variable import LocalVariable + class LocalVariableInitFromTuple(LocalVariable): """ Use on this pattern: @@ -12,8 +15,12 @@ class LocalVariableInitFromTuple(LocalVariable): def __init__(self): super(LocalVariableInitFromTuple, self).__init__() - self._tuple_index = None + self._tuple_index: Optional[int] = None @property - def tuple_index(self): + def tuple_index(self) -> Optional[int]: return self._tuple_index + + @tuple_index.setter + def tuple_index(self, idx: int): + self._tuple_index = idx diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index af5421c53..5bb3f1a1c 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -1,14 +1,20 @@ +from typing import Optional, TYPE_CHECKING, Tuple, List + from .variable import Variable from slither.core.children.child_contract import ChildContract from slither.utils.type import export_nested_types_from_variable -class StateVariable(ChildContract, Variable): +if TYPE_CHECKING: + from ..cfg.node import Node + from ..declarations import Contract + +class StateVariable(ChildContract, Variable): def __init__(self): super(StateVariable, self).__init__() - self._node_initialization = None + self._node_initialization: Optional["Node"] = None - def is_declared_by(self, contract): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: @@ -16,7 +22,6 @@ class StateVariable(ChildContract, Variable): """ return self.contract == contract - ################################################################################### ################################################################################### # region Signature @@ -24,21 +29,21 @@ class StateVariable(ChildContract, Variable): ################################################################################### @property - def signature(self): + def signature(self) -> Tuple[str, List[str], str]: """ Return the signature of the state variable as a function signature :return: (str, list(str), list(str)), as (name, list parameters type, list return values type) """ - return self.name, [str(x) for x in export_nested_types_from_variable(self)], self.type + return self.name, [str(x) for x in export_nested_types_from_variable(self)], str(self.type) @property - def signature_str(self): + def signature_str(self) -> str: """ Return the signature of the state variable as a function signature :return: str: func_name(type1,type2) returns(type3) """ name, parameters, returnVars = self.signature - return name+'('+','.join(parameters)+') returns('+','.join(returnVars)+')' + return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")" # endregion ################################################################################### @@ -48,18 +53,18 @@ class StateVariable(ChildContract, Variable): ################################################################################### @property - def canonical_name(self): - return '{}.{}'.format(self.contract.name, self.name) + def canonical_name(self) -> str: + return "{}.{}".format(self.contract.name, self.name) @property - def full_name(self): + def full_name(self) -> str: """ Return the name of the state variable as a function signaure str: func_name(type1,type2) :return: the function signature without the return values """ name, parameters, _ = self.signature - return name+'('+','.join(parameters)+')' + return name + "(" + ",".join(parameters) + ")" # endregion ################################################################################### @@ -69,7 +74,7 @@ class StateVariable(ChildContract, Variable): ################################################################################### @property - def node_initialization(self): + def node_initialization(self) -> Optional["Node"]: """ Node for the state variable initalization :return: @@ -80,8 +85,6 @@ class StateVariable(ChildContract, Variable): def node_initialization(self, node_initialization): self._node_initialization = node_initialization - # endregion ################################################################################### ################################################################################### - diff --git a/slither/core/variables/structure_variable.py b/slither/core/variables/structure_variable.py index 1c0c188e1..c1b34d4f1 100644 --- a/slither/core/variables/structure_variable.py +++ b/slither/core/variables/structure_variable.py @@ -1,5 +1,6 @@ from .variable import Variable from slither.core.children.child_structure import ChildStructure -class StructureVariable(ChildStructure, Variable): pass +class StructureVariable(ChildStructure, Variable): + pass diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 4d7f26e03..bd3413dc6 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -1,24 +1,32 @@ """ Variable module """ +from typing import Optional, TYPE_CHECKING, List, Union from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.solidity_types.type import Type from slither.core.solidity_types.elementary_type import ElementaryType -class Variable(SourceMapping): +if TYPE_CHECKING: + from slither.core.expressions.expression import Expression + +class Variable(SourceMapping): def __init__(self): super(Variable, self).__init__() - self._name = None - self._initial_expression = None - self._type = None - self._initialized = None - self._visibility = None + self._name: Optional[str] = None + self._initial_expression: Optional["Expression"] = None + self._type: Optional[Type] = None + self._initialized: Optional[bool] = None + self._visibility: Optional[str] = None self._is_constant = False @property - def expression(self): + def is_scalar(self) -> bool: + return isinstance(self.type, ElementaryType) + + @property + def expression(self) -> Optional["Expression"]: """ Expression: Expression of the node (if initialized) Initial expression may be different than the expression of the node @@ -32,25 +40,33 @@ class Variable(SourceMapping): """ return self._initial_expression + @expression.setter + def expression(self, expr: "Expression"): + self._initial_expression = expr + @property - def initialized(self): + def initialized(self) -> Optional[bool]: """ boolean: True if the variable is initialized at construction """ return self._initialized + @initialized.setter + def initialized(self, is_init: bool): + self._initialized = is_init + @property - def uninitialized(self): + def uninitialized(self) -> bool: """ boolean: True if the variable is not initialized """ return not self._initialized @property - def name(self): - ''' + def name(self) -> str: + """ str: variable name - ''' + """ return self._name @name.setter @@ -58,20 +74,32 @@ class Variable(SourceMapping): self._name = name @property - def type(self): + def type(self) -> Optional[Union[Type, List[Type]]]: return self._type + @type.setter + def type(self, types: Union[Type, List[Type]]): + self._type = types + @property - def is_constant(self): + def is_constant(self) -> bool: return self._is_constant + @is_constant.setter + def is_constant(self, is_cst: bool): + self._is_constant = is_cst + @property - def visibility(self): - ''' + def visibility(self) -> Optional[str]: + """ str: variable visibility - ''' + """ return self._visibility + @visibility.setter + def visibility(self, v: str): + self._visibility = v + def set_type(self, t): if isinstance(t, str): t = ElementaryType(t) @@ -80,25 +108,21 @@ class Variable(SourceMapping): @property def function_name(self): - ''' + """ Return the name of the variable as a function signature :return: - ''' + """ from slither.core.solidity_types import ArrayType, MappingType + from slither.utils.type import export_nested_types_from_variable + variable_getter_args = "" - if type(self.type) is ArrayType: - length = 0 - v = self - while type(v.type) is ArrayType: - length += 1 - v = v.type - variable_getter_args = ','.join(["uint256"] * length) - elif type(self.type) is MappingType: - variable_getter_args = self.type.type_from + return_type = self.type + assert return_type + + if isinstance(return_type, (ArrayType, MappingType)): + variable_getter_args = ",".join(map(str, export_nested_types_from_variable(self))) return f"{self.name}({variable_getter_args})" def __str__(self): return self._name - - diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 26f3221c3..9036c09ce 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -10,7 +10,7 @@ from slither.core.cfg.node import Node from slither.core.declarations import Function from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityFunction, SolidityVariable from slither.core.expressions import NewContract -from slither.core.slither_core import Slither +from slither.core.slither_core import SlitherCore from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter @@ -26,7 +26,7 @@ def _get_name(f: Function) -> str: return f.solidity_signature -def _extract_payable(slither: Slither) -> Dict[str, List[str]]: +def _extract_payable(slither: SlitherCore) -> Dict[str, List[str]]: ret: Dict[str, List[str]] = {} for contract in slither.contracts: payable_functions = [_get_name(f) for f in contract.functions_entry_points if f.payable] @@ -35,7 +35,7 @@ def _extract_payable(slither: Slither) -> Dict[str, List[str]]: return ret -def _extract_solidity_variable_usage(slither: Slither, sol_var: SolidityVariable) -> Dict[str, List[str]]: +def _extract_solidity_variable_usage(slither: SlitherCore, sol_var: SolidityVariable) -> Dict[str, List[str]]: ret: Dict[str, List[str]] = {} for contract in slither.contracts: functions_using_sol_var = [] @@ -53,7 +53,7 @@ def _is_constant(f: Function) -> bool: """ Heuristic: - If view/pure with Solidity >= 0.4 -> Return true - - If it contains assembly -> Return false (Slither doesn't analyze asm) + - If it contains assembly -> Return false (SlitherCore doesn't analyze asm) - Otherwise check for the rules from https://solidity.readthedocs.io/en/v0.5.0/contracts.html?highlight=pure#view-functions with an exception: internal dynamic call are not correctly handled, so we consider them as non-constant @@ -93,7 +93,7 @@ def _is_constant(f: Function) -> bool: return True -def _extract_constant_functions(slither: Slither) -> Dict[str, List[str]]: +def _extract_constant_functions(slither: SlitherCore) -> Dict[str, List[str]]: ret: Dict[str, List[str]] = {} for contract in slither.contracts: cst_functions = [_get_name(f) for f in contract.functions_entry_points if _is_constant(f)] @@ -103,7 +103,7 @@ def _extract_constant_functions(slither: Slither) -> Dict[str, List[str]]: return ret -def _extract_assert(slither: Slither) -> Dict[str, List[str]]: +def _extract_assert(slither: SlitherCore) -> Dict[str, List[str]]: ret: Dict[str, List[str]] = {} for contract in slither.contracts: functions_using_assert = [] @@ -145,7 +145,7 @@ def _extract_constants_from_irs(irs: List[Operation], if isinstance(ir, Binary): for r in ir.read: if isinstance(r, Constant): - all_cst_used_in_binary[BinaryType.str(ir.type)].append(ConstantValue(str(r.value), str(r.type))) + all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type))) if isinstance(ir, TypeConversion): if isinstance(ir.variable, Constant): all_cst_used.append(ConstantValue(str(ir.variable.value), str(ir.type))) @@ -169,7 +169,7 @@ def _extract_constants_from_irs(irs: List[Operation], context_explored) -def _extract_constants(slither: Slither) -> Tuple[Dict[str, Dict[str, List]], Dict[str, Dict[str, Dict]]]: +def _extract_constants(slither: SlitherCore) -> Tuple[Dict[str, Dict[str, List]], Dict[str, Dict[str, Dict]]]: # contract -> function -> [ {"value": value, "type": type} ] ret_cst_used: Dict[str, Dict[str, List[ConstantValue]]] = defaultdict(dict) # contract -> function -> binary_operand -> [ {"value": value, "type": type ] @@ -196,7 +196,7 @@ def _extract_constants(slither: Slither) -> Tuple[Dict[str, Dict[str, List]], Di return ret_cst_used, ret_cst_used_in_binary -def _extract_function_relations(slither: Slither) -> Dict[str, Dict[str, Dict[str, List[str]]]]: +def _extract_function_relations(slither: SlitherCore) -> Dict[str, Dict[str, Dict[str, List[str]]]]: # contract -> function -> [functions] ret: Dict[str, Dict[str, Dict[str, List[str]]]] = defaultdict(dict) for contract in slither.contracts: @@ -217,7 +217,7 @@ def _extract_function_relations(slither: Slither) -> Dict[str, Dict[str, Dict[st return ret -def _have_external_calls(slither: Slither) -> Dict[str, List[str]]: +def _have_external_calls(slither: SlitherCore) -> Dict[str, List[str]]: """ Detect the functions with external calls :param slither: @@ -233,7 +233,7 @@ def _have_external_calls(slither: Slither) -> Dict[str, List[str]]: return ret -def _use_balance(slither: Slither) -> Dict[str, List[str]]: +def _use_balance(slither: SlitherCore) -> Dict[str, List[str]]: """ Detect the functions with external calls :param slither: @@ -250,7 +250,7 @@ def _use_balance(slither: Slither) -> Dict[str, List[str]]: return ret -def _call_a_parameter(slither: Slither) -> Dict[str, List[Dict]]: +def _call_a_parameter(slither: SlitherCore) -> Dict[str, List[Dict]]: """ Detect the functions with external calls :param slither: diff --git a/slither/slither.py b/slither/slither.py index 514c1ae1f..2eff55fa7 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -1,18 +1,13 @@ import logging import os -import subprocess -import sys -import glob -import json -import platform from crytic_compile import CryticCompile, InvalidCompilation from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.printers.abstract_printer import AbstractPrinter -from .solc_parsing.exceptions import VariableNotFound -from .solc_parsing.slitherSolc import SlitherSolc +from .core.slither_core import SlitherCore from .exceptions import SlitherError +from .solc_parsing.slitherSolc import SlitherSolc logger = logging.getLogger("Slither") logging.basicConfig() @@ -21,10 +16,9 @@ logger_detector = logging.getLogger("Detectors") logger_printer = logging.getLogger("Printers") -class Slither(SlitherSolc): - +class Slither(SlitherCore): def __init__(self, target, **kwargs): - ''' + """ Args: target (str | list(json) | CryticCompile) Keyword Args: @@ -46,14 +40,16 @@ class Slither(SlitherSolc): embark_ignore_compile (bool): do not run embark build (default False) embark_overwrite_config (bool): overwrite original config file (default false) - ''' + """ + super().__init__() + self._parser: SlitherSolc # This could be another parser, like SlitherVyper, interface needs to be determined # list of files provided (see --splitted option) if isinstance(target, list): self._init_from_list(target) - elif isinstance(target, str) and target.endswith('.json'): + elif isinstance(target, str) and target.endswith(".json"): self._init_from_raw_json(target) else: - super(Slither, self).__init__('') + self._parser = SlitherSolc("", self) try: if isinstance(target, CryticCompile): crytic_compile = target @@ -61,54 +57,55 @@ class Slither(SlitherSolc): crytic_compile = CryticCompile(target, **kwargs) self._crytic_compile = crytic_compile except InvalidCompilation as e: - raise SlitherError('Invalid compilation: \n'+str(e)) + raise SlitherError("Invalid compilation: \n" + str(e)) for path, ast in crytic_compile.asts.items(): - self._parse_contracts_from_loaded_json(ast, path) - self._add_source_code(path) + self._parser.parse_contracts_from_loaded_json(ast, path) + self.add_source_code(path) - if kwargs.get('generate_patches', False): + if kwargs.get("generate_patches", False): self.generate_patches = True - self._markdown_root = kwargs.get('markdown_root', "") + self._markdown_root = kwargs.get("markdown_root", "") self._detectors = [] self._printers = [] - filter_paths = kwargs.get('filter_paths', []) + filter_paths = kwargs.get("filter_paths", []) for p in filter_paths: self.add_path_to_filter(p) - self._exclude_dependencies = kwargs.get('exclude_dependencies', False) + self._exclude_dependencies = kwargs.get("exclude_dependencies", False) - triage_mode = kwargs.get('triage_mode', False) + triage_mode = kwargs.get("triage_mode", False) self._triage_mode = triage_mode - self._analyze_contracts() - + self._parser.analyze_contracts() def _init_from_raw_json(self, filename): if not os.path.isfile(filename): - raise SlitherError('{} does not exist (are you in the correct directory?)'.format(filename)) - assert filename.endswith('json') - with open(filename, encoding='utf8') as astFile: + raise SlitherError( + "{} does not exist (are you in the correct directory?)".format(filename) + ) + assert filename.endswith("json") + with open(filename, encoding="utf8") as astFile: stdout = astFile.read() if not stdout: - raise SlitherError('Empty AST file: %s', filename) - contracts_json = stdout.split('\n=') + raise SlitherError("Empty AST file: %s", filename) + contracts_json = stdout.split("\n=") - super(Slither, self).__init__(filename) + self._parser = SlitherSolc(filename, self) for c in contracts_json: - self._parse_contracts_from_json(c) + self._parser.parse_contracts_from_json(c) def _init_from_list(self, contract): - super(Slither, self).__init__('') + self._parser = SlitherSolc("", self) for c in contract: - if 'absolutePath' in c: - path = c['absolutePath'] + if "absolutePath" in c: + path = c["absolutePath"] else: - path = c['attributes']['absolutePath'] - self._parse_contracts_from_loaded_json(c, path) + path = c["attributes"]["absolutePath"] + self._parser.parse_contracts_from_loaded_json(c, path) @property def detectors(self): @@ -138,7 +135,7 @@ class Slither(SlitherSolc): """ :param detector_class: Class inheriting from `AbstractDetector`. """ - self._check_common_things('detector', detector_class, AbstractDetector, self._detectors) + self._check_common_things("detector", detector_class, AbstractDetector, self._detectors) instance = detector_class(self, logger_detector) self._detectors.append(instance) @@ -147,7 +144,7 @@ class Slither(SlitherSolc): """ :param printer_class: Class inheriting from `AbstractPrinter`. """ - self._check_common_things('printer', printer_class, AbstractPrinter, self._printers) + self._check_common_things("printer", printer_class, AbstractPrinter, self._printers) instance = printer_class(self, logger_printer) self._printers.append(instance) @@ -179,19 +176,19 @@ class Slither(SlitherSolc): ) if any(type(obj) == cls for obj in instances_list): - raise Exception( - "You can't register {!r} twice.".format(cls) - ) + raise Exception("You can't register {!r} twice.".format(cls)) def _run_solc(self, filename, solc, disable_solc_warnings, solc_arguments, ast_format): if not os.path.isfile(filename): - raise SlitherError('{} does not exist (are you in the correct directory?)'.format(filename)) - assert filename.endswith('json') - with open(filename, encoding='utf8') as astFile: + raise SlitherError( + "{} does not exist (are you in the correct directory?)".format(filename) + ) + assert filename.endswith("json") + with open(filename, encoding="utf8") as astFile: stdout = astFile.read() if not stdout: - raise SlitherError('Empty AST file: %s', filename) - stdout = stdout.split('\n=') + raise SlitherError("Empty AST file: %s", filename) + stdout = stdout.split("\n=") return stdout diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index 4c033bddd..cd886197f 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -1,4 +1,5 @@ import logging +from enum import Enum from slither.core.solidity_types import ElementaryType from slither.slithir.exceptions import SlithIRError @@ -8,26 +9,27 @@ from slither.slithir.variables import ReferenceVariable logger = logging.getLogger("BinaryOperationIR") -class BinaryType(object): - POWER = 0 # ** - MULTIPLICATION = 1 # * - DIVISION = 2 # / - MODULO = 3 # % - ADDITION = 4 # + - SUBTRACTION = 5 # - - LEFT_SHIFT = 6 # << - RIGHT_SHIFT = 7 # >> - AND = 8 # & - CARET = 9 # ^ - OR = 10 # | - LESS = 11 # < - GREATER = 12 # > - LESS_EQUAL = 13 # <= - GREATER_EQUAL = 14 # >= - EQUAL = 15 # == - NOT_EQUAL = 16 # != - ANDAND = 17 # && - OROR = 18 # || + +class BinaryType(Enum): + POWER = 0 # ** + MULTIPLICATION = 1 # * + DIVISION = 2 # / + MODULO = 3 # % + ADDITION = 4 # + + SUBTRACTION = 5 # - + LEFT_SHIFT = 6 # << + RIGHT_SHIFT = 7 # >> + AND = 8 # & + CARET = 9 # ^ + OR = 10 # | + LESS = 11 # < + GREATER = 12 # > + LESS_EQUAL = 13 # <= + GREATER_EQUAL = 14 # >= + EQUAL = 15 # == + NOT_EQUAL = 16 # != + ANDAND = 17 # && + OROR = 18 # || @staticmethod def return_bool(operation_type): @@ -83,47 +85,47 @@ class BinaryType(object): raise SlithIRError('get_type: Unknown operation type {})'.format(operation_type)) - @staticmethod - def str(operation_type): - if operation_type == BinaryType.POWER: - return '**' - if operation_type == BinaryType.MULTIPLICATION: - return '*' - if operation_type == BinaryType.DIVISION: - return '/' - if operation_type == BinaryType.MODULO: - return '%' - if operation_type == BinaryType.ADDITION: - return '+' - if operation_type == BinaryType.SUBTRACTION: - return '-' - if operation_type == BinaryType.LEFT_SHIFT: - return '<<' - if operation_type == BinaryType.RIGHT_SHIFT: - return '>>' - if operation_type == BinaryType.AND: - return '&' - if operation_type == BinaryType.CARET: - return '^' - if operation_type == BinaryType.OR: - return '|' - if operation_type == BinaryType.LESS: - return '<' - if operation_type == BinaryType.GREATER: - return '>' - if operation_type == BinaryType.LESS_EQUAL: - return '<=' - if operation_type == BinaryType.GREATER_EQUAL: - return '>=' - if operation_type == BinaryType.EQUAL: - return '==' - if operation_type == BinaryType.NOT_EQUAL: - return '!=' - if operation_type == BinaryType.ANDAND: - return '&&' - if operation_type == BinaryType.OROR: - return '||' - raise SlithIRError('str: Unknown operation type {})'.format(operation_type)) + def __str__(self): + if self == BinaryType.POWER: + return "**" + if self == BinaryType.MULTIPLICATION: + return "*" + if self == BinaryType.DIVISION: + return "/" + if self == BinaryType.MODULO: + return "%" + if self == BinaryType.ADDITION: + return "+" + if self == BinaryType.SUBTRACTION: + return "-" + if self == BinaryType.LEFT_SHIFT: + return "<<" + if self == BinaryType.RIGHT_SHIFT: + return ">>" + if self == BinaryType.AND: + return "&" + if self == BinaryType.CARET: + return "^" + if self == BinaryType.OR: + return "|" + if self == BinaryType.LESS: + return "<" + if self == BinaryType.GREATER: + return ">" + if self == BinaryType.LESS_EQUAL: + return "<=" + if self == BinaryType.GREATER_EQUAL: + return ">=" + if self == BinaryType.EQUAL: + return "==" + if self == BinaryType.NOT_EQUAL: + return "!=" + if self == BinaryType.ANDAND: + return "&&" + if self == BinaryType.OROR: + return "||" + raise SlithIRError("str: Unknown operation type {} {})".format(self, type(self))) + class Binary(OperationWithLValue): @@ -131,6 +133,7 @@ class Binary(OperationWithLValue): assert is_valid_rvalue(left_variable) assert is_valid_rvalue(right_variable) assert is_valid_lvalue(result) + assert isinstance(operation_type, BinaryType) super(Binary, self).__init__() self._variables = [left_variable, right_variable] self._type = operation_type @@ -143,7 +146,7 @@ class Binary(OperationWithLValue): @property def read(self): return [self.variable_left, self.variable_right] - + @property def get_variable(self): return self._variables @@ -162,7 +165,7 @@ class Binary(OperationWithLValue): @property def type_str(self): - return BinaryType.str(self._type) + return str(self._type) def __str__(self): if isinstance(self.lvalue, ReferenceVariable): @@ -170,10 +173,10 @@ class Binary(OperationWithLValue): while isinstance(points, ReferenceVariable): points = points.points_to return '{}(-> {}) = {} {} {}'.format(str(self.lvalue), - points, - self.variable_left, - self.type_str, - self.variable_right) + points, + self.variable_left, + self.type_str, + self.variable_right) return '{}({}) = {} {} {}'.format(str(self.lvalue), self.lvalue.type, self.variable_left, diff --git a/slither/slithir/operations/unary.py b/slither/slithir/operations/unary.py index 60cbfd4e3..99ca109e7 100644 --- a/slither/slithir/operations/unary.py +++ b/slither/slithir/operations/unary.py @@ -1,4 +1,6 @@ import logging +from enum import Enum + from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.exceptions import SlithIRError @@ -6,7 +8,7 @@ from slither.slithir.exceptions import SlithIRError logger = logging.getLogger("BinaryOperationIR") -class UnaryType: +class UnaryType(Enum): BANG = 0 # ! TILD = 1 # ~ @@ -19,14 +21,13 @@ class UnaryType: return UnaryType.TILD raise SlithIRError('get_type: Unknown operation type {}'.format(operation_type)) - @staticmethod - def str(operation_type): - if operation_type == UnaryType.BANG: - return '!' - if operation_type == UnaryType.TILD: - return '~' + def __str__(self): + if self == UnaryType.BANG: + return "!" + if self == UnaryType.TILD: + return "~" - raise SlithIRError('str: Unknown operation type {}'.format(operation_type)) + raise SlithIRError("str: Unknown operation type {}".format(self)) class Unary(OperationWithLValue): @@ -53,7 +54,7 @@ class Unary(OperationWithLValue): @property def type_str(self): - return UnaryType.str(self._type) + return str(self._type) def __str__(self): return "{} = {} {} ".format(self.lvalue, self.type_str, self.rvalue) diff --git a/slither/solc_parsing/cfg/node.py b/slither/solc_parsing/cfg/node.py index 60b7e3f2c..2d2bec0f3 100644 --- a/slither/solc_parsing/cfg/node.py +++ b/slither/solc_parsing/cfg/node.py @@ -1,68 +1,64 @@ +from typing import Optional, Dict + from slither.core.cfg.node import Node from slither.core.cfg.node import NodeType +from slither.core.expressions.assignment_operation import ( + AssignmentOperation, + AssignmentOperationType, +) +from slither.core.expressions.identifier import Identifier from slither.solc_parsing.expressions.expression_parsing import parse_expression +from slither.visitors.expression.find_calls import FindCalls from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.write_var import WriteVar -from slither.visitors.expression.find_calls import FindCalls - -from slither.visitors.expression.export_values import ExportValues -from slither.core.declarations.solidity_variables import SolidityVariable, SolidityFunction -from slither.core.declarations.function import Function -from slither.core.variables.state_variable import StateVariable -from slither.core.expressions.identifier import Identifier -from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType - -class NodeSolc(Node): +class NodeSolc: + def __init__(self, node: Node): + self._unparsed_expression: Optional[Dict] = None + self._node = node - def __init__(self, nodeType, nodeId): - super(NodeSolc, self).__init__(nodeType, nodeId) - self._unparsed_expression = None + @property + def underlying_node(self) -> Node: + return self._node - def add_unparsed_expression(self, expression): + def add_unparsed_expression(self, expression: Dict): assert self._unparsed_expression is None self._unparsed_expression = expression def analyze_expressions(self, caller_context): - if self.type == NodeType.VARIABLE and not self._expression: - self._expression = self.variable_declaration.expression + if self._node.type == NodeType.VARIABLE and not self._node.expression: + self._node.add_expression(self._node.variable_declaration.expression) if self._unparsed_expression: expression = parse_expression(self._unparsed_expression, caller_context) - self._expression = expression - self._unparsed_expression = None + self._node.add_expression(expression) + # self._unparsed_expression = None - if self.expression: + if self._node.expression: - if self.type == NodeType.VARIABLE: + if self._node.type == NodeType.VARIABLE: # Update the expression to be an assignement to the variable - #print(self.variable_declaration) - _expression = AssignmentOperation(Identifier(self.variable_declaration), - self.expression, - AssignmentOperationType.ASSIGN, - self.variable_declaration.type) - _expression.set_offset(self.expression.source_mapping, self.slither) - self._expression = _expression - - expression = self.expression - pp = ReadVar(expression) - self._expression_vars_read = pp.result() - -# self._vars_read = [item for sublist in vars_read for item in sublist] -# self._state_vars_read = [x for x in self.variables_read if\ -# isinstance(x, (StateVariable))] -# self._solidity_vars_read = [x for x in self.variables_read if\ -# isinstance(x, (SolidityVariable))] - - pp = WriteVar(expression) - self._expression_vars_written = pp.result() + _expression = AssignmentOperation( + Identifier(self._node.variable_declaration), + self._node.expression, + AssignmentOperationType.ASSIGN, + self._node.variable_declaration.type, + ) + _expression.set_offset(self._node.expression.source_mapping, self._node.slither) + self._node.add_expression(_expression, bypass_verif_empty=True) -# self._vars_written = [item for sublist in vars_written for item in sublist] -# self._state_vars_written = [x for x in self.variables_written if\ -# isinstance(x, StateVariable)] + expression = self._node.expression + read_var = ReadVar(expression) + self._node.variables_read_as_expression = read_var.result() - pp = FindCalls(expression) - self._expression_calls = pp.result() - self._external_calls_as_expressions = [c for c in self.calls_as_expression if not isinstance(c.called, Identifier)] - self._internal_calls_as_expressions = [c for c in self.calls_as_expression if isinstance(c.called, Identifier)] + write_var = WriteVar(expression) + self._node.variables_written_as_expression = write_var.result() + find_call = FindCalls(expression) + self._node.calls_as_expression = find_call.result() + self._node.external_calls_as_expressions = [ + c for c in self._node.calls_as_expression if not isinstance(c.called, Identifier) + ] + self._node.internal_calls_as_expressions = [ + c for c in self._node.calls_as_expression if isinstance(c.called, Identifier) + ] diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 896150476..f7d001fe8 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,55 +1,67 @@ import logging +from typing import List, Dict, Callable, TYPE_CHECKING, Union +from slither.core.declarations import Modifier, Structure, Event from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function, FunctionType from slither.core.declarations.enum import Enum -from slither.core.cfg.node import Node, NodeType -from slither.core.expressions import AssignmentOperation, Identifier, AssignmentOperationType -from slither.slithir.variables import StateIRVariable +from slither.core.declarations.function import Function +from slither.core.variables.state_variable import StateVariable from slither.solc_parsing.declarations.event import EventSolc from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.modifier import ModifierSolc from slither.solc_parsing.declarations.structure import StructureSolc +from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.solidity_types.type_parsing import parse_type from slither.solc_parsing.variables.state_variable import StateVariableSolc -from slither.solc_parsing.exceptions import ParsingError, VariableNotFound -logger = logging.getLogger("ContractSolcParsing") +LOGGER = logging.getLogger("ContractSolcParsing") +if TYPE_CHECKING: + from slither.solc_parsing.slitherSolc import SlitherSolc + from slither.core.slither_core import SlitherCore -class ContractSolc04(Contract): - def __init__(self, slitherSolc, data): +class ContractSolc: + def __init__(self, slither_parser: "SlitherSolc", contract: Contract, data): # assert slitherSolc.solc_version.startswith('0.4') - super(ContractSolc04, self).__init__() - self.set_slither(slitherSolc) + self._contract = contract + self._contract.set_slither(slither_parser.core) + self._slither_parser = slither_parser self._data = data - self._functionsNotParsed = [] - self._modifiersNotParsed = [] - self._functions_no_params = [] - self._modifiers_no_params = [] - self._eventsNotParsed = [] - self._variablesNotParsed = [] - self._enumsNotParsed = [] - self._structuresNotParsed = [] - self._usingForNotParsed = [] + self._functionsNotParsed: List[Dict] = [] + self._modifiersNotParsed: List[Dict] = [] + self._functions_no_params: List[FunctionSolc] = [] + self._modifiers_no_params: List[ModifierSolc] = [] + self._eventsNotParsed: List[EventSolc] = [] + self._variablesNotParsed: List[Dict] = [] + self._enumsNotParsed: List[Dict] = [] + self._structuresNotParsed: List[Dict] = [] + self._usingForNotParsed: List[Dict] = [] + + self._functions_parser: List[FunctionSolc] = [] + self._modifiers_parser: List[ModifierSolc] = [] + self._structures_parser: List[StructureSolc] = [] - self._is_analyzed = False + self._is_analyzed: bool = False # use to remap inheritance id - self._remapping = {} + self._remapping: Dict[str, str] = {} + + self.baseContracts = [] + self.baseConstructorContractsCalled = [] + self._linearized_base_contracts: List[int] + + self._variables_parser: List[StateVariableSolc] = [] # Export info if self.is_compact_ast: - self._name = self._data['name'] + self._contract.name = self._data["name"] else: - self._name = self._data['attributes'][self.get_key()] + self._contract.name = self._data["attributes"][self.get_key()] - self._id = self._data['id'] - - self._inheritance = [] + self._contract.id = self._data["id"] self._parse_contract_info() self._parse_contract_items() @@ -61,33 +73,57 @@ class ContractSolc04(Contract): ################################################################################### @property - def is_analyzed(self): + def is_analyzed(self) -> bool: return self._is_analyzed - def set_is_analyzed(self, is_analyzed): + def set_is_analyzed(self, is_analyzed: bool): self._is_analyzed = is_analyzed + @property + def underlying_contract(self) -> Contract: + return self._contract + + @property + def linearized_base_contracts(self) -> List[int]: + return self._linearized_base_contracts + + @property + def slither(self) -> "SlitherCore": + return self._contract.slither + + @property + def slither_parser(self) -> "SlitherSolc": + return self._slither_parser + + @property + def functions_parser(self) -> List["FunctionSolc"]: + return self._functions_parser + + @property + def modifiers_parser(self) -> List["ModifierSolc"]: + return self._modifiers_parser + ################################################################################### ################################################################################### # region AST ################################################################################### ################################################################################### - def get_key(self): - return self.slither.get_key() + def get_key(self) -> str: + return self._slither_parser.get_key() - def get_children(self, key='nodes'): + def get_children(self, key="nodes") -> str: if self.is_compact_ast: return key - return 'children' + return "children" @property - def remapping(self): + def remapping(self) -> Dict[str, str]: return self._remapping @property - def is_compact_ast(self): - return self.slither.is_compact_ast + def is_compact_ast(self) -> bool: + return self._slither_parser.is_compact_ast # endregion ################################################################################### @@ -100,162 +136,190 @@ class ContractSolc04(Contract): if self.is_compact_ast: attributes = self._data else: - attributes = self._data['attributes'] + attributes = self._data["attributes"] - self.isInterface = False - if 'contractKind' in attributes: - if attributes['contractKind'] == 'interface': - self.isInterface = True - self._kind = attributes['contractKind'] - self.linearizedBaseContracts = attributes['linearizedBaseContracts'] - self.fullyImplemented = attributes['fullyImplemented'] + self._contract.is_interface = False + if "contractKind" in attributes: + if attributes["contractKind"] == "interface": + self._contract.is_interface = True + self._contract.kind = attributes["contractKind"] + self._linearized_base_contracts = attributes["linearizedBaseContracts"] + # self._contract.fullyImplemented = attributes["fullyImplemented"] # Parse base contract information self._parse_base_contract_info() # trufle does some re-mapping of id - if 'baseContracts' in self._data: - for elem in self._data['baseContracts']: - if elem['nodeType'] == 'InheritanceSpecifier': - self._remapping[elem['baseName']['referencedDeclaration']] = elem['baseName']['name'] + if "baseContracts" in self._data: + for elem in self._data["baseContracts"]: + if elem["nodeType"] == "InheritanceSpecifier": + self._remapping[elem["baseName"]["referencedDeclaration"]] = elem["baseName"][ + "name" + ] def _parse_base_contract_info(self): # Parse base contracts (immediate, non-linearized) - self.baseContracts = [] - self.baseConstructorContractsCalled = [] if self.is_compact_ast: # Parse base contracts + constructors in compact-ast - if 'baseContracts' in self._data: - for base_contract in self._data['baseContracts']: - if base_contract['nodeType'] != 'InheritanceSpecifier': + if "baseContracts" in self._data: + for base_contract in self._data["baseContracts"]: + if base_contract["nodeType"] != "InheritanceSpecifier": continue - if 'baseName' not in base_contract or 'referencedDeclaration' not in base_contract['baseName']: + if ( + "baseName" not in base_contract + or "referencedDeclaration" not in base_contract["baseName"] + ): continue # Obtain our contract reference and add it to our base contract list - referencedDeclaration = base_contract['baseName']['referencedDeclaration'] + referencedDeclaration = base_contract["baseName"]["referencedDeclaration"] self.baseContracts.append(referencedDeclaration) # If we have defined arguments in our arguments object, this is a constructor invocation. # (note: 'arguments' can be [], which is not the same as None. [] implies a constructor was # called with no arguments, while None implies no constructor was called). - if 'arguments' in base_contract and base_contract['arguments'] is not None: + if "arguments" in base_contract and base_contract["arguments"] is not None: self.baseConstructorContractsCalled.append(referencedDeclaration) else: # Parse base contracts + constructors in legacy-ast - if 'children' in self._data: - for base_contract in self._data['children']: - if base_contract['name'] != 'InheritanceSpecifier': + if "children" in self._data: + for base_contract in self._data["children"]: + if base_contract["name"] != "InheritanceSpecifier": continue - if 'children' not in base_contract or len(base_contract['children']) == 0: + if "children" not in base_contract or len(base_contract["children"]) == 0: continue # Obtain all items for this base contract specification (base contract, followed by arguments) - base_contract_items = base_contract['children'] - if 'name' not in base_contract_items[0] or base_contract_items[0]['name'] != 'UserDefinedTypeName': + base_contract_items = base_contract["children"] + if ( + "name" not in base_contract_items[0] + or base_contract_items[0]["name"] != "UserDefinedTypeName" + ): continue - if 'attributes' not in base_contract_items[0] or 'referencedDeclaration' not in \ - base_contract_items[0]['attributes']: + if ( + "attributes" not in base_contract_items[0] + or "referencedDeclaration" not in base_contract_items[0]["attributes"] + ): continue # Obtain our contract reference and add it to our base contract list - referencedDeclaration = base_contract_items[0]['attributes']['referencedDeclaration'] + referencedDeclaration = base_contract_items[0]["attributes"][ + "referencedDeclaration" + ] self.baseContracts.append(referencedDeclaration) # If we have an 'attributes'->'arguments' which is None, this is not a constructor call. - if 'attributes' not in base_contract or 'arguments' not in base_contract['attributes'] or \ - base_contract['attributes']['arguments'] is not None: + if ( + "attributes" not in base_contract + or "arguments" not in base_contract["attributes"] + or base_contract["attributes"]["arguments"] is not None + ): self.baseConstructorContractsCalled.append(referencedDeclaration) def _parse_contract_items(self): if not self.get_children() in self._data: # empty contract return for item in self._data[self.get_children()]: - if item[self.get_key()] == 'FunctionDefinition': + if item[self.get_key()] == "FunctionDefinition": self._functionsNotParsed.append(item) - elif item[self.get_key()] == 'EventDefinition': + elif item[self.get_key()] == "EventDefinition": self._eventsNotParsed.append(item) - elif item[self.get_key()] == 'InheritanceSpecifier': + elif item[self.get_key()] == "InheritanceSpecifier": # we dont need to parse it as it is redundant # with self.linearizedBaseContracts continue - elif item[self.get_key()] == 'VariableDeclaration': + elif item[self.get_key()] == "VariableDeclaration": self._variablesNotParsed.append(item) - elif item[self.get_key()] == 'EnumDefinition': + elif item[self.get_key()] == "EnumDefinition": self._enumsNotParsed.append(item) - elif item[self.get_key()] == 'ModifierDefinition': + elif item[self.get_key()] == "ModifierDefinition": self._modifiersNotParsed.append(item) - elif item[self.get_key()] == 'StructDefinition': + elif item[self.get_key()] == "StructDefinition": self._structuresNotParsed.append(item) - elif item[self.get_key()] == 'UsingForDirective': + elif item[self.get_key()] == "UsingForDirective": self._usingForNotParsed.append(item) else: - raise ParsingError('Unknown contract item: ' + item[self.get_key()]) + raise ParsingError("Unknown contract item: " + item[self.get_key()]) return - def _parse_struct(self, struct): + def _parse_struct(self, struct: Dict): if self.is_compact_ast: - name = struct['name'] + name = struct["name"] attributes = struct else: - name = struct['attributes'][self.get_key()] - attributes = struct['attributes'] - if 'canonicalName' in attributes: - canonicalName = attributes['canonicalName'] + name = struct["attributes"][self.get_key()] + attributes = struct["attributes"] + if "canonicalName" in attributes: + canonicalName = attributes["canonicalName"] else: - canonicalName = self.name + '.' + name + canonicalName = self._contract.name + "." + name - if self.get_children('members') in struct: - children = struct[self.get_children('members')] + if self.get_children("members") in struct: + children = struct[self.get_children("members")] else: children = [] # empty struct - st = StructureSolc(name, canonicalName, children) - st.set_contract(self) - st.set_offset(struct['src'], self.slither) - self._structures[name] = st + + st = Structure() + st.set_contract(self._contract) + st.set_offset(struct["src"], self._contract.slither) + + st_parser = StructureSolc(st, name, canonicalName, children, self) + self._contract.structures_as_dict[name] = st + self._structures_parser.append(st_parser) def parse_structs(self): - for father in self.inheritance_reverse: - self._structures.update(father.structures_as_dict()) + for father in self._contract.inheritance_reverse: + self._contract.structures_as_dict.update(father.structures_as_dict) for struct in self._structuresNotParsed: self._parse_struct(struct) self._structuresNotParsed = None def parse_state_variables(self): - for father in self.inheritance_reverse: - self._variables.update(father.variables_as_dict()) - self._variables_ordered += father.state_variables_ordered + for father in self._contract.inheritance_reverse: + self._contract.variables_as_dict.update(father.variables_as_dict) + self._contract.add_variables_ordered(father.state_variables_ordered) for varNotParsed in self._variablesNotParsed: - var = StateVariableSolc(varNotParsed) - var.set_offset(varNotParsed['src'], self.slither) - var.set_contract(self) + var = StateVariable() + var.set_offset(varNotParsed["src"], self._contract.slither) + var.set_contract(self._contract) - self._variables[var.name] = var - self._variables_ordered.append(var) + var_parser = StateVariableSolc(var, varNotParsed) + self._variables_parser.append(var_parser) - def _parse_modifier(self, modifier): + self._contract.variables_as_dict[var.name] = var + self._contract.add_variables_ordered([var]) - modif = ModifierSolc(modifier, self, self) - modif.set_contract(self) - modif.set_contract_declarer(self) - modif.set_offset(modifier['src'], self.slither) - self.slither.add_modifier(modif) - self._modifiers_no_params.append(modif) + def _parse_modifier(self, modifier_data: Dict): + modif = Modifier() + modif.set_offset(modifier_data["src"], self._contract.slither) + modif.set_contract(self._contract) + modif.set_contract_declarer(self._contract) - def parse_modifiers(self): + modif_parser = ModifierSolc(modif, modifier_data, self) + self._contract.slither.add_modifier(modif) + self._modifiers_no_params.append(modif_parser) + self._modifiers_parser.append(modif_parser) + + self._slither_parser.add_functions_parser(modif_parser) + def parse_modifiers(self): for modifier in self._modifiersNotParsed: self._parse_modifier(modifier) self._modifiersNotParsed = None - return + def _parse_function(self, function_data: Dict): + func = Function() + func.set_offset(function_data["src"], self._contract.slither) + func.set_contract(self._contract) + func.set_contract_declarer(self._contract) - def _parse_function(self, function): - func = FunctionSolc(function, self, self) - func.set_offset(function['src'], self.slither) - self.slither.add_function(func) - self._functions_no_params.append(func) + func_parser = FunctionSolc(func, function_data, self) + self._contract.slither.add_function(func) + self._functions_no_params.append(func_parser) + self._functions_parser.append(func_parser) + + self._slither_parser.add_functions_parser(func_parser) def parse_functions(self): @@ -264,8 +328,6 @@ class ContractSolc04(Contract): self._functionsNotParsed = None - return - # endregion ################################################################################### ################################################################################### @@ -274,56 +336,78 @@ class ContractSolc04(Contract): ################################################################################### def log_incorrect_parsing(self, error): - logger.error(error) - self._is_incorrectly_parsed = True + LOGGER.error(error) + self._contract.is_incorrectly_parsed = True def analyze_content_modifiers(self): try: - for modifier in self.modifiers: - modifier.analyze_content() + for modifier_parser in self._modifiers_parser: + modifier_parser.analyze_content() except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing modifier {e}') - return + self.log_incorrect_parsing(f"Missing modifier {e}") def analyze_content_functions(self): try: - for function in self.functions: - function.analyze_content() + for function_parser in self._functions_parser: + function_parser.analyze_content() except (VariableNotFound, KeyError, ParsingError) as e: - self.log_incorrect_parsing(f'Missing function {e}') + self.log_incorrect_parsing(f"Missing function {e}") return def analyze_params_modifiers(self): - try: elements_no_params = self._modifiers_no_params - getter = lambda f: f.modifiers - getter_available = lambda f: f.modifiers_declared - Cls = ModifierSolc - self._modifiers = self._analyze_params_elements(elements_no_params, getter, getter_available, Cls) + getter = lambda c: c.modifiers_parser + getter_available = lambda c: c.modifiers_declared + Cls = Modifier + Cls_parser = ModifierSolc + modifiers = self._analyze_params_elements( + elements_no_params, + getter, + getter_available, + Cls, + Cls_parser, + self._modifiers_parser, + ) + self._contract.set_modifiers(modifiers) except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing params {e}') + self.log_incorrect_parsing(f"Missing params {e}") self._modifiers_no_params = [] - return - def analyze_params_functions(self): try: elements_no_params = self._functions_no_params - getter = lambda f: f.functions - getter_available = lambda f: f.functions_declared - Cls = FunctionSolc - self._functions = self._analyze_params_elements(elements_no_params, getter, getter_available, Cls) + getter = lambda c: c.functions_parser + getter_available = lambda c: c.functions_declared + Cls = Function + Cls_parser = FunctionSolc + functions = self._analyze_params_elements( + elements_no_params, + getter, + getter_available, + Cls, + Cls_parser, + self._functions_parser, + ) + self._contract.set_functions(functions) except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing params {e}') + self.log_incorrect_parsing(f"Missing params {e}") self._functions_no_params = [] - return - def _analyze_params_elements(self, elements_no_params, getter, getter_available, Cls): + def _analyze_params_elements( + self, + elements_no_params: List[FunctionSolc], + getter: Callable[["ContractSolc"], List[FunctionSolc]], + getter_available: Callable[[Contract], List[Function]], + Cls: Callable, + Cls_parser: Callable, + parser: List[FunctionSolc], + ) -> Dict[str, Union[Function, Modifier]]: """ Analyze the parameters of the given elements (Function or Modifier). The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier) If the element is shadowed, set is_shadowed to True + :param elements_no_params: list of elements to analyzer :param getter: fun x :param getter_available: fun x @@ -333,15 +417,31 @@ class ContractSolc04(Contract): all_elements = {} try: - for father in self.inheritance: - for element in getter(father): - elem = Cls(element._functionNotParsed, self, element.contract_declarer) - elem.set_offset(element._functionNotParsed['src'], self.slither) - elem.analyze_params() - self.slither.add_function(elem) + for father in self._contract.inheritance: + father_parser = self._slither_parser.underlying_contract_to_parser[father] + for element_parser in getter(father_parser): + elem = Cls() + elem.set_contract(self._contract) + elem.set_contract_declarer(element_parser.underlying_function.contract_declarer) + elem.set_offset( + element_parser.function_not_parsed["src"], self._contract.slither + ) + + elem_parser = Cls_parser(elem, element_parser.function_not_parsed, self,) + elem_parser.analyze_params() + if isinstance(elem, Modifier): + self._contract.slither.add_modifier(elem) + else: + self._contract.slither.add_function(elem) + + self._slither_parser.add_functions_parser(elem_parser) + all_elements[elem.canonical_name] = elem + parser.append(elem_parser) - accessible_elements = self.available_elements_from_inheritances(all_elements, getter_available) + accessible_elements = self._contract.available_elements_from_inheritances( + all_elements, getter_available + ) # If there is a constructor in the functions # We remove the previous constructor @@ -349,128 +449,65 @@ class ContractSolc04(Contract): # # Note: contract.all_functions_called returns the constructors of the base contracts has_constructor = False - for element in elements_no_params: - element.analyze_params() - if element.is_constructor: + for element_parser in elements_no_params: + element_parser.analyze_params() + if element_parser.underlying_function.is_constructor: has_constructor = True if has_constructor: - _accessible_functions = {k: v for (k, v) in accessible_elements.items() if not v.is_constructor} - - for element in elements_no_params: - accessible_elements[element.full_name] = element - all_elements[element.canonical_name] = element + _accessible_functions = { + k: v for (k, v) in accessible_elements.items() if not v.is_constructor + } + + for element_parser in elements_no_params: + accessible_elements[ + element_parser.underlying_function.full_name + ] = element_parser.underlying_function + all_elements[ + element_parser.underlying_function.canonical_name + ] = element_parser.underlying_function for element in all_elements.values(): if accessible_elements[element.full_name] != all_elements[element.canonical_name]: element.is_shadowed = True accessible_elements[element.full_name].shadows = True except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing params {e}') + self.log_incorrect_parsing(f"Missing params {e}") return all_elements def analyze_constant_state_variables(self): - for var in self.variables: - if var.is_constant: + for var_parser in self._variables_parser: + if var_parser.underlying_variable.is_constant: # cant parse constant expression based on function calls try: - var.analyze(self) + var_parser.analyze(self) except (VariableNotFound, KeyError) as e: - logger.error(e) + LOGGER.error(e) pass - return - - def _create_node(self, func, counter, variable): - # Function uses to create node for state variable declaration statements - node = Node(NodeType.OTHER_ENTRYPOINT, counter) - node.set_offset(variable.source_mapping, self.slither) - node.set_function(func) - func.add_node(node) - expression = AssignmentOperation(Identifier(variable), - variable.expression, - AssignmentOperationType.ASSIGN, - variable.type) - - expression.set_offset(variable.source_mapping, self.slither) - node.add_expression(expression) - return node - - def add_constructor_variables(self): - if self.state_variables: - for (idx, variable_candidate) in enumerate(self.state_variables): - if variable_candidate.expression and not variable_candidate.is_constant: - - constructor_variable = Function() - constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) - constructor_variable.set_visibility('internal') - # For now, source mapping of the constructor variable is the whole contract - # Could be improved with a targeted source mapping - constructor_variable.set_offset(self.source_mapping, self.slither) - self._functions[constructor_variable.canonical_name] = constructor_variable - - prev_node = self._create_node(constructor_variable, 0, variable_candidate) - variable_candidate.node_initialization = prev_node - counter = 1 - for v in self.state_variables[idx + 1:]: - if v.expression and not v.is_constant: - next_node = self._create_node(constructor_variable, counter, v) - v.node_initialization = next_node - prev_node.add_son(next_node) - next_node.add_father(prev_node) - counter += 1 - break - - for (idx, variable_candidate) in enumerate(self.state_variables): - if variable_candidate.expression and variable_candidate.is_constant: - - constructor_variable = Function() - constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) - constructor_variable.set_visibility('internal') - # For now, source mapping of the constructor variable is the whole contract - # Could be improved with a targeted source mapping - constructor_variable.set_offset(self.source_mapping, self.slither) - self._functions[constructor_variable.canonical_name] = constructor_variable - - prev_node = self._create_node(constructor_variable, 0, variable_candidate) - variable_candidate.node_initialization = prev_node - counter = 1 - for v in self.state_variables[idx + 1:]: - if v.expression and v.is_constant: - next_node = self._create_node(constructor_variable, counter, v) - v.node_initialization = next_node - prev_node.add_son(next_node) - next_node.add_father(prev_node) - counter += 1 - - break def analyze_state_variables(self): try: - for var in self.variables: - var.analyze(self) + for var_parser in self._variables_parser: + var_parser.analyze(self) return except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing state variable {e}') + self.log_incorrect_parsing(f"Missing state variable {e}") def analyze_using_for(self): try: - for father in self.inheritance: - self._using_for.update(father.using_for) + for father in self._contract.inheritance: + self._contract.using_for.update(father.using_for) if self.is_compact_ast: for using_for in self._usingForNotParsed: - lib_name = parse_type(using_for['libraryName'], self) - if 'typeName' in using_for and using_for['typeName']: - type_name = parse_type(using_for['typeName'], self) + lib_name = parse_type(using_for["libraryName"], self) + if "typeName" in using_for and using_for["typeName"]: + type_name = parse_type(using_for["typeName"], self) else: - type_name = '*' - if not type_name in self._using_for: - self.using_for[type_name] = [] - self._using_for[type_name].append(lib_name) + type_name = "*" + if type_name not in self._contract.using_for: + self._contract.using_for[type_name] = [] + self._contract.using_for[type_name].append(lib_name) else: for using_for in self._usingForNotParsed: children = using_for[self.get_children()] @@ -480,18 +517,18 @@ class ContractSolc04(Contract): old = parse_type(children[1], self) else: new = parse_type(children[0], self) - old = '*' - if not old in self._using_for: - self.using_for[old] = [] - self._using_for[old].append(new) + old = "*" + if old not in self._contract.using_for: + self._contract.using_for[old] = [] + self._contract.using_for[old].append(new) self._usingForNotParsed = [] except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing using for {e}') + self.log_incorrect_parsing(f"Missing using for {e}") def analyze_enums(self): try: - for father in self.inheritance: - self._enums.update(father.enums_as_dict()) + for father in self._contract.inheritance: + self._contract.enums_as_dict.update(father.enums_as_dict) for enum in self._enumsNotParsed: # for enum, we can parse and analyze it @@ -499,107 +536,60 @@ class ContractSolc04(Contract): self._analyze_enum(enum) self._enumsNotParsed = None except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing enum {e}') + self.log_incorrect_parsing(f"Missing enum {e}") def _analyze_enum(self, enum): # Enum can be parsed in one pass if self.is_compact_ast: - name = enum['name'] - canonicalName = enum['canonicalName'] + name = enum["name"] + canonicalName = enum["canonicalName"] else: - name = enum['attributes'][self.get_key()] - if 'canonicalName' in enum['attributes']: - canonicalName = enum['attributes']['canonicalName'] + name = enum["attributes"][self.get_key()] + if "canonicalName" in enum["attributes"]: + canonicalName = enum["attributes"]["canonicalName"] else: - canonicalName = self.name + '.' + name + canonicalName = self._contract.name + "." + name values = [] - for child in enum[self.get_children('members')]: - assert child[self.get_key()] == 'EnumValue' + for child in enum[self.get_children("members")]: + assert child[self.get_key()] == "EnumValue" if self.is_compact_ast: - values.append(child['name']) + values.append(child["name"]) else: - values.append(child['attributes'][self.get_key()]) + values.append(child["attributes"][self.get_key()]) new_enum = Enum(name, canonicalName, values) - new_enum.set_contract(self) - new_enum.set_offset(enum['src'], self.slither) - self._enums[canonicalName] = new_enum + new_enum.set_contract(self._contract) + new_enum.set_offset(enum["src"], self._contract.slither) + self._contract.enums_as_dict[canonicalName] = new_enum - def _analyze_struct(self, struct): + def _analyze_struct(self, struct: StructureSolc): struct.analyze() def analyze_structs(self): try: - for struct in self.structures: + for struct in self._structures_parser: self._analyze_struct(struct) except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing struct {e}') + self.log_incorrect_parsing(f"Missing struct {e}") def analyze_events(self): try: - for father in self.inheritance_reverse: - self._events.update(father.events_as_dict()) + for father in self._contract.inheritance_reverse: + self._contract.events_as_dict.update(father.events_as_dict) for event_to_parse in self._eventsNotParsed: - event = EventSolc(event_to_parse, self) - event.analyze(self) - event.set_contract(self) - event.set_offset(event_to_parse['src'], self.slither) - self._events[event.full_name] = event + event = Event() + event.set_contract(self._contract) + event.set_offset(event_to_parse["src"], self._contract.slither) + + event_parser = EventSolc(event, event_to_parse, self) + event_parser.analyze(self) + self._contract.events_as_dict[event.full_name] = event except (VariableNotFound, KeyError) as e: - self.log_incorrect_parsing(f'Missing event {e}') + self.log_incorrect_parsing(f"Missing event {e}") self._eventsNotParsed = None - # endregion - ################################################################################### - ################################################################################### - # region SlithIR - ################################################################################### - ################################################################################### - - def convert_expression_to_slithir(self): - for func in self.functions + self.modifiers: - try: - func.generate_slithir_and_analyze() - except AttributeError: - # This can happens for example if there is a call to an interface - # And the interface is redefined due to contract's name reuse - # But the available version misses some functions - self.log_incorrect_parsing(f'Impossible to generate IR for {self.name}.{func.name}') - - all_ssa_state_variables_instances = dict() - - for contract in self.inheritance: - for v in contract.state_variables_declared: - new_var = StateIRVariable(v) - all_ssa_state_variables_instances[v.canonical_name] = new_var - self._initial_state_variables.append(new_var) - - for v in self.variables: - if v.contract == self: - new_var = StateIRVariable(v) - all_ssa_state_variables_instances[v.canonical_name] = new_var - self._initial_state_variables.append(new_var) - - for func in self.functions + self.modifiers: - func.generate_slithir_ssa(all_ssa_state_variables_instances) - - def fix_phi(self): - last_state_variables_instances = dict() - initial_state_variables_instances = dict() - for v in self._initial_state_variables: - last_state_variables_instances[v.canonical_name] = [] - initial_state_variables_instances[v.canonical_name] = v - - for func in self.functions + self.modifiers: - result = func.get_last_ssa_state_variables_instances() - for variable_name, instances in result.items(): - last_state_variables_instances[variable_name] += instances - - for func in self.functions + self.modifiers: - func.fix_phi(last_state_variables_instances, initial_state_variables_instances) - # endregion ################################################################################### ################################################################################### @@ -631,6 +621,6 @@ class ContractSolc04(Contract): ################################################################################### def __hash__(self): - return self._id + return self._contract.id # endregion diff --git a/slither/solc_parsing/declarations/event.py b/slither/solc_parsing/declarations/event.py index d37a478f2..f480e3974 100644 --- a/slither/solc_parsing/declarations/event.py +++ b/slither/solc_parsing/declarations/event.py @@ -1,43 +1,55 @@ """ Event module """ +from typing import TYPE_CHECKING, Dict + +from slither.core.variables.event_variable import EventVariable from slither.solc_parsing.variables.event_variable import EventVariableSolc from slither.core.declarations.event import Event -class EventSolc(Event): +if TYPE_CHECKING: + from slither.solc_parsing.declarations.contract import ContractSolc + + +class EventSolc: """ Event class """ - def __init__(self, event, contract): - super(EventSolc, self).__init__() - self._contract = contract + def __init__(self, event: Event, event_data: Dict, contract_parser: "ContractSolc"): + + self._event = event + event.set_contract(contract_parser.underlying_contract) + self._parser_contract = contract_parser - self._elems = [] if self.is_compact_ast: - self._name = event['name'] - elems = event['parameters'] - assert elems['nodeType'] == 'ParameterList' - self._elemsNotParsed = elems['parameters'] + self._event.name = event_data["name"] + elems = event_data["parameters"] + assert elems["nodeType"] == "ParameterList" + self._elemsNotParsed = elems["parameters"] else: - self._name = event['attributes']['name'] - elems = event['children'][0] + self._event.name = event_data["attributes"]["name"] + elems = event_data["children"][0] - assert elems['name'] == 'ParameterList' - if 'children' in elems: - self._elemsNotParsed = elems['children'] + assert elems["name"] == "ParameterList" + if "children" in elems: + self._elemsNotParsed = elems["children"] else: self._elemsNotParsed = [] @property - def is_compact_ast(self): - return self.contract.is_compact_ast + def is_compact_ast(self) -> bool: + return self._parser_contract.is_compact_ast - def analyze(self, contract): + def analyze(self, contract: "ContractSolc"): for elem_to_parse in self._elemsNotParsed: - elem = EventVariableSolc(elem_to_parse) - elem.analyze(contract) - self._elems.append(elem) + elem = EventVariable() + # Todo: check if the source offset is always here + if "src" in elem_to_parse: + elem.set_offset(elem_to_parse["src"], self._parser_contract.slither) + elem_parser = EventVariableSolc(elem, elem_to_parse) + elem_parser.analyze(contract) - self._elemsNotParsed = [] + self._event.elems.append(elem) + self._elemsNotParsed = [] diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 40f8fc387..403127444 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -1,51 +1,67 @@ """ """ import logging +from typing import Dict, Optional, Union, List, TYPE_CHECKING -from slither.core.cfg.node import NodeType, link_nodes, insert_node +from slither.core.cfg.node import NodeType, link_nodes, insert_node, Node from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function, ModifierStatements, FunctionType from slither.core.expressions import AssignmentOperation +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.solc_parsing.cfg.node import NodeSolc -from slither.solc_parsing.expressions.expression_parsing import \ - parse_expression +from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.variables.local_variable import LocalVariableSolc -from slither.solc_parsing.variables.local_variable_init_from_tuple import \ - LocalVariableInitFromTupleSolc -from slither.solc_parsing.variables.variable_declaration import \ - MultipleVariablesDeclaration +from slither.solc_parsing.variables.local_variable_init_from_tuple import ( + LocalVariableInitFromTupleSolc, +) +from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration from slither.utils.expression_manipulations import SplitTernaryExpression from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional from slither.solc_parsing.exceptions import ParsingError from slither.core.source_mapping.source_mapping import SourceMapping -logger = logging.getLogger("FunctionSolc") +if TYPE_CHECKING: + from slither.core.expressions.expression import Expression + from slither.solc_parsing.declarations.contract import ContractSolc + from slither.solc_parsing.slitherSolc import SlitherSolc + from slither.core.slither_core import SlitherCore -class FunctionSolc(Function): +LOGGER = logging.getLogger("FunctionSolc") + + +def link_underlying_nodes(node1: NodeSolc, node2: NodeSolc): + link_nodes(node1.underlying_node, node2.underlying_node) + + +class FunctionSolc: """ """ + # elems = [(type, name)] - def __init__(self, function, contract, contract_declarer): - super(FunctionSolc, self).__init__() - self._contract = contract - self._contract_declarer = contract_declarer + def __init__( + self, function: Function, function_data: Dict, contract_parser: "ContractSolc", + ): + self._slither_parser: "SlitherSolc" = contract_parser.slither_parser + self._contract_parser = contract_parser + self._function = function # Only present if compact AST - self._referenced_declaration = None + self._referenced_declaration: Optional[int] = None if self.is_compact_ast: - self._name = function['name'] - if 'id' in function: - self._referenced_declaration = function['id'] + self._function.name = function_data["name"] + if "id" in function_data: + self._referenced_declaration = function_data["id"] + self._function.id = function_data["id"] else: - self._name = function['attributes'][self.get_key()] - self._functionNotParsed = function + self._function.name = function_data["attributes"][self.get_key()] + self._functionNotParsed = function_data self._params_was_analyzed = False self._content_was_analyzed = False - self._counter_nodes = 0 self._counter_scope_local_variables = 0 # variable renamed will map the solc id @@ -54,33 +70,60 @@ class FunctionSolc(Function): # we can retrieve the variable # It only matters if two variables have the same name in the function # which is only possible with solc > 0.5 - self._variables_renamed = {} + self._variables_renamed: Dict[ + int, Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] + ] = {} self._analyze_type() + self.parameters_src = SourceMapping() + self.returns_src = SourceMapping() + + self._node_to_nodesolc: Dict[Node, NodeSolc] = dict() + + self._local_variables_parser: List[ + Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] + ] = [] + + @property + def underlying_function(self) -> Function: + return self._function + + @property + def contract_parser(self) -> "ContractSolc": + return self._contract_parser + + @property + def slither_parser(self) -> "SlitherSolc": + return self._slither_parser + + @property + def slither(self) -> "SlitherCore": + return self._function.slither + ################################################################################### ################################################################################### # region AST format ################################################################################### ################################################################################### - def get_key(self): - return self.slither.get_key() + def get_key(self) -> str: + return self._slither_parser.get_key() - def get_children(self, key): + def get_children(self, key: str) -> str: if self.is_compact_ast: return key - return 'children' + return "children" @property def is_compact_ast(self): - return self.slither.is_compact_ast + return self._slither_parser.is_compact_ast @property - def referenced_declaration(self): - ''' + def referenced_declaration(self) -> Optional[str]: + """ Return the compact AST referenced declaration id (None for legacy AST) - ''' + """ return self._referenced_declaration # endregion @@ -91,22 +134,31 @@ class FunctionSolc(Function): ################################################################################### @property - def variables_renamed(self): + def variables_renamed( + self, + ) -> Dict[int, Union[LocalVariableSolc, LocalVariableInitFromTupleSolc]]: return self._variables_renamed - def _add_local_variable(self, local_var): + def _add_local_variable( + self, local_var_parser: Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] + ): # If two local variables have the same name # We add a suffix to the new variable # This is done to prevent collision during SSA translation # Use of while in case of collision # In the worst case, the name will be really long - if local_var.name: - while local_var.name in self._variables: - local_var.name += "_scope_{}".format(self._counter_scope_local_variables) + if local_var_parser.underlying_variable.name: + while local_var_parser.underlying_variable.name in self._function.variables: + local_var_parser.underlying_variable.name += "_scope_{}".format( + self._counter_scope_local_variables + ) self._counter_scope_local_variables += 1 - if not local_var.reference_id is None: - self._variables_renamed[local_var.reference_id] = local_var - self._variables[local_var.name] = local_var + if local_var_parser.reference_id is not None: + self._variables_renamed[local_var_parser.reference_id] = local_var_parser + self._function.variables_as_dict[ + local_var_parser.underlying_variable.name + ] = local_var_parser.underlying_variable + self._local_variables_parser.append(local_var_parser) # endregion ################################################################################### @@ -115,6 +167,10 @@ class FunctionSolc(Function): ################################################################################### ################################################################################### + @property + def function_not_parsed(self) -> Dict: + return self._functionNotParsed + def _analyze_type(self): """ Analyz the type of the function @@ -125,61 +181,61 @@ class FunctionSolc(Function): if self.is_compact_ast: attributes = self._functionNotParsed else: - attributes = self._functionNotParsed['attributes'] + attributes = self._functionNotParsed["attributes"] - if self._name == '': - self._function_type = FunctionType.FALLBACK + if self._function.name == "": + self._function.function_type = FunctionType.FALLBACK # 0.6.x introduced the receiver function # It has also an empty name, so we need to check the kind attribute - if 'kind' in attributes: - if attributes['kind'] == 'receive': - self._function_type = FunctionType.RECEIVE + if "kind" in attributes: + if attributes["kind"] == "receive": + self._function.function_type = FunctionType.RECEIVE else: - self._function_type = FunctionType.NORMAL + self._function.function_type = FunctionType.NORMAL - if self._name == self.contract_declarer.name: - self._function_type = FunctionType.CONSTRUCTOR + if self._function.name == self._function.contract_declarer.name: + self._function.function_type = FunctionType.CONSTRUCTOR def _analyze_attributes(self): if self.is_compact_ast: attributes = self._functionNotParsed else: - attributes = self._functionNotParsed['attributes'] - - if 'payable' in attributes: - self._payable = attributes['payable'] - if 'stateMutability' in attributes: - if attributes['stateMutability'] == 'payable': - self._payable = True - elif attributes['stateMutability'] == 'pure': - self._pure = True - self._view = True - elif attributes['stateMutability'] == 'view': - self._view = True - - if 'constant' in attributes: - self._view = attributes['constant'] - - if 'isConstructor' in attributes and attributes['isConstructor']: - self._function_type = FunctionType.CONSTRUCTOR - - if 'kind' in attributes: - if attributes['kind'] == 'constructor': - self._function_type = FunctionType.CONSTRUCTOR - - if 'visibility' in attributes: - self._visibility = attributes['visibility'] + attributes = self._functionNotParsed["attributes"] + + if "payable" in attributes: + self._function.payable = attributes["payable"] + if "stateMutability" in attributes: + if attributes["stateMutability"] == "payable": + self._function.payable = True + elif attributes["stateMutability"] == "pure": + self._function.pure = True + self._function.view = True + elif attributes["stateMutability"] == "view": + self._function.view = True + + if "constant" in attributes: + self._function.view = attributes["constant"] + + if "isConstructor" in attributes and attributes["isConstructor"]: + self._function.function_type = FunctionType.CONSTRUCTOR + + if "kind" in attributes: + if attributes["kind"] == "constructor": + self._function.function_type = FunctionType.CONSTRUCTOR + + if "visibility" in attributes: + self._function.visibility = attributes["visibility"] # old solc - elif 'public' in attributes: - if attributes['public']: - self._visibility = 'public' + elif "public" in attributes: + if attributes["public"]: + self._function.visibility = "public" else: - self._visibility = 'private' + self._function.visibility = "private" else: - self._visibility = 'public' + self._function.visibility = "public" - if 'payable' in attributes: - self._payable = attributes['payable'] + if "payable" in attributes: + self._function.payable = attributes["payable"] def analyze_params(self): # Can be re-analyzed due to inheritance @@ -191,10 +247,10 @@ class FunctionSolc(Function): self._analyze_attributes() if self.is_compact_ast: - params = self._functionNotParsed['parameters'] - returns = self._functionNotParsed['returnParameters'] + params = self._functionNotParsed["parameters"] + returns = self._functionNotParsed["returnParameters"] else: - children = self._functionNotParsed[self.get_children('children')] + children = self._functionNotParsed[self.get_children("children")] params = children[0] returns = children[1] @@ -210,42 +266,39 @@ class FunctionSolc(Function): self._content_was_analyzed = True if self.is_compact_ast: - body = self._functionNotParsed['body'] + body = self._functionNotParsed["body"] - if body and body[self.get_key()] == 'Block': - self._is_implemented = True + if body and body[self.get_key()] == "Block": + self._function.is_implemented = True self._parse_cfg(body) - for modifier in self._functionNotParsed['modifiers']: + for modifier in self._functionNotParsed["modifiers"]: self._parse_modifier(modifier) else: - children = self._functionNotParsed[self.get_children('children')] - self._is_implemented = False + children = self._functionNotParsed[self.get_children("children")] + self._function.is_implemented = False for child in children[2:]: - if child[self.get_key()] == 'Block': - self._is_implemented = True + if child[self.get_key()] == "Block": + self._function.is_implemented = True self._parse_cfg(child) # Parse modifier after parsing all the block # In the case a local variable is used in the modifier for child in children[2:]: - if child[self.get_key()] == 'ModifierInvocation': + if child[self.get_key()] == "ModifierInvocation": self._parse_modifier(child) - for local_vars in self.variables: - local_vars.analyze(self) + for local_var_parser in self._local_variables_parser: + local_var_parser.analyze(self) - for node in self.nodes: - node.analyze_expressions(self) + for node_parser in self._node_to_nodesolc.values(): + node_parser.analyze_expressions(self) self._filter_ternary() self._remove_alone_endif() - - - # endregion ################################################################################### ################################################################################### @@ -253,13 +306,11 @@ class FunctionSolc(Function): ################################################################################### ################################################################################### - def _new_node(self, node_type, src): - node = NodeSolc(node_type, self._counter_nodes) - node.set_offset(src, self.slither) - self._counter_nodes += 1 - node.set_function(self) - self._nodes.append(node) - return node + def _new_node(self, node_type: NodeType, src: Union[str, Dict]) -> NodeSolc: + node = self._function.new_node(node_type, src) + node_parser = NodeSolc(node) + self._node_to_nodesolc[node] = node_parser + return node_parser # endregion ################################################################################### @@ -268,108 +319,108 @@ class FunctionSolc(Function): ################################################################################### ################################################################################### - def _parse_if(self, ifStatement, node): + def _parse_if(self, if_statement: Dict, node: NodeSolc) -> NodeSolc: # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? falseStatement = None if self.is_compact_ast: - condition = ifStatement['condition'] + condition = if_statement["condition"] # Note: check if the expression could be directly # parsed here - condition_node = self._new_node(NodeType.IF, condition['src']) + condition_node = self._new_node(NodeType.IF, condition["src"]) condition_node.add_unparsed_expression(condition) - link_nodes(node, condition_node) - trueStatement = self._parse_statement(ifStatement['trueBody'], condition_node) - if ifStatement['falseBody']: - falseStatement = self._parse_statement(ifStatement['falseBody'], condition_node) + link_underlying_nodes(node, condition_node) + trueStatement = self._parse_statement(if_statement["trueBody"], condition_node) + if if_statement["falseBody"]: + falseStatement = self._parse_statement(if_statement["falseBody"], condition_node) else: - children = ifStatement[self.get_children('children')] + children = if_statement[self.get_children("children")] condition = children[0] # Note: check if the expression could be directly # parsed here - condition_node = self._new_node(NodeType.IF, condition['src']) + condition_node = self._new_node(NodeType.IF, condition["src"]) condition_node.add_unparsed_expression(condition) - link_nodes(node, condition_node) + link_underlying_nodes(node, condition_node) trueStatement = self._parse_statement(children[1], condition_node) if len(children) == 3: falseStatement = self._parse_statement(children[2], condition_node) - endIf_node = self._new_node(NodeType.ENDIF, ifStatement['src']) - link_nodes(trueStatement, endIf_node) + endIf_node = self._new_node(NodeType.ENDIF, if_statement["src"]) + link_underlying_nodes(trueStatement, endIf_node) if falseStatement: - link_nodes(falseStatement, endIf_node) + link_underlying_nodes(falseStatement, endIf_node) else: - link_nodes(condition_node, endIf_node) + link_underlying_nodes(condition_node, endIf_node) return endIf_node - def _parse_while(self, whileStatement, node): + def _parse_while(self, whilte_statement: Dict, node: NodeSolc) -> NodeSolc: # WhileStatement = 'while' '(' Expression ')' Statement - node_startWhile = self._new_node(NodeType.STARTLOOP, whileStatement['src']) + node_startWhile = self._new_node(NodeType.STARTLOOP, whilte_statement["src"]) if self.is_compact_ast: - node_condition = self._new_node(NodeType.IFLOOP, whileStatement['condition']['src']) - node_condition.add_unparsed_expression(whileStatement['condition']) - statement = self._parse_statement(whileStatement['body'], node_condition) + node_condition = self._new_node(NodeType.IFLOOP, whilte_statement["condition"]["src"]) + node_condition.add_unparsed_expression(whilte_statement["condition"]) + statement = self._parse_statement(whilte_statement["body"], node_condition) else: - children = whileStatement[self.get_children('children')] + children = whilte_statement[self.get_children("children")] expression = children[0] - node_condition = self._new_node(NodeType.IFLOOP, expression['src']) + node_condition = self._new_node(NodeType.IFLOOP, expression["src"]) node_condition.add_unparsed_expression(expression) statement = self._parse_statement(children[1], node_condition) - node_endWhile = self._new_node(NodeType.ENDLOOP, whileStatement['src']) + node_endWhile = self._new_node(NodeType.ENDLOOP, whilte_statement["src"]) - link_nodes(node, node_startWhile) - link_nodes(node_startWhile, node_condition) - link_nodes(statement, node_condition) - link_nodes(node_condition, node_endWhile) + link_underlying_nodes(node, node_startWhile) + link_underlying_nodes(node_startWhile, node_condition) + link_underlying_nodes(statement, node_condition) + link_underlying_nodes(node_condition, node_endWhile) return node_endWhile - def _parse_for_compact_ast(self, statement, node): - body = statement['body'] - init_expression = statement['initializationExpression'] - condition = statement['condition'] - loop_expression = statement['loopExpression'] + def _parse_for_compact_ast(self, statement: Dict, node: NodeSolc) -> NodeSolc: + body = statement["body"] + init_expression = statement["initializationExpression"] + condition = statement["condition"] + loop_expression = statement["loopExpression"] - node_startLoop = self._new_node(NodeType.STARTLOOP, statement['src']) - node_endLoop = self._new_node(NodeType.ENDLOOP, statement['src']) + node_startLoop = self._new_node(NodeType.STARTLOOP, statement["src"]) + node_endLoop = self._new_node(NodeType.ENDLOOP, statement["src"]) if init_expression: node_init_expression = self._parse_statement(init_expression, node) - link_nodes(node_init_expression, node_startLoop) + link_underlying_nodes(node_init_expression, node_startLoop) else: - link_nodes(node, node_startLoop) + link_underlying_nodes(node, node_startLoop) if condition: - node_condition = self._new_node(NodeType.IFLOOP, condition['src']) + node_condition = self._new_node(NodeType.IFLOOP, condition["src"]) node_condition.add_unparsed_expression(condition) - link_nodes(node_startLoop, node_condition) - link_nodes(node_condition, node_endLoop) + link_underlying_nodes(node_startLoop, node_condition) + link_underlying_nodes(node_condition, node_endLoop) else: node_condition = node_startLoop node_body = self._parse_statement(body, node_condition) + node_LoopExpression = None if loop_expression: node_LoopExpression = self._parse_statement(loop_expression, node_body) - link_nodes(node_LoopExpression, node_condition) + link_underlying_nodes(node_LoopExpression, node_condition) else: - link_nodes(node_body, node_condition) + link_underlying_nodes(node_body, node_condition) if not condition: if not loop_expression: # TODO: fix case where loop has no expression - link_nodes(node_startLoop, node_endLoop) - else: - link_nodes(node_LoopExpression, node_endLoop) + link_underlying_nodes(node_startLoop, node_endLoop) + elif node_LoopExpression: + link_underlying_nodes(node_LoopExpression, node_endLoop) return node_endLoop - - def _parse_for(self, statement, node): + def _parse_for(self, statement: Dict, node: NodeSolc) -> NodeSolc: # ForStatement = 'for' '(' (SimpleStatement)? ';' (Expression)? ';' (ExpressionStatement)? ')' Statement # the handling of loop in the legacy ast is too complex @@ -385,49 +436,50 @@ class FunctionSolc(Function): # Old solc version do not prevent in the attributes # if the loop has a init value /condition or expression # There is no way to determine that for(a;;) and for(;a;) are different with old solc - if 'attributes' in statement: - attributes = statement['attributes'] - if 'initializationExpression' in statement: - if not statement['initializationExpression']: + if "attributes" in statement: + attributes = statement["attributes"] + if "initializationExpression" in statement: + if not statement["initializationExpression"]: hasInitExession = False - elif 'initializationExpression' in attributes: - if not attributes['initializationExpression']: + elif "initializationExpression" in attributes: + if not attributes["initializationExpression"]: hasInitExession = False - if 'condition' in statement: - if not statement['condition']: + if "condition" in statement: + if not statement["condition"]: hasCondition = False - elif 'condition' in attributes: - if not attributes['condition']: + elif "condition" in attributes: + if not attributes["condition"]: hasCondition = False - if 'loopExpression' in statement: - if not statement['loopExpression']: + if "loopExpression" in statement: + if not statement["loopExpression"]: hasLoopExpression = False - elif 'loopExpression' in attributes: - if not attributes['loopExpression']: + elif "loopExpression" in attributes: + if not attributes["loopExpression"]: hasLoopExpression = False + node_startLoop = self._new_node(NodeType.STARTLOOP, statement["src"]) + node_endLoop = self._new_node(NodeType.ENDLOOP, statement["src"]) - node_startLoop = self._new_node(NodeType.STARTLOOP, statement['src']) - node_endLoop = self._new_node(NodeType.ENDLOOP, statement['src']) - - children = statement[self.get_children('children')] + children = statement[self.get_children("children")] if hasInitExession: if len(children) >= 2: - if children[0][self.get_key()] in ['VariableDefinitionStatement', - 'VariableDeclarationStatement', - 'ExpressionStatement']: + if children[0][self.get_key()] in [ + "VariableDefinitionStatement", + "VariableDeclarationStatement", + "ExpressionStatement", + ]: node_initExpression = self._parse_statement(children[0], node) - link_nodes(node_initExpression, node_startLoop) + link_underlying_nodes(node_initExpression, node_startLoop) else: hasInitExession = False else: hasInitExession = False if not hasInitExession: - link_nodes(node, node_startLoop) + link_underlying_nodes(node, node_startLoop) node_condition = node_startLoop if hasCondition: @@ -435,214 +487,222 @@ class FunctionSolc(Function): candidate = children[1] else: candidate = children[0] - if candidate[self.get_key()] not in ['VariableDefinitionStatement', - 'VariableDeclarationStatement', - 'ExpressionStatement']: + if candidate[self.get_key()] not in [ + "VariableDefinitionStatement", + "VariableDeclarationStatement", + "ExpressionStatement", + ]: expression = candidate - node_condition = self._new_node(NodeType.IFLOOP, expression['src']) - #expression = parse_expression(candidate, self) + node_condition = self._new_node(NodeType.IFLOOP, expression["src"]) + # expression = parse_expression(candidate, self) node_condition.add_unparsed_expression(expression) - link_nodes(node_startLoop, node_condition) - link_nodes(node_condition, node_endLoop) + link_underlying_nodes(node_startLoop, node_condition) + link_underlying_nodes(node_condition, node_endLoop) hasCondition = True else: hasCondition = False - node_statement = self._parse_statement(children[-1], node_condition) node_LoopExpression = node_statement if hasLoopExpression: if len(children) > 2: - if children[-2][self.get_key()] == 'ExpressionStatement': + if children[-2][self.get_key()] == "ExpressionStatement": node_LoopExpression = self._parse_statement(children[-2], node_statement) if not hasCondition: - link_nodes(node_LoopExpression, node_endLoop) + link_underlying_nodes(node_LoopExpression, node_endLoop) if not hasCondition and not hasLoopExpression: - link_nodes(node, node_endLoop) + link_underlying_nodes(node, node_endLoop) - link_nodes(node_LoopExpression, node_condition) + link_underlying_nodes(node_LoopExpression, node_condition) return node_endLoop - def _parse_dowhile(self, doWhilestatement, node): + def _parse_dowhile(self, do_while_statement: Dict, node: NodeSolc) -> NodeSolc: - node_startDoWhile = self._new_node(NodeType.STARTLOOP, doWhilestatement['src']) + node_startDoWhile = self._new_node(NodeType.STARTLOOP, do_while_statement["src"]) if self.is_compact_ast: - node_condition = self._new_node(NodeType.IFLOOP, doWhilestatement['condition']['src']) - node_condition.add_unparsed_expression(doWhilestatement['condition']) - statement = self._parse_statement(doWhilestatement['body'], node_condition) + node_condition = self._new_node(NodeType.IFLOOP, do_while_statement["condition"]["src"]) + node_condition.add_unparsed_expression(do_while_statement["condition"]) + statement = self._parse_statement(do_while_statement["body"], node_condition) else: - children = doWhilestatement[self.get_children('children')] + children = do_while_statement[self.get_children("children")] # same order in the AST as while expression = children[0] - node_condition = self._new_node(NodeType.IFLOOP, expression['src']) + node_condition = self._new_node(NodeType.IFLOOP, expression["src"]) node_condition.add_unparsed_expression(expression) statement = self._parse_statement(children[1], node_condition) - node_endDoWhile = self._new_node(NodeType.ENDLOOP, doWhilestatement['src']) + node_endDoWhile = self._new_node(NodeType.ENDLOOP, do_while_statement["src"]) - link_nodes(node, node_startDoWhile) + link_underlying_nodes(node, node_startDoWhile) # empty block, loop from the start to the condition - if not node_condition.sons: - link_nodes(node_startDoWhile, node_condition) + if not node_condition.underlying_node.sons: + link_underlying_nodes(node_startDoWhile, node_condition) else: - link_nodes(node_startDoWhile, node_condition.sons[0]) - link_nodes(statement, node_condition) - link_nodes(node_condition, node_endDoWhile) + link_nodes(node_startDoWhile.underlying_node, node_condition.underlying_node.sons[0]) + link_underlying_nodes(statement, node_condition) + link_underlying_nodes(node_condition, node_endDoWhile) return node_endDoWhile - def _parse_try_catch(self, statement, node): - externalCall = statement.get('externalCall', None) + def _parse_try_catch(self, statement: Dict, node: NodeSolc) -> NodeSolc: + externalCall = statement.get("externalCall", None) if externalCall is None: - raise ParsingError('Try/Catch not correctly parsed by Slither %s' % statement) + raise ParsingError("Try/Catch not correctly parsed by Slither %s" % statement) - new_node = self._new_node(NodeType.TRY, statement['src']) + new_node = self._new_node(NodeType.TRY, statement["src"]) new_node.add_unparsed_expression(externalCall) - link_nodes(node, new_node) + link_underlying_nodes(node, new_node) node = new_node - for clause in statement.get('clauses', []): + for clause in statement.get("clauses", []): self._parse_catch(clause, node) return node - def _parse_catch(self, statement, node): - block = statement.get('block', None) + def _parse_catch(self, statement: Dict, node: NodeSolc) -> NodeSolc: + block = statement.get("block", None) if block is None: - raise ParsingError('Catch not correctly parsed by Slither %s' % statement) - try_node = self._new_node(NodeType.CATCH, statement['src']) - link_nodes(node, try_node) + raise ParsingError("Catch not correctly parsed by Slither %s" % statement) + try_node = self._new_node(NodeType.CATCH, statement["src"]) + link_underlying_nodes(node, try_node) if self.is_compact_ast: - params = statement['parameters'] + params = statement["parameters"] else: - params = statement[self.get_children('children')] + params = statement[self.get_children("children")] if params: - for param in params.get('parameters', []): - assert param[self.get_key()] == 'VariableDeclaration' + for param in params.get("parameters", []): + assert param[self.get_key()] == "VariableDeclaration" self._add_param(param) return self._parse_statement(block, try_node) - def _parse_variable_definition(self, statement, node): + def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSolc: try: - local_var = LocalVariableSolc(statement) - local_var.set_function(self) - local_var.set_offset(statement['src'], self.contract.slither) + local_var = LocalVariable() + local_var.set_function(self._function) + local_var.set_offset(statement["src"], self._function.slither) - self._add_local_variable(local_var) - #local_var.analyze(self) + local_var_parser = LocalVariableSolc(local_var, statement) + self._add_local_variable(local_var_parser) + # local_var.analyze(self) - new_node = self._new_node(NodeType.VARIABLE, statement['src']) - new_node.add_variable_declaration(local_var) - link_nodes(node, new_node) + new_node = self._new_node(NodeType.VARIABLE, statement["src"]) + new_node.underlying_node.add_variable_declaration(local_var) + link_underlying_nodes(node, new_node) return new_node except MultipleVariablesDeclaration: # Custom handling of var (a,b) = .. style declaration if self.is_compact_ast: - variables = statement['declarations'] + variables = statement["declarations"] count = len(variables) - if statement['initialValue']['nodeType'] == 'TupleExpression' and \ - len(statement['initialValue']['components']) == count: - inits = statement['initialValue']['components'] + if ( + statement["initialValue"]["nodeType"] == "TupleExpression" + and len(statement["initialValue"]["components"]) == count + ): + inits = statement["initialValue"]["components"] i = 0 new_node = node for variable in variables: init = inits[i] - src = variable['src'] - i = i+1 - - new_statement = {'nodeType':'VariableDefinitionStatement', - 'src': src, - 'declarations':[variable], - 'initialValue':init} + src = variable["src"] + i = i + 1 + + new_statement = { + "nodeType": "VariableDefinitionStatement", + "src": src, + "declarations": [variable], + "initialValue": init, + } new_node = self._parse_variable_definition(new_statement, new_node) else: # If we have # var (a, b) = f() # we can split in multiple declarations, without init - # Then we craft one expression that does the assignment + # Then we craft one expression that does the assignment variables = [] i = 0 new_node = node - for variable in statement['declarations']: - i = i+1 + for variable in statement["declarations"]: + i = i + 1 if variable: - src = variable['src'] + src = variable["src"] # Create a fake statement to be consistent - new_statement = {'nodeType':'VariableDefinitionStatement', - 'src': src, - 'declarations':[variable]} + new_statement = { + "nodeType": "VariableDefinitionStatement", + "src": src, + "declarations": [variable], + } variables.append(variable) - new_node = self._parse_variable_definition_init_tuple(new_statement, - i, - new_node) + new_node = self._parse_variable_definition_init_tuple( + new_statement, i, new_node + ) var_identifiers = [] # craft of the expression doing the assignement for v in variables: identifier = { - 'nodeType':'Identifier', - 'src': v['src'], - 'name': v['name'], - 'typeDescriptions': { - 'typeString':v['typeDescriptions']['typeString'] - } + "nodeType": "Identifier", + "src": v["src"], + "name": v["name"], + "typeDescriptions": {"typeString": v["typeDescriptions"]["typeString"]}, } var_identifiers.append(identifier) - tuple_expression = {'nodeType':'TupleExpression', - 'src': statement['src'], - 'components':var_identifiers} + tuple_expression = { + "nodeType": "TupleExpression", + "src": statement["src"], + "components": var_identifiers, + } expression = { - 'nodeType' : 'Assignment', - 'src':statement['src'], - 'operator': '=', - 'type':'tuple()', - 'leftHandSide': tuple_expression, - 'rightHandSide': statement['initialValue'], - 'typeDescriptions': {'typeString':'tuple()'} - } + "nodeType": "Assignment", + "src": statement["src"], + "operator": "=", + "type": "tuple()", + "leftHandSide": tuple_expression, + "rightHandSide": statement["initialValue"], + "typeDescriptions": {"typeString": "tuple()"}, + } node = new_node - new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + new_node = self._new_node(NodeType.EXPRESSION, statement["src"]) new_node.add_unparsed_expression(expression) - link_nodes(node, new_node) - + link_underlying_nodes(node, new_node) else: count = 0 - children = statement[self.get_children('children')] + children = statement[self.get_children("children")] child = children[0] - while child[self.get_key()] == 'VariableDeclaration': - count = count +1 + while child[self.get_key()] == "VariableDeclaration": + count = count + 1 child = children[count] assert len(children) == (count + 1) tuple_vars = children[count] - variables_declaration = children[0:count] i = 0 new_node = node - if tuple_vars[self.get_key()] == 'TupleExpression': - assert len(tuple_vars[self.get_children('children')]) == count + if tuple_vars[self.get_key()] == "TupleExpression": + assert len(tuple_vars[self.get_children("children")]) == count for variable in variables_declaration: - init = tuple_vars[self.get_children('children')][i] - src = variable['src'] - i = i+1 + init = tuple_vars[self.get_children("children")][i] + src = variable["src"] + i = i + 1 # Create a fake statement to be consistent - new_statement = {self.get_key():'VariableDefinitionStatement', - 'src': src, - self.get_children('children'):[variable, init]} + new_statement = { + self.get_key(): "VariableDefinitionStatement", + "src": src, + self.get_children("children"): [variable, init], + } new_node = self._parse_variable_definition(new_statement, new_node) else: @@ -650,64 +710,72 @@ class FunctionSolc(Function): # var (a, b) = f() # we can split in multiple declarations, without init # Then we craft one expression that does the assignment - assert tuple_vars[self.get_key()] in ['FunctionCall', 'Conditional'] + assert tuple_vars[self.get_key()] in ["FunctionCall", "Conditional"] variables = [] for variable in variables_declaration: - src = variable['src'] - i = i+1 + src = variable["src"] + i = i + 1 # Create a fake statement to be consistent - new_statement = {self.get_key():'VariableDefinitionStatement', - 'src': src, - self.get_children('children'):[variable]} + new_statement = { + self.get_key(): "VariableDefinitionStatement", + "src": src, + self.get_children("children"): [variable], + } variables.append(variable) - new_node = self._parse_variable_definition_init_tuple(new_statement, i, new_node) + new_node = self._parse_variable_definition_init_tuple( + new_statement, i, new_node + ) var_identifiers = [] # craft of the expression doing the assignement for v in variables: identifier = { - self.get_key() : 'Identifier', - 'src': v['src'], - 'attributes': { - 'value': v['attributes'][self.get_key()], - 'type': v['attributes']['type']} + self.get_key(): "Identifier", + "src": v["src"], + "attributes": { + "value": v["attributes"][self.get_key()], + "type": v["attributes"]["type"], + }, } var_identifiers.append(identifier) expression = { - self.get_key() : 'Assignment', - 'src':statement['src'], - 'attributes': {'operator': '=', - 'type':'tuple()'}, - self.get_children('children'): - [{self.get_key(): 'TupleExpression', - 'src': statement['src'], - self.get_children('children'): var_identifiers}, - tuple_vars]} + self.get_key(): "Assignment", + "src": statement["src"], + "attributes": {"operator": "=", "type": "tuple()"}, + self.get_children("children"): [ + { + self.get_key(): "TupleExpression", + "src": statement["src"], + self.get_children("children"): var_identifiers, + }, + tuple_vars, + ], + } node = new_node - new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + new_node = self._new_node(NodeType.EXPRESSION, statement["src"]) new_node.add_unparsed_expression(expression) - link_nodes(node, new_node) - + link_underlying_nodes(node, new_node) return new_node - def _parse_variable_definition_init_tuple(self, statement, index, node): - local_var = LocalVariableInitFromTupleSolc(statement, index) - #local_var = LocalVariableSolc(statement[self.get_children('children')][0], statement[self.get_children('children')][1::]) - local_var.set_function(self) - local_var.set_offset(statement['src'], self.contract.slither) + def _parse_variable_definition_init_tuple( + self, statement: Dict, index: int, node: NodeSolc + ) -> NodeSolc: + local_var = LocalVariableInitFromTuple() + local_var.set_function(self._function) + local_var.set_offset(statement["src"], self._function.slither) - self._add_local_variable(local_var) -# local_var.analyze(self) + local_var_parser = LocalVariableInitFromTupleSolc(local_var, statement, index) - new_node = self._new_node(NodeType.VARIABLE, statement['src']) - new_node.add_variable_declaration(local_var) - link_nodes(node, new_node) - return new_node + self._add_local_variable(local_var_parser) + new_node = self._new_node(NodeType.VARIABLE, statement["src"]) + new_node.underlying_node.add_variable_declaration(local_var) + link_underlying_nodes(node, new_node) + return new_node - def _parse_statement(self, statement, node): + def _parse_statement(self, statement: Dict, node: NodeSolc) -> NodeSolc: """ Return: @@ -720,115 +788,118 @@ class FunctionSolc(Function): name = statement[self.get_key()] # SimpleStatement = VariableDefinition | ExpressionStatement - if name == 'IfStatement': + if name == "IfStatement": node = self._parse_if(statement, node) - elif name == 'WhileStatement': + elif name == "WhileStatement": node = self._parse_while(statement, node) - elif name == 'ForStatement': + elif name == "ForStatement": node = self._parse_for(statement, node) - elif name == 'Block': + elif name == "Block": node = self._parse_block(statement, node) - elif name == 'InlineAssembly': - asm_node = self._new_node(NodeType.ASSEMBLY, statement['src']) - self._contains_assembly = True + elif name == "InlineAssembly": + asm_node = self._new_node(NodeType.ASSEMBLY, statement["src"]) + self._function.contains_assembly = True # Added with solc 0.4.12 - if 'operations' in statement: - asm_node.add_inline_asm(statement['operations']) - link_nodes(node, asm_node) + if "operations" in statement: + asm_node.underlying_node.add_inline_asm(statement["operations"]) + link_underlying_nodes(node, asm_node) node = asm_node - elif name == 'DoWhileStatement': + elif name == "DoWhileStatement": node = self._parse_dowhile(statement, node) # For Continue / Break / Return / Throw # The is fixed later - elif name == 'Continue': - continue_node = self._new_node(NodeType.CONTINUE, statement['src']) - link_nodes(node, continue_node) + elif name == "Continue": + continue_node = self._new_node(NodeType.CONTINUE, statement["src"]) + link_underlying_nodes(node, continue_node) node = continue_node - elif name == 'Break': - break_node = self._new_node(NodeType.BREAK, statement['src']) - link_nodes(node, break_node) + elif name == "Break": + break_node = self._new_node(NodeType.BREAK, statement["src"]) + link_underlying_nodes(node, break_node) node = break_node - elif name == 'Return': - return_node = self._new_node(NodeType.RETURN, statement['src']) - link_nodes(node, return_node) + elif name == "Return": + return_node = self._new_node(NodeType.RETURN, statement["src"]) + link_underlying_nodes(node, return_node) if self.is_compact_ast: - if statement['expression']: - return_node.add_unparsed_expression(statement['expression']) + if statement["expression"]: + return_node.add_unparsed_expression(statement["expression"]) else: - if self.get_children('children') in statement and statement[self.get_children('children')]: - assert len(statement[self.get_children('children')]) == 1 - expression = statement[self.get_children('children')][0] + if ( + self.get_children("children") in statement + and statement[self.get_children("children")] + ): + assert len(statement[self.get_children("children")]) == 1 + expression = statement[self.get_children("children")][0] return_node.add_unparsed_expression(expression) node = return_node - elif name == 'Throw': - throw_node = self._new_node(NodeType.THROW, statement['src']) - link_nodes(node, throw_node) + elif name == "Throw": + throw_node = self._new_node(NodeType.THROW, statement["src"]) + link_underlying_nodes(node, throw_node) node = throw_node - elif name == 'EmitStatement': - #expression = parse_expression(statement[self.get_children('children')][0], self) + elif name == "EmitStatement": + # expression = parse_expression(statement[self.get_children('children')][0], self) if self.is_compact_ast: - expression = statement['eventCall'] + expression = statement["eventCall"] else: - expression = statement[self.get_children('children')][0] - new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + expression = statement[self.get_children("children")][0] + new_node = self._new_node(NodeType.EXPRESSION, statement["src"]) new_node.add_unparsed_expression(expression) - link_nodes(node, new_node) + link_underlying_nodes(node, new_node) node = new_node - elif name in ['VariableDefinitionStatement', 'VariableDeclarationStatement']: + elif name in ["VariableDefinitionStatement", "VariableDeclarationStatement"]: node = self._parse_variable_definition(statement, node) - elif name == 'ExpressionStatement': - #assert len(statement[self.get_children('expression')]) == 1 - #assert not 'attributes' in statement - #expression = parse_expression(statement[self.get_children('children')][0], self) + elif name == "ExpressionStatement": + # assert len(statement[self.get_children('expression')]) == 1 + # assert not 'attributes' in statement + # expression = parse_expression(statement[self.get_children('children')][0], self) if self.is_compact_ast: - expression = statement[self.get_children('expression')] + expression = statement[self.get_children("expression")] else: - expression = statement[self.get_children('expression')][0] - new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + expression = statement[self.get_children("expression")][0] + new_node = self._new_node(NodeType.EXPRESSION, statement["src"]) new_node.add_unparsed_expression(expression) - link_nodes(node, new_node) + link_underlying_nodes(node, new_node) node = new_node - elif name == 'TryStatement': + elif name == "TryStatement": node = self._parse_try_catch(statement, node) # elif name == 'TryCatchClause': # self._parse_catch(statement, node) else: - raise ParsingError('Statement not parsed %s' % name) + raise ParsingError("Statement not parsed %s" % name) return node - def _parse_block(self, block, node): - ''' + def _parse_block(self, block: Dict, node: NodeSolc): + """ Return: Node - ''' - assert block[self.get_key()] == 'Block' + """ + assert block[self.get_key()] == "Block" if self.is_compact_ast: - statements = block['statements'] + statements = block["statements"] else: - statements = block[self.get_children('children')] + statements = block[self.get_children("children")] for statement in statements: node = self._parse_statement(statement, node) return node - def _parse_cfg(self, cfg): + def _parse_cfg(self, cfg: Dict): - assert cfg[self.get_key()] == 'Block' + assert cfg[self.get_key()] == "Block" - node = self._new_node(NodeType.ENTRYPOINT, cfg['src']) - self._entry_point = node + node = self._new_node(NodeType.ENTRYPOINT, cfg["src"]) + self._function.entry_point = node.underlying_node if self.is_compact_ast: - statements = cfg['statements'] + statements = cfg["statements"] else: - statements = cfg[self.get_children('children')] + statements = cfg[self.get_children("children")] if not statements: - self._is_empty = True + self._function.is_empty = True else: - self._is_empty = False + self._function.is_empty = False self._parse_block(cfg, node) self._remove_incorrect_edges() self._remove_alone_endif() @@ -840,7 +911,7 @@ class FunctionSolc(Function): ################################################################################### ################################################################################### - def _find_end_loop(self, node, visited, counter): + def _find_end_loop(self, node: Node, visited: List[Node], counter: int) -> Optional[Node]: # counter allows to explore nested loop if node in visited: return None @@ -862,7 +933,7 @@ class FunctionSolc(Function): return None - def _find_start_loop(self, node, visited): + def _find_start_loop(self, node: Node, visited: List[Node]) -> Optional[Node]: if node in visited: return None @@ -877,7 +948,7 @@ class FunctionSolc(Function): return None - def _fix_break_node(self, node): + def _fix_break_node(self, node: Node): end_node = self._find_end_loop(node, [], 0) if not end_node: @@ -886,122 +957,130 @@ class FunctionSolc(Function): # We start with -1 as counter to catch this corner case end_node = self._find_end_loop(node, [], -1) if not end_node: - raise ParsingError('Break in no-loop context {}'.format(node.function)) + raise ParsingError("Break in no-loop context {}".format(node.function)) for son in node.sons: son.remove_father(node) node.set_sons([end_node]) end_node.add_father(node) - def _fix_continue_node(self, node): + def _fix_continue_node(self, node: Node): start_node = self._find_start_loop(node, []) if not start_node: - raise ParsingError('Continue in no-loop context {}'.format(node.nodeId())) + raise ParsingError("Continue in no-loop context {}".format(node.node_id)) for son in node.sons: son.remove_father(node) node.set_sons([start_node]) start_node.add_father(node) - def _fix_try(self, node): + def _fix_try(self, node: Node): end_node = next((son for son in node.sons if son.type != NodeType.CATCH), None) if end_node: for son in node.sons: if son.type == NodeType.CATCH: self._fix_catch(son, end_node) - def _fix_catch(self, node, end_node): + def _fix_catch(self, node: Node, end_node: Node): if not node.sons: link_nodes(node, end_node) else: for son in node.sons: self._fix_catch(son, end_node) - def _add_param(self, param): - local_var = LocalVariableSolc(param) + def _add_param(self, param: Dict) -> LocalVariableSolc: + + local_var = LocalVariable() + local_var.set_function(self._function) + local_var.set_offset(param["src"], self._function.slither) - local_var.set_function(self) - local_var.set_offset(param['src'], self.contract.slither) - local_var.analyze(self) + local_var_parser = LocalVariableSolc(local_var, param) + + local_var_parser.analyze(self) # see https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location - if local_var.location == 'default': - local_var.set_location('memory') + if local_var.location == "default": + local_var.set_location("memory") - self._add_local_variable(local_var) - return local_var + self._add_local_variable(local_var_parser) + return local_var_parser + def _parse_params(self, params: Dict): + assert params[self.get_key()] == "ParameterList" - def _parse_params(self, params): - assert params[self.get_key()] == 'ParameterList' + self.parameters_src.set_offset(params["src"], self._function.slither) - self.parameters_src = SourceMapping() - self.parameters_src.set_offset(params['src'], self.contract.slither) - if self.is_compact_ast: - params = params['parameters'] + params = params["parameters"] else: - params = params[self.get_children('children')] + params = params[self.get_children("children")] for param in params: - assert param[self.get_key()] == 'VariableDeclaration' + assert param[self.get_key()] == "VariableDeclaration" local_var = self._add_param(param) - self._parameters.append(local_var) + self._function.add_parameters(local_var.underlying_variable) + def _parse_returns(self, returns: Dict): - def _parse_returns(self, returns): + assert returns[self.get_key()] == "ParameterList" - assert returns[self.get_key()] == 'ParameterList' + self.returns_src.set_offset(returns["src"], self._function.slither) - self.returns_src = SourceMapping() - self.returns_src.set_offset(returns['src'], self.contract.slither) - if self.is_compact_ast: - returns = returns['parameters'] + returns = returns["parameters"] else: - returns = returns[self.get_children('children')] + returns = returns[self.get_children("children")] for ret in returns: - assert ret[self.get_key()] == 'VariableDeclaration' + assert ret[self.get_key()] == "VariableDeclaration" local_var = self._add_param(ret) - self._returns.append(local_var) + self._function.add_return(local_var.underlying_variable) - - def _parse_modifier(self, modifier): + def _parse_modifier(self, modifier: Dict): m = parse_expression(modifier, self) - self._expression_modifiers.append(m) + # self._expression_modifiers.append(m) # Do not parse modifier nodes for interfaces - if not self._is_implemented: + if not self._function.is_implemented: return for m in ExportValues(m).result(): if isinstance(m, Function): - node = self._new_node(NodeType.EXPRESSION, modifier['src']) - node.add_unparsed_expression(modifier) + node_parser = self._new_node(NodeType.EXPRESSION, modifier["src"]) + node_parser.add_unparsed_expression(modifier) # The latest entry point is the entry point, or the latest modifier call - if self._modifiers: - latest_entry_point = self._modifiers[-1].nodes[-1] + if self._function.modifiers: + latest_entry_point = self._function.modifiers[-1].nodes[-1] else: - latest_entry_point = self.entry_point - insert_node(latest_entry_point, node) - self._modifiers.append(ModifierStatements(modifier=m, - entry_point=latest_entry_point, - nodes=[latest_entry_point, node])) + latest_entry_point = self._function.entry_point + insert_node(latest_entry_point, node_parser.underlying_node) + self._function.add_modifier( + ModifierStatements( + modifier=m, + entry_point=latest_entry_point, + nodes=[latest_entry_point, node_parser.underlying_node], + ) + ) elif isinstance(m, Contract): - node = self._new_node(NodeType.EXPRESSION, modifier['src']) - node.add_unparsed_expression(modifier) + node_parser = self._new_node(NodeType.EXPRESSION, modifier["src"]) + node_parser.add_unparsed_expression(modifier) # The latest entry point is the entry point, or the latest constructor call - if self._explicit_base_constructor_calls: - latest_entry_point = self._explicit_base_constructor_calls[-1].nodes[-1] + if self._function.explicit_base_constructor_calls_statements: + latest_entry_point = self._function.explicit_base_constructor_calls_statements[ + -1 + ].nodes[-1] else: - latest_entry_point = self.entry_point - insert_node(latest_entry_point, node) - self._explicit_base_constructor_calls.append(ModifierStatements(modifier=m, - entry_point=latest_entry_point, - nodes=[latest_entry_point, node])) + latest_entry_point = self._function.entry_point + insert_node(latest_entry_point, node_parser.underlying_node) + self._function.add_explicit_base_constructor_calls_statements( + ModifierStatements( + modifier=m, + entry_point=latest_entry_point, + nodes=[latest_entry_point, node_parser.underlying_node], + ) + ) # endregion ################################################################################### @@ -1011,7 +1090,7 @@ class FunctionSolc(Function): ################################################################################### def _remove_incorrect_edges(self): - for node in self._nodes: + for node in self._node_to_nodesolc.keys(): if node.type in [NodeType.RETURN, NodeType.THROW]: for son in node.sons: son.remove_father(node) @@ -1043,16 +1122,16 @@ class FunctionSolc(Function): } """ prev_nodes = [] - while set(prev_nodes) != set(self.nodes): - prev_nodes = self.nodes - to_remove = [] - for node in self.nodes: + while set(prev_nodes) != set(self._node_to_nodesolc.keys()): + prev_nodes = self._node_to_nodesolc.keys() + to_remove: List[Node] = [] + for node in self._node_to_nodesolc.keys(): if node.type == NodeType.ENDIF and not node.fathers: for son in node.sons: son.remove_father(node) node.set_sons([]) to_remove.append(node) - self._nodes = [n for n in self.nodes if not n in to_remove] + self._function.nodes = [n for n in self._function.nodes if n not in to_remove] # endregion ################################################################################### @@ -1061,18 +1140,20 @@ class FunctionSolc(Function): ################################################################################### ################################################################################### - def _filter_ternary(self): + def _filter_ternary(self) -> bool: ternary_found = True updated = False while ternary_found: ternary_found = False - for node in self._nodes: + for node in self._node_to_nodesolc.keys(): has_cond = HasConditional(node.expression) if has_cond.result(): st = SplitTernaryExpression(node.expression) condition = st.condition if not condition: - raise ParsingError(f'Incorrect ternary conversion {node.expression} {node.source_mapping_str}') + raise ParsingError( + f"Incorrect ternary conversion {node.expression} {node.source_mapping_str}" + ) true_expr = st.true_expression false_expr = st.false_expression self._split_ternary_node(node, condition, true_expr, false_expr) @@ -1081,21 +1162,27 @@ class FunctionSolc(Function): break return updated - def _split_ternary_node(self, node, condition, true_expr, false_expr): + def _split_ternary_node( + self, + node: Node, + condition: "Expression", + true_expr: "Expression", + false_expr: "Expression", + ): condition_node = self._new_node(NodeType.IF, node.source_mapping) - condition_node.add_expression(condition) + condition_node.underlying_node.add_expression(condition) condition_node.analyze_expressions(self) if node.type == NodeType.VARIABLE: - condition_node.add_variable_declaration(node.variable_declaration) + condition_node.underlying_node.add_variable_declaration(node.variable_declaration) true_node = self._new_node(NodeType.EXPRESSION, node.source_mapping) if node.type == NodeType.VARIABLE: assert isinstance(true_expr, AssignmentOperation) - #true_expr = true_expr.expression_right + # true_expr = true_expr.expression_right elif node.type == NodeType.RETURN: true_node.type = NodeType.RETURN - true_node.add_expression(true_expr) + true_node.underlying_node.add_expression(true_expr) true_node.analyze_expressions(self) false_node = self._new_node(NodeType.EXPRESSION, node.source_mapping) @@ -1103,35 +1190,30 @@ class FunctionSolc(Function): assert isinstance(false_expr, AssignmentOperation) elif node.type == NodeType.RETURN: false_node.type = NodeType.RETURN - #false_expr = false_expr.expression_right - false_node.add_expression(false_expr) + # false_expr = false_expr.expression_right + false_node.underlying_node.add_expression(false_expr) false_node.analyze_expressions(self) endif_node = self._new_node(NodeType.ENDIF, node.source_mapping) for father in node.fathers: father.remove_son(node) - father.add_son(condition_node) - condition_node.add_father(father) + father.add_son(condition_node.underlying_node) + condition_node.underlying_node.add_father(father) for son in node.sons: son.remove_father(node) - son.add_father(endif_node) - endif_node.add_son(son) - - link_nodes(condition_node, true_node) - link_nodes(condition_node, false_node) - + son.add_father(endif_node.underlying_node) + endif_node.underlying_node.add_son(son) - if not true_node.type in [NodeType.THROW, NodeType.RETURN]: - link_nodes(true_node, endif_node) - if not false_node.type in [NodeType.THROW, NodeType.RETURN]: - link_nodes(false_node, endif_node) - - self._nodes = [n for n in self._nodes if n.node_id != node.node_id] + link_underlying_nodes(condition_node, true_node) + link_underlying_nodes(condition_node, false_node) + if true_node.type not in [NodeType.THROW, NodeType.RETURN]: + link_underlying_nodes(true_node, endif_node) + if false_node.type not in [NodeType.THROW, NodeType.RETURN]: + link_underlying_nodes(false_node, endif_node) + self._function.nodes = [n for n in self._function.nodes if n.node_id != node.node_id] # endregion - - diff --git a/slither/solc_parsing/declarations/modifier.py b/slither/solc_parsing/declarations/modifier.py index 5e2748c4e..99268804a 100644 --- a/slither/solc_parsing/declarations/modifier.py +++ b/slither/solc_parsing/declarations/modifier.py @@ -1,15 +1,29 @@ """ Event module """ -from slither.core.declarations.modifier import Modifier -from slither.solc_parsing.declarations.function import FunctionSolc +from typing import Dict, TYPE_CHECKING from slither.core.cfg.node import NodeType from slither.core.cfg.node import link_nodes +from slither.core.declarations.modifier import Modifier +from slither.solc_parsing.cfg.node import NodeSolc +from slither.solc_parsing.declarations.function import FunctionSolc -class ModifierSolc(Modifier, FunctionSolc): +if TYPE_CHECKING: + from slither.solc_parsing.declarations.contract import ContractSolc +class ModifierSolc(FunctionSolc): + def __init__(self, modifier: Modifier, function_data: Dict, contract_parser: "ContractSolc"): + super().__init__(modifier, function_data, contract_parser) + # _modifier is equal to _function, but keep it here to prevent + # confusion for mypy in underlying_function + self._modifier = modifier + + @property + def underlying_function(self) -> Modifier: + return self._modifier + def analyze_params(self): # Can be re-analyzed due to inheritance if self._params_was_analyzed: @@ -20,9 +34,9 @@ class ModifierSolc(Modifier, FunctionSolc): self._analyze_attributes() if self.is_compact_ast: - params = self._functionNotParsed['parameters'] + params = self._functionNotParsed["parameters"] else: - children = self._functionNotParsed['children'] + children = self._functionNotParsed["children"] params = children[0] if params: @@ -34,41 +48,40 @@ class ModifierSolc(Modifier, FunctionSolc): self._content_was_analyzed = True - if self.is_compact_ast: - body = self._functionNotParsed['body'] + body = self._functionNotParsed["body"] - if body and body[self.get_key()] == 'Block': - self._is_implemented = True + if body and body[self.get_key()] == "Block": + self._function.is_implemented = True self._parse_cfg(body) else: - children = self._functionNotParsed['children'] + children = self._functionNotParsed["children"] - self._isImplemented = False + self._function.is_implemented = False if len(children) > 1: assert len(children) == 2 block = children[1] - assert block['name'] == 'Block' - self._is_implemented = True + assert block["name"] == "Block" + self._function.is_implemented = True self._parse_cfg(block) - for local_vars in self.variables: - local_vars.analyze(self) + for local_var_parser in self._local_variables_parser: + local_var_parser.analyze(self) - for node in self.nodes: + for node in self._node_to_nodesolc.values(): node.analyze_expressions(self) self._filter_ternary() self._remove_alone_endif() - self._analyze_read_write() - self._analyze_calls() + # self._analyze_read_write() + # self._analyze_calls() - def _parse_statement(self, statement, node): + def _parse_statement(self, statement: Dict, node: NodeSolc) -> NodeSolc: name = statement[self.get_key()] - if name == 'PlaceholderStatement': - placeholder_node = self._new_node(NodeType.PLACEHOLDER, statement['src']) - link_nodes(node, placeholder_node) + if name == "PlaceholderStatement": + placeholder_node = self._new_node(NodeType.PLACEHOLDER, statement["src"]) + link_nodes(node.underlying_node, placeholder_node.underlying_node) return placeholder_node return super(ModifierSolc, self)._parse_statement(statement, node) diff --git a/slither/solc_parsing/declarations/structure.py b/slither/solc_parsing/declarations/structure.py index 825e85c98..513d2a5b8 100644 --- a/slither/solc_parsing/declarations/structure.py +++ b/slither/solc_parsing/declarations/structure.py @@ -1,34 +1,47 @@ """ Structure module """ +from typing import List, TYPE_CHECKING + +from slither.core.variables.structure_variable import StructureVariable from slither.solc_parsing.variables.structure_variable import StructureVariableSolc from slither.core.declarations.structure import Structure -class StructureSolc(Structure): +if TYPE_CHECKING: + from slither.solc_parsing.declarations.contract import ContractSolc + + +class StructureSolc: """ Structure class """ - # elems = [(type, name)] + # elems = [(type, name)] - def __init__(self, name, canonicalName, elems): - super(StructureSolc, self).__init__() - self._name = name - self._canonical_name = canonicalName - self._elems = {} - self._elems_ordered = [] + def __init__( + self, + st: Structure, + name: str, + canonicalName: str, + elems: List[str], + contract_parser: "ContractSolc", + ): + self._structure = st + st.name = name + st.canonical_name = canonicalName + self._contract_parser = contract_parser self._elemsNotParsed = elems def analyze(self): for elem_to_parse in self._elemsNotParsed: - elem = StructureVariableSolc(elem_to_parse) - elem.set_structure(self) - elem.set_offset(elem_to_parse['src'], self.contract.slither) + elem = StructureVariable() + elem.set_structure(self._structure) + elem.set_offset(elem_to_parse["src"], self._structure.contract.slither) - elem.analyze(self.contract) + elem_parser = StructureVariableSolc(elem, elem_to_parse) + elem_parser.analyze(self._contract_parser) - self._elems[elem.name] = elem - self._elems_ordered.append(elem.name) + self._structure.elems[elem.name] = elem + self._structure.add_elem_in_order(elem.name) self._elemsNotParsed = [] - diff --git a/slither/solc_parsing/exceptions.py b/slither/solc_parsing/exceptions.py index b9797a65c..cf3928f90 100644 --- a/slither/solc_parsing/exceptions.py +++ b/slither/solc_parsing/exceptions.py @@ -1,5 +1,9 @@ from slither.exceptions import SlitherException -class ParsingError(SlitherException): pass -class VariableNotFound(SlitherException): pass +class ParsingError(SlitherException): + pass + + +class VariableNotFound(SlitherException): + pass diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 5323a0f08..d2ee972df 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -1,23 +1,26 @@ import logging import re +from typing import Dict, TYPE_CHECKING, Optional, Union +from slither.core.declarations import Event, Enum, Structure from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function -from slither.core.declarations.solidity_variables import (SOLIDITY_FUNCTIONS, - SOLIDITY_VARIABLES, - SOLIDITY_VARIABLES_COMPOSED, - SolidityFunction, - SolidityVariable, - SolidityVariableComposed) -from slither.core.expressions.assignment_operation import (AssignmentOperation, - AssignmentOperationType) -from slither.core.expressions.binary_operation import (BinaryOperation, - BinaryOperationType) +from slither.core.declarations.solidity_variables import ( + SOLIDITY_FUNCTIONS, + SOLIDITY_VARIABLES, + SOLIDITY_VARIABLES_COMPOSED, + SolidityFunction, + SolidityVariable, + SolidityVariableComposed, +) +from slither.core.expressions.assignment_operation import ( + AssignmentOperation, + AssignmentOperationType, +) +from slither.core.expressions.binary_operation import BinaryOperation, BinaryOperationType from slither.core.expressions.call_expression import CallExpression -from slither.core.expressions.conditional_expression import \ - ConditionalExpression -from slither.core.expressions.elementary_type_name_expression import \ - ElementaryTypeNameExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.identifier import Identifier from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.literal import Literal @@ -29,16 +32,18 @@ from slither.core.expressions.super_call_expression import SuperCallExpression from slither.core.expressions.super_identifier import SuperIdentifier from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.type_conversion import TypeConversion -from slither.core.expressions.unary_operation import (UnaryOperation, - UnaryOperationType) -from slither.core.solidity_types import (ArrayType, ElementaryType, - FunctionType, MappingType) -from slither.solc_parsing.solidity_types.type_parsing import (UnknownType, - parse_type) +from slither.core.expressions.unary_operation import UnaryOperation, UnaryOperationType +from slither.core.solidity_types import ArrayType, ElementaryType, FunctionType, MappingType +from slither.core.variables.variable import Variable from slither.solc_parsing.exceptions import ParsingError, VariableNotFound +from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type -logger = logging.getLogger("ExpressionParsing") +if TYPE_CHECKING: + from slither.core.expressions.expression import Expression + from slither.solc_parsing.declarations.function import FunctionSolc + from slither.solc_parsing.declarations.contract import ContractSolc +logger = logging.getLogger("ExpressionParsing") ################################################################################### ################################################################################### @@ -46,21 +51,34 @@ logger = logging.getLogger("ExpressionParsing") ################################################################################### ################################################################################### -def get_pointer_name(variable): +CallerContext = Union["ContractSolc", "FunctionSolc"] + + +def get_pointer_name(variable: Variable): curr_type = variable.type - while (isinstance(curr_type, (ArrayType, MappingType))): + while isinstance(curr_type, (ArrayType, MappingType)): if isinstance(curr_type, ArrayType): curr_type = curr_type.type else: assert isinstance(curr_type, MappingType) curr_type = curr_type.type_to - if isinstance(curr_type, (FunctionType)): + if isinstance(curr_type, FunctionType): return variable.name + curr_type.parameters_signature return None -def find_variable(var_name, caller_context, referenced_declaration=None, is_super=False): +def find_variable( + var_name: str, + caller_context: CallerContext, + referenced_declaration: Optional[int] = None, + is_super=False, +) -> Union[ + Variable, Function, Contract, SolidityVariable, SolidityFunction, Event, Enum, Structure +]: + from slither.solc_parsing.declarations.contract import ContractSolc + from slither.solc_parsing.declarations.function import FunctionSolc + # variable are looked from the contract declarer # functions can be shadowed, but are looked from the contract instance, rather than the contract declarer # the difference between function and variable come from the fact that an internal call, or an variable access @@ -76,24 +94,24 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe # for events it's unclear what should be the behavior, as they can be shadowed, but there is not impact # structure/enums cannot be shadowed - if isinstance(caller_context, Contract): - function = None - contract = caller_context - contract_declarer = caller_context - elif isinstance(caller_context, Function): + if isinstance(caller_context, ContractSolc): + function: Optional[FunctionSolc] = None + contract = caller_context.underlying_contract + contract_declarer = caller_context.underlying_contract + elif isinstance(caller_context, FunctionSolc): function = caller_context - contract = function.contract - contract_declarer = function.contract_declarer + contract = function.underlying_function.contract + contract_declarer = function.underlying_function.contract_declarer else: - raise ParsingError('Incorrect caller context') + raise ParsingError("Incorrect caller context") if function: # We look for variable declared with the referencedDeclaration attr func_variables = function.variables_renamed if referenced_declaration and referenced_declaration in func_variables: - return func_variables[referenced_declaration] + return func_variables[referenced_declaration].underlying_variable # If not found, check for name - func_variables = function.variables_as_dict() + func_variables = function.underlying_function.variables_as_dict if var_name in func_variables: return func_variables[var_name] # A local variable can be a pointer @@ -101,12 +119,14 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe # function test(function(uint) internal returns(bool) t) interna{ # Will have a local variable t which will match the signature # t(uint256) - func_variables_ptr = {get_pointer_name(f): f for f in function.variables} + func_variables_ptr = { + get_pointer_name(f): f for f in function.underlying_function.variables + } if var_name and var_name in func_variables_ptr: return func_variables_ptr[var_name] # variable are looked from the contract declarer - contract_variables = contract_declarer.variables_as_dict() + contract_variables = contract_declarer.variables_as_dict if var_name in contract_variables: return contract_variables[var_name] @@ -118,8 +138,12 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe if is_super: getter_available = lambda f: f.functions_declared d = {f.canonical_name: f for f in contract.functions} - functions = {f.full_name: f for f in - contract_declarer.available_elements_from_inheritances(d, getter_available).values()} + functions = { + f.full_name: f + for f in contract_declarer.available_elements_from_inheritances( + d, getter_available + ).values() + } else: functions = contract.available_functions_as_dict() if var_name in functions: @@ -128,23 +152,27 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe if is_super: getter_available = lambda m: m.modifiers_declared d = {m.canonical_name: m for m in contract.modifiers} - modifiers = {m.full_name: m for m in - contract_declarer.available_elements_from_inheritances(d, getter_available).values()} + modifiers = { + m.full_name: m + for m in contract_declarer.available_elements_from_inheritances( + d, getter_available + ).values() + } else: modifiers = contract.available_modifiers_as_dict() if var_name in modifiers: return modifiers[var_name] # structures are looked on the contract declarer - structures = contract.structures_as_dict() + structures = contract.structures_as_dict if var_name in structures: return structures[var_name] - events = contract.events_as_dict() + events = contract.events_as_dict if var_name in events: return events[var_name] - enums = contract.enums_as_dict() + enums = contract.enums_as_dict if var_name in enums: return enums[var_name] @@ -154,7 +182,7 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe return enums[var_name] # Could refer to any enum - all_enums = [c.enums_as_dict() for c in contract.slither.contracts] + all_enums = [c.enums_as_dict for c in contract.slither.contracts] all_enums = {k: v for d in all_enums for k, v in d.items()} if var_name in all_enums: return all_enums[var_name] @@ -165,19 +193,22 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe if var_name in SOLIDITY_FUNCTIONS: return SolidityFunction(var_name) - contracts = contract.slither.contracts_as_dict() + contracts = contract.slither.contracts_as_dict if var_name in contracts: return contracts[var_name] if referenced_declaration: - for contract in contract.slither.contracts: - if contract.id == referenced_declaration: - return contract - for function in contract.slither.functions: - if function.referenced_declaration == referenced_declaration: - return function + # id of the contracts is the referenced declaration + # This is not true for the functions, as we dont always have the referenced_declaration + # But maybe we could? (TODO) + for contract_candidate in contract.slither.contracts: + if contract_candidate.id == referenced_declaration: + return contract_candidate + for function_candidate in caller_context.slither_parser.all_functions_parser: + if function_candidate.referenced_declaration == referenced_declaration: + return function_candidate.underlying_function - raise VariableNotFound('Variable not found: {} (context {})'.format(var_name, caller_context)) + raise VariableNotFound("Variable not found: {} (context {})".format(var_name, caller_context)) # endregion @@ -187,38 +218,39 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe ################################################################################### ################################################################################### -def filter_name(value): - value = value.replace(' memory', '') - value = value.replace(' storage', '') - value = value.replace(' external', '') - value = value.replace(' internal', '') - value = value.replace('struct ', '') - value = value.replace('contract ', '') - value = value.replace('enum ', '') - value = value.replace(' ref', '') - value = value.replace(' pointer', '') - value = value.replace(' pure', '') - value = value.replace(' view', '') - value = value.replace(' constant', '') - value = value.replace(' payable', '') - value = value.replace('function (', 'function(') - value = value.replace('returns (', 'returns(') + +def filter_name(value: str) -> str: + value = value.replace(" memory", "") + value = value.replace(" storage", "") + value = value.replace(" external", "") + value = value.replace(" internal", "") + value = value.replace("struct ", "") + value = value.replace("contract ", "") + value = value.replace("enum ", "") + value = value.replace(" ref", "") + value = value.replace(" pointer", "") + value = value.replace(" pure", "") + value = value.replace(" view", "") + value = value.replace(" constant", "") + value = value.replace(" payable", "") + value = value.replace("function (", "function(") + value = value.replace("returns (", "returns(") # remove the text remaining after functio(...) # which should only be ..returns(...) # nested parenthesis so we use a system of counter on parenthesis - idx = value.find('(') + idx = value.find("(") if idx: counter = 1 max_idx = len(value) while counter: assert idx < max_idx idx = idx + 1 - if value[idx] == '(': + if value[idx] == "(": counter += 1 - elif value[idx] == ')': + elif value[idx] == ")": counter -= 1 - value = value[:idx + 1] + value = value[: idx + 1] return value @@ -230,35 +262,38 @@ def filter_name(value): ################################################################################### ################################################################################### -def parse_call(expression, caller_context): - src = expression['src'] + +def parse_call(expression: Dict, caller_context): + src = expression["src"] if caller_context.is_compact_ast: attributes = expression - type_conversion = expression['kind'] == 'typeConversion' - type_return = attributes['typeDescriptions']['typeString'] + type_conversion = expression["kind"] == "typeConversion" + type_return = attributes["typeDescriptions"]["typeString"] else: - attributes = expression['attributes'] - type_conversion = attributes['type_conversion'] - type_return = attributes['type'] + attributes = expression["attributes"] + type_conversion = attributes["type_conversion"] + type_return = attributes["type"] if type_conversion: type_call = parse_type(UnknownType(type_return), caller_context) if caller_context.is_compact_ast: - assert len(expression['arguments']) == 1 - expression_to_parse = expression['arguments'][0] + assert len(expression["arguments"]) == 1 + expression_to_parse = expression["arguments"][0] else: - children = expression['children'] + children = expression["children"] assert len(children) == 2 type_info = children[0] expression_to_parse = children[1] - assert type_info['name'] in ['ElementaryTypenameExpression', - 'ElementaryTypeNameExpression', - 'Identifier', - 'TupleExpression', - 'IndexAccess', - 'MemberAccess'] + assert type_info["name"] in [ + "ElementaryTypenameExpression", + "ElementaryTypeNameExpression", + "Identifier", + "TupleExpression", + "IndexAccess", + "MemberAccess", + ] expression = parse_expression(expression_to_parse, caller_context) t = TypeConversion(expression, type_call) @@ -269,33 +304,33 @@ def parse_call(expression, caller_context): call_value = None call_salt = None if caller_context.is_compact_ast: - called = parse_expression(expression['expression'], caller_context) + called = parse_expression(expression["expression"], caller_context) # If the next expression is a FunctionCallOptions # We can here the gas/value information # This is only available if the syntax is {gas: , value: } # For the .gas().value(), the member are considered as function call # And converted later to the correct info (convert.py) - if expression['expression'][caller_context.get_key()] == 'FunctionCallOptions': - call_with_options = expression['expression'] - for idx, name in enumerate(call_with_options.get('names', [])): - option = parse_expression(call_with_options['options'][idx], caller_context) - if name == 'value': + if expression["expression"][caller_context.get_key()] == "FunctionCallOptions": + call_with_options = expression["expression"] + for idx, name in enumerate(call_with_options.get("names", [])): + option = parse_expression(call_with_options["options"][idx], caller_context) + if name == "value": call_value = option - if name == 'gas': + if name == "gas": call_gas = option - if name == 'salt': + if name == "salt": call_salt = option arguments = [] - if expression['arguments']: - arguments = [parse_expression(a, caller_context) for a in expression['arguments']] + if expression["arguments"]: + arguments = [parse_expression(a, caller_context) for a in expression["arguments"]] else: - children = expression['children'] + children = expression["children"] called = parse_expression(children[0], caller_context) arguments = [parse_expression(a, caller_context) for a in children[1::]] if isinstance(called, SuperCallExpression): sp = SuperCallExpression(called, arguments, type_return) - sp.set_offset(expression['src'], caller_context.slither) + sp.set_offset(expression["src"], caller_context.slither) return sp call_expression = CallExpression(called, arguments, type_return) call_expression.set_offset(src, caller_context.slither) @@ -307,46 +342,48 @@ def parse_call(expression, caller_context): return call_expression -def parse_super_name(expression, is_compact_ast): +def parse_super_name(expression: Dict, is_compact_ast: bool) -> str: if is_compact_ast: - assert expression['nodeType'] == 'MemberAccess' - base_name = expression['memberName'] - arguments = expression['typeDescriptions']['typeString'] + assert expression["nodeType"] == "MemberAccess" + base_name = expression["memberName"] + arguments = expression["typeDescriptions"]["typeString"] else: - assert expression['name'] == 'MemberAccess' - attributes = expression['attributes'] - base_name = attributes['member_name'] - arguments = attributes['type'] + assert expression["name"] == "MemberAccess" + attributes = expression["attributes"] + base_name = attributes["member_name"] + arguments = attributes["type"] - assert arguments.startswith('function ') + assert arguments.startswith("function ") # remove function (...() - arguments = arguments[len('function '):] + arguments = arguments[len("function ") :] arguments = filter_name(arguments) - if ' ' in arguments: - arguments = arguments[:arguments.find(' ')] + if " " in arguments: + arguments = arguments[: arguments.find(" ")] return base_name + arguments -def _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context): +def _parse_elementary_type_name_expression( + expression: Dict, is_compact_ast: bool, caller_context +) -> ElementaryTypeNameExpression: # nop exression # uint; if is_compact_ast: - value = expression['typeName'] + value = expression["typeName"] else: - assert 'children' not in expression - value = expression['attributes']['value'] + assert "children" not in expression + value = expression["attributes"]["value"] if isinstance(value, dict): t = parse_type(value, caller_context) else: t = parse_type(UnknownType(value), caller_context) e = ElementaryTypeNameExpression(t) - e.set_offset(expression['src'], caller_context.slither) + e.set_offset(expression["src"], caller_context.slither) return e -def parse_expression(expression, caller_context): +def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expression": """ Returns: @@ -375,56 +412,56 @@ def parse_expression(expression, caller_context): # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression # | PrimaryExpression - # The AST naming does not follow the spec + # The AST naming does not follow the spec name = expression[caller_context.get_key()] is_compact_ast = caller_context.is_compact_ast - src = expression['src'] + src = expression["src"] - if name == 'UnaryOperation': + if name == "UnaryOperation": if is_compact_ast: attributes = expression else: - attributes = expression['attributes'] - assert 'prefix' in attributes - operation_type = UnaryOperationType.get_type(attributes['operator'], attributes['prefix']) + attributes = expression["attributes"] + assert "prefix" in attributes + operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"]) if is_compact_ast: - expression = parse_expression(expression['subExpression'], caller_context) + expression = parse_expression(expression["subExpression"], caller_context) else: - assert len(expression['children']) == 1 - expression = parse_expression(expression['children'][0], caller_context) + assert len(expression["children"]) == 1 + expression = parse_expression(expression["children"][0], caller_context) unary_op = UnaryOperation(expression, operation_type) unary_op.set_offset(src, caller_context.slither) return unary_op - elif name == 'BinaryOperation': + elif name == "BinaryOperation": if is_compact_ast: attributes = expression else: - attributes = expression['attributes'] - operation_type = BinaryOperationType.get_type(attributes['operator']) + attributes = expression["attributes"] + operation_type = BinaryOperationType.get_type(attributes["operator"]) if is_compact_ast: - left_expression = parse_expression(expression['leftExpression'], caller_context) - right_expression = parse_expression(expression['rightExpression'], caller_context) + left_expression = parse_expression(expression["leftExpression"], caller_context) + right_expression = parse_expression(expression["rightExpression"], caller_context) else: - assert len(expression['children']) == 2 - left_expression = parse_expression(expression['children'][0], caller_context) - right_expression = parse_expression(expression['children'][1], caller_context) + assert len(expression["children"]) == 2 + left_expression = parse_expression(expression["children"][0], caller_context) + right_expression = parse_expression(expression["children"][1], caller_context) binary_op = BinaryOperation(left_expression, right_expression, operation_type) binary_op.set_offset(src, caller_context.slither) return binary_op - elif name in 'FunctionCall': + elif name in "FunctionCall": return parse_call(expression, caller_context) - elif name == 'FunctionCallOptions': + elif name == "FunctionCallOptions": # call/gas info are handled in parse_call - called = parse_expression(expression['expression'], caller_context) + called = parse_expression(expression["expression"], caller_context) assert isinstance(called, (MemberAccess, NewContract)) return called - elif name == 'TupleExpression': + elif name == "TupleExpression": """ For expression like (a,,c) = (1,2,3) @@ -437,35 +474,39 @@ def parse_expression(expression, caller_context): Note: this is only possible with Solidity >= 0.4.12 """ if is_compact_ast: - expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']] + expressions = [ + parse_expression(e, caller_context) if e else None for e in expression["components"] + ] else: - if 'children' not in expression: - attributes = expression['attributes'] - components = attributes['components'] - expressions = [parse_expression(c, caller_context) if c else None for c in components] + if "children" not in expression: + attributes = expression["attributes"] + components = attributes["components"] + expressions = [ + parse_expression(c, caller_context) if c else None for c in components + ] else: - expressions = [parse_expression(e, caller_context) for e in expression['children']] + expressions = [parse_expression(e, caller_context) for e in expression["children"]] # Add none for empty tuple items if "attributes" in expression: - if "type" in expression['attributes']: - t = expression['attributes']['type'] - if ',,' in t or '(,' in t or ',)' in t: - t = t[len('tuple('):-1] - elems = t.split(',') + if "type" in expression["attributes"]: + t = expression["attributes"]["type"] + if ",," in t or "(," in t or ",)" in t: + t = t[len("tuple(") : -1] + elems = t.split(",") for idx in range(len(elems)): - if elems[idx] == '': + if elems[idx] == "": expressions.insert(idx, None) t = TupleExpression(expressions) t.set_offset(src, caller_context.slither) return t - elif name == 'Conditional': + elif name == "Conditional": if is_compact_ast: - if_expression = parse_expression(expression['condition'], caller_context) - then_expression = parse_expression(expression['trueExpression'], caller_context) - else_expression = parse_expression(expression['falseExpression'], caller_context) + if_expression = parse_expression(expression["condition"], caller_context) + then_expression = parse_expression(expression["trueExpression"], caller_context) + else_expression = parse_expression(expression["falseExpression"], caller_context) else: - children = expression['children'] + children = expression["children"] assert len(children) == 3 if_expression = parse_expression(children[0], caller_context) then_expression = parse_expression(children[1], caller_context) @@ -474,100 +515,103 @@ def parse_expression(expression, caller_context): conditional.set_offset(src, caller_context.slither) return conditional - elif name == 'Assignment': + elif name == "Assignment": if is_compact_ast: - left_expression = parse_expression(expression['leftHandSide'], caller_context) - right_expression = parse_expression(expression['rightHandSide'], caller_context) + left_expression = parse_expression(expression["leftHandSide"], caller_context) + right_expression = parse_expression(expression["rightHandSide"], caller_context) - operation_type = AssignmentOperationType.get_type(expression['operator']) + operation_type = AssignmentOperationType.get_type(expression["operator"]) - operation_return_type = expression['typeDescriptions']['typeString'] + operation_return_type = expression["typeDescriptions"]["typeString"] else: - attributes = expression['attributes'] - children = expression['children'] - assert len(expression['children']) == 2 + attributes = expression["attributes"] + children = expression["children"] + assert len(expression["children"]) == 2 left_expression = parse_expression(children[0], caller_context) right_expression = parse_expression(children[1], caller_context) - operation_type = AssignmentOperationType.get_type(attributes['operator']) - operation_return_type = attributes['type'] + operation_type = AssignmentOperationType.get_type(attributes["operator"]) + operation_return_type = attributes["type"] - assignement = AssignmentOperation(left_expression, right_expression, operation_type, operation_return_type) + assignement = AssignmentOperation( + left_expression, right_expression, operation_type, operation_return_type + ) assignement.set_offset(src, caller_context.slither) return assignement - - - elif name == 'Literal': + elif name == "Literal": subdenomination = None - assert 'children' not in expression + assert "children" not in expression if is_compact_ast: - value = expression['value'] + value = expression["value"] if value: - if 'subdenomination' in expression and expression['subdenomination']: - subdenomination = expression['subdenomination'] + if "subdenomination" in expression and expression["subdenomination"]: + subdenomination = expression["subdenomination"] elif not value and value != "": - value = '0x' + expression['hexValue'] - type = expression['typeDescriptions']['typeString'] + value = "0x" + expression["hexValue"] + type_candidate = expression["typeDescriptions"]["typeString"] # Length declaration for array was None until solc 0.5.5 - if type is None: - if expression['kind'] == 'number': - type = 'int_const' + if type_candidate is None: + if expression["kind"] == "number": + type_candidate = "int_const" else: - value = expression['attributes']['value'] + value = expression["attributes"]["value"] if value: - if 'subdenomination' in expression['attributes'] and expression['attributes']['subdenomination']: - subdenomination = expression['attributes']['subdenomination'] + if ( + "subdenomination" in expression["attributes"] + and expression["attributes"]["subdenomination"] + ): + subdenomination = expression["attributes"]["subdenomination"] elif value is None: # for literal declared as hex # see https://solidity.readthedocs.io/en/v0.4.25/types.html?highlight=hex#hexadecimal-literals - assert 'hexvalue' in expression['attributes'] - value = '0x' + expression['attributes']['hexvalue'] - type = expression['attributes']['type'] + assert "hexvalue" in expression["attributes"] + value = "0x" + expression["attributes"]["hexvalue"] + type_candidate = expression["attributes"]["type"] - if type is None: + if type_candidate is None: if value.isdecimal(): - type = ElementaryType('uint256') + type_candidate = ElementaryType("uint256") else: - type = ElementaryType('string') - elif type.startswith('int_const '): - type = ElementaryType('uint256') - elif type.startswith('bool'): - type = ElementaryType('bool') - elif type.startswith('address'): - type = ElementaryType('address') + type_candidate = ElementaryType("string") + elif type_candidate.startswith("int_const "): + type_candidate = ElementaryType("uint256") + elif type_candidate.startswith("bool"): + type_candidate = ElementaryType("bool") + elif type_candidate.startswith("address"): + type_candidate = ElementaryType("address") else: - type = ElementaryType('string') - literal = Literal(value, type, subdenomination) + type_candidate = ElementaryType("string") + literal = Literal(value, type_candidate, subdenomination) literal.set_offset(src, caller_context.slither) return literal - elif name == 'Identifier': - assert 'children' not in expression + elif name == "Identifier": + assert "children" not in expression t = None if caller_context.is_compact_ast: - value = expression['name'] - t = expression['typeDescriptions']['typeString'] + value = expression["name"] + t = expression["typeDescriptions"]["typeString"] else: - value = expression['attributes']['value'] - if 'type' in expression['attributes']: - t = expression['attributes']['type'] + value = expression["attributes"]["value"] + if "type" in expression["attributes"]: + t = expression["attributes"]["type"] if t: - found = re.findall('[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)', t) + found = re.findall("[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)", t) assert len(found) <= 1 if found: - value = value + '(' + found[0] + ')' + value = value + "(" + found[0] + ")" value = filter_name(value) - if 'referencedDeclaration' in expression: - referenced_declaration = expression['referencedDeclaration'] + if "referencedDeclaration" in expression: + referenced_declaration = expression["referencedDeclaration"] else: referenced_declaration = None @@ -577,14 +621,14 @@ def parse_expression(expression, caller_context): identifier.set_offset(src, caller_context.slither) return identifier - elif name == 'IndexAccess': + elif name == "IndexAccess": if is_compact_ast: - index_type = expression['typeDescriptions']['typeString'] - left = expression['baseExpression'] - right = expression['indexExpression'] + index_type = expression["typeDescriptions"]["typeString"] + left = expression["baseExpression"] + right = expression["indexExpression"] else: - index_type = expression['attributes']['type'] - children = expression['children'] + index_type = expression["attributes"]["type"] + children = expression["children"] assert len(children) == 2 left = children[0] right = children[1] @@ -600,22 +644,22 @@ def parse_expression(expression, caller_context): index.set_offset(src, caller_context.slither) return index - elif name == 'MemberAccess': + elif name == "MemberAccess": if caller_context.is_compact_ast: - member_name = expression['memberName'] - member_type = expression['typeDescriptions']['typeString'] - member_expression = parse_expression(expression['expression'], caller_context) + member_name = expression["memberName"] + member_type = expression["typeDescriptions"]["typeString"] + member_expression = parse_expression(expression["expression"], caller_context) else: - member_name = expression['attributes']['member_name'] - member_type = expression['attributes']['type'] - children = expression['children'] + member_name = expression["attributes"]["member_name"] + member_type = expression["attributes"]["type"] + children = expression["children"] assert len(children) == 1 member_expression = parse_expression(children[0], caller_context) - if str(member_expression) == 'super': + if str(member_expression) == "super": super_name = parse_super_name(expression, is_compact_ast) var = find_variable(super_name, caller_context, is_super=True) if var is None: - raise VariableNotFound('Variable not found: {}'.format(super_name)) + raise VariableNotFound("Variable not found: {}".format(super_name)) sup = SuperIdentifier(var) sup.set_offset(src, caller_context.slither) return sup @@ -627,81 +671,82 @@ def parse_expression(expression, caller_context): return idx return member_access - elif name == 'ElementaryTypeNameExpression': + elif name == "ElementaryTypeNameExpression": return _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context) - # NewExpression is not a root expression, it's always the child of another expression - elif name == 'NewExpression': + elif name == "NewExpression": if is_compact_ast: - type_name = expression['typeName'] + type_name = expression["typeName"] else: - children = expression['children'] + children = expression["children"] assert len(children) == 1 type_name = children[0] - if type_name[caller_context.get_key()] == 'ArrayTypeName': + if type_name[caller_context.get_key()] == "ArrayTypeName": depth = 0 - while type_name[caller_context.get_key()] == 'ArrayTypeName': + while type_name[caller_context.get_key()] == "ArrayTypeName": # Note: dont conserve the size of the array if provided # We compute it directly if is_compact_ast: - type_name = type_name['baseType'] + type_name = type_name["baseType"] else: - type_name = type_name['children'][0] + type_name = type_name["children"][0] depth += 1 - if type_name[caller_context.get_key()] == 'ElementaryTypeName': + if type_name[caller_context.get_key()] == "ElementaryTypeName": if is_compact_ast: - array_type = ElementaryType(type_name['name']) + array_type = ElementaryType(type_name["name"]) else: - array_type = ElementaryType(type_name['attributes']['name']) - elif type_name[caller_context.get_key()] == 'UserDefinedTypeName': + array_type = ElementaryType(type_name["attributes"]["name"]) + elif type_name[caller_context.get_key()] == "UserDefinedTypeName": if is_compact_ast: - array_type = parse_type(UnknownType(type_name['name']), caller_context) + array_type = parse_type(UnknownType(type_name["name"]), caller_context) else: - array_type = parse_type(UnknownType(type_name['attributes']['name']), caller_context) - elif type_name[caller_context.get_key()] == 'FunctionTypeName': + array_type = parse_type( + UnknownType(type_name["attributes"]["name"]), caller_context + ) + elif type_name[caller_context.get_key()] == "FunctionTypeName": array_type = parse_type(type_name, caller_context) else: - raise ParsingError('Incorrect type array {}'.format(type_name)) + raise ParsingError("Incorrect type array {}".format(type_name)) array = NewArray(depth, array_type) array.set_offset(src, caller_context.slither) return array - if type_name[caller_context.get_key()] == 'ElementaryTypeName': + if type_name[caller_context.get_key()] == "ElementaryTypeName": if is_compact_ast: - elem_type = ElementaryType(type_name['name']) + elem_type = ElementaryType(type_name["name"]) else: - elem_type = ElementaryType(type_name['attributes']['name']) + elem_type = ElementaryType(type_name["attributes"]["name"]) new_elem = NewElementaryType(elem_type) new_elem.set_offset(src, caller_context.slither) return new_elem - assert type_name[caller_context.get_key()] == 'UserDefinedTypeName' + assert type_name[caller_context.get_key()] == "UserDefinedTypeName" if is_compact_ast: - contract_name = type_name['name'] + contract_name = type_name["name"] else: - contract_name = type_name['attributes']['name'] + contract_name = type_name["attributes"]["name"] new = NewContract(contract_name) new.set_offset(src, caller_context.slither) return new - elif name == 'ModifierInvocation': + elif name == "ModifierInvocation": if is_compact_ast: - called = parse_expression(expression['modifierName'], caller_context) + called = parse_expression(expression["modifierName"], caller_context) arguments = [] - if expression['arguments']: - arguments = [parse_expression(a, caller_context) for a in expression['arguments']] + if expression["arguments"]: + arguments = [parse_expression(a, caller_context) for a in expression["arguments"]] else: - children = expression['children'] + children = expression["children"] called = parse_expression(children[0], caller_context) arguments = [parse_expression(a, caller_context) for a in children[1::]] - call = CallExpression(called, arguments, 'Modifier') + call = CallExpression(called, arguments, "Modifier") call.set_offset(src, caller_context.slither) return call - raise ParsingError('Expression not parsed %s' % name) + raise ParsingError("Expression not parsed %s" % name) diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index cd786a985..d5de1b56a 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -1,8 +1,8 @@ -import os import json -import re import logging -from typing import Optional, List +import os +import re +from typing import List, Dict from slither.core.declarations import Contract from slither.exceptions import SlitherException @@ -11,44 +11,63 @@ logging.basicConfig() logger = logging.getLogger("SlitherSolcParsing") logger.setLevel(logging.INFO) -from slither.solc_parsing.declarations.contract import ContractSolc04 -from slither.core.slither_core import Slither +from slither.solc_parsing.declarations.contract import ContractSolc +from slither.solc_parsing.declarations.function import FunctionSolc +from slither.core.slither_core import SlitherCore from slither.core.declarations.pragma_directive import Pragma from slither.core.declarations.import_directive import Import from slither.analyses.data_dependency.data_dependency import compute_dependency -class SlitherSolc(Slither): - - def __init__(self, filename): +class SlitherSolc: + def __init__(self, filename: str, core: SlitherCore): super(SlitherSolc, self).__init__() - self._filename = filename - self._contractsNotParsed = [] - self._contracts_by_id = {} + core.filename = filename + self._contracts_by_id: Dict[int, ContractSolc] = {} self._analyzed = False + self._underlying_contract_to_parser: Dict[Contract, ContractSolc] = dict() + self._is_compact_ast = False + self._core: SlitherCore = core + + self._all_functions_parser: List[FunctionSolc] = [] self._top_level_contracts_counter = 0 + @property + def core(self): + return self._core + + @property + def all_functions_parser(self) -> List[FunctionSolc]: + return self._all_functions_parser + + def add_functions_parser(self, f: FunctionSolc): + self._all_functions_parser.append(f) + + @property + def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]: + return self._underlying_contract_to_parser + ################################################################################### ################################################################################### # region AST ################################################################################### ################################################################################### - def get_key(self): + def get_key(self) -> str: if self._is_compact_ast: - return 'nodeType' - return 'name' + return "nodeType" + return "name" - def get_children(self): + def get_children(self) -> str: if self._is_compact_ast: - return 'nodes' - return 'children' + return "nodes" + return "children" @property - def is_compact_ast(self): + def is_compact_ast(self) -> bool: return self._is_compact_ast # endregion @@ -58,139 +77,147 @@ class SlitherSolc(Slither): ################################################################################### ################################################################################### - def _parse_contracts_from_json(self, json_data): + def parse_contracts_from_json(self, json_data: str) -> bool: try: data_loaded = json.loads(json_data) # Truffle AST - if 'ast' in data_loaded: - self._parse_contracts_from_loaded_json(data_loaded['ast'], data_loaded['sourcePath']) + if "ast" in data_loaded: + self.parse_contracts_from_loaded_json(data_loaded["ast"], data_loaded["sourcePath"]) return True # solc AST, where the non-json text was removed else: - if 'attributes' in data_loaded: - filename = data_loaded['attributes']['absolutePath'] + if "attributes" in data_loaded: + filename = data_loaded["attributes"]["absolutePath"] else: - filename = data_loaded['absolutePath'] - self._parse_contracts_from_loaded_json(data_loaded, filename) + filename = data_loaded["absolutePath"] + self.parse_contracts_from_loaded_json(data_loaded, filename) return True except ValueError: - first = json_data.find('{') + first = json_data.find("{") if first != -1: - last = json_data.rfind('}') + 1 + last = json_data.rfind("}") + 1 filename = json_data[0:first] json_data = json_data[first:last] data_loaded = json.loads(json_data) - self._parse_contracts_from_loaded_json(data_loaded, filename) + self.parse_contracts_from_loaded_json(data_loaded, filename) return True return False - def _parse_contracts_from_loaded_json(self, data_loaded, filename): - if 'nodeType' in data_loaded: + def parse_contracts_from_loaded_json(self, data_loaded: Dict, filename: str): + if "nodeType" in data_loaded: self._is_compact_ast = True - if 'sourcePaths' in data_loaded: - for sourcePath in data_loaded['sourcePaths']: + if "sourcePaths" in data_loaded: + for sourcePath in data_loaded["sourcePaths"]: if os.path.isfile(sourcePath): - self._add_source_code(sourcePath) + self._core.add_source_code(sourcePath) - if data_loaded[self.get_key()] == 'root': - self._solc_version = '0.3' - logger.error('solc <0.4 is not supported') + if data_loaded[self.get_key()] == "root": + self._solc_version = "0.3" + logger.error("solc <0.4 is not supported") return - elif data_loaded[self.get_key()] == 'SourceUnit': - self._solc_version = '0.4' + elif data_loaded[self.get_key()] == "SourceUnit": + self._solc_version = "0.4" self._parse_source_unit(data_loaded, filename) else: - logger.error('solc version is not supported') + logger.error("solc version is not supported") return for contract_data in data_loaded[self.get_children()]: - assert contract_data[self.get_key()] in ['ContractDefinition', - 'PragmaDirective', - 'ImportDirective', - 'StructDefinition', - 'EnumDefinition'] - if contract_data[self.get_key()] == 'ContractDefinition': - contract = ContractSolc04(self, contract_data) - if 'src' in contract_data: - contract.set_offset(contract_data['src'], self) - self._contractsNotParsed.append(contract) - elif contract_data[self.get_key()] == 'PragmaDirective': + assert contract_data[self.get_key()] in [ + "ContractDefinition", + "PragmaDirective", + "ImportDirective", + "StructDefinition", + "EnumDefinition", + ] + if contract_data[self.get_key()] == "ContractDefinition": + contract = Contract() + contract_parser = ContractSolc(self, contract, contract_data) + if "src" in contract_data: + contract.set_offset(contract_data["src"], self._core) + + self._underlying_contract_to_parser[contract] = contract_parser + + elif contract_data[self.get_key()] == "PragmaDirective": if self._is_compact_ast: - pragma = Pragma(contract_data['literals']) + pragma = Pragma(contract_data["literals"]) else: - pragma = Pragma(contract_data['attributes']["literals"]) - pragma.set_offset(contract_data['src'], self) - self._pragma_directives.append(pragma) - elif contract_data[self.get_key()] == 'ImportDirective': + pragma = Pragma(contract_data["attributes"]["literals"]) + pragma.set_offset(contract_data["src"], self._core) + self._core.pragma_directives.append(pragma) + elif contract_data[self.get_key()] == "ImportDirective": if self.is_compact_ast: import_directive = Import(contract_data["absolutePath"]) else: - import_directive = Import(contract_data['attributes']["absolutePath"]) - import_directive.set_offset(contract_data['src'], self) - self._import_directives.append(import_directive) + import_directive = Import(contract_data["attributes"]["absolutePath"]) + import_directive.set_offset(contract_data["src"], self._core) + self._core.import_directives.append(import_directive) - elif contract_data[self.get_key()] in ['StructDefinition', 'EnumDefinition']: + elif contract_data[self.get_key()] in ["StructDefinition", "EnumDefinition"]: # This can only happen for top-level structure and enum # They were introduced with 0.6.5 - assert self._is_compact_ast # Do not support top level definition for legacy AST + assert self._is_compact_ast # Do not support top level definition for legacy AST fake_contract_data = { - 'name': f'SlitherInternalTopLevelContract{self._top_level_contracts_counter}', - 'id': -1000, # TODO: determine if collission possible - 'linearizedBaseContracts': [], - 'fullyImplemented': True, - 'contractKind': 'SLitherInternal' + "name": f"SlitherInternalTopLevelContract{self._top_level_contracts_counter}", + "id": -1000, # TODO: determine if collission possible + "linearizedBaseContracts": [], + "fullyImplemented": True, + "contractKind": "SLitherInternal", } self._top_level_contracts_counter += 1 - top_level_contract = ContractSolc04(self, fake_contract_data) + top_level_contract = ContractSolc(self, fake_contract_data) top_level_contract.is_top_level = True - top_level_contract.set_offset(contract_data['src'], self) + top_level_contract.set_offset(contract_data["src"], self) - if contract_data[self.get_key()] == 'StructDefinition': - top_level_contract._structuresNotParsed.append(contract_data) # Todo add proper setters + if contract_data[self.get_key()] == "StructDefinition": + top_level_contract._structuresNotParsed.append( + contract_data + ) # Todo add proper setters else: - top_level_contract._enumsNotParsed.append(contract_data) # Todo add proper setters + top_level_contract._enumsNotParsed.append( + contract_data + ) # Todo add proper setters self._contractsNotParsed.append(top_level_contract) - - def _parse_source_unit(self, data, filename): - if data[self.get_key()] != 'SourceUnit': + def _parse_source_unit(self, data: Dict, filename: str): + if data[self.get_key()] != "SourceUnit": return -1 # handle solc prior 0.3.6 # match any char for filename # filename can contain space, /, -, .. - name = re.findall('=+ (.+) =+', filename) - if name: - assert len(name) == 1 - name = name[0] + name_candidates = re.findall("=+ (.+) =+", filename) + if name_candidates: + assert len(name_candidates) == 1 + name: str = name_candidates[0] else: name = filename sourceUnit = -1 # handle old solc, or error - if 'src' in data: - sourceUnit = re.findall('[0-9]*:[0-9]*:([0-9]*)', data['src']) - if len(sourceUnit) == 1: - sourceUnit = int(sourceUnit[0]) + if "src" in data: + sourceUnit_candidates = re.findall("[0-9]*:[0-9]*:([0-9]*)", data["src"]) + if len(sourceUnit_candidates) == 1: + sourceUnit = int(sourceUnit_candidates[0]) if sourceUnit == -1: # if source unit is not found # We can still deduce it, by assigning to the last source_code added # This works only for crytic compile. # which used --combined-json ast, rather than --ast-json # As a result -1 is not used as index - if self.crytic_compile is not None: - sourceUnit = len(self.source_code) + if self._core.crytic_compile is not None: + sourceUnit = len(self._core.source_code) - self._source_units[sourceUnit] = name - if os.path.isfile(name) and not name in self.source_code: - self._add_source_code(name) + self._core.source_units[sourceUnit] = name + if os.path.isfile(name) and not name in self._core.source_code: + self._core.add_source_code(name) else: - lib_name = os.path.join('node_modules', name) - if os.path.isfile(lib_name) and not name in self.source_code: - self._add_source_code(lib_name) + lib_name = os.path.join("node_modules", name) + if os.path.isfile(lib_name) and not name in self._core.source_code: + self._core.add_source_code(lib_name) # endregion ################################################################################### @@ -200,32 +227,42 @@ class SlitherSolc(Slither): ################################################################################### @property - def analyzed(self): + def analyzed(self) -> bool: return self._analyzed - def _analyze_contracts(self): - if not self._contractsNotParsed: - logger.info(f'No contract were found in {self.filename}, check the correct compilation') + def analyze_contracts(self): + if not self._underlying_contract_to_parser: + logger.info( + f"No contract were found in {self._core.filename}, check the correct compilation" + ) if self._analyzed: - raise Exception('Contract analysis can be run only once!') + raise Exception("Contract analysis can be run only once!") # First we save all the contracts in a dict # the key is the contractid - for contract in self._contractsNotParsed: - if contract.name.startswith('SlitherInternalTopLevelContract') and not contract.is_top_level: - raise SlitherException("""Your codebase has a contract named 'SlitherInternalTopLevelContract'. -Please rename it, this name is reserved for Slither's internals""") - if contract.name in self._contracts: - if contract.id != self._contracts[contract.name].id: - self._contract_name_collisions[contract.name].append(contract.source_mapping_str) - self._contract_name_collisions[contract.name].append( - self._contracts[contract.name].source_mapping_str) + for contract in self._underlying_contract_to_parser.keys(): + if ( + contract.name.startswith("SlitherInternalTopLevelContract") + and not contract.is_top_level + ): + raise SlitherException( + """Your codebase has a contract named 'SlitherInternalTopLevelContract'. +Please rename it, this name is reserved for Slither's internals""" + ) + if contract.name in self._core.contracts_as_dict: + if contract.id != self._core.contracts_as_dict[contract.name].id: + self._core.contract_name_collisions[contract.name].append( + contract.source_mapping_str + ) + self._core.contract_name_collisions[contract.name].append( + self._core.contracts_as_dict[contract.name].source_mapping_str + ) else: self._contracts_by_id[contract.id] = contract - self._contracts[contract.name] = contract + self._core.contracts_as_dict[contract.name] = contract # Update of the inheritance - for contract in self._contractsNotParsed: + for contract_parser in self._underlying_contract_to_parser.values(): # remove the first elem in linearizedBaseContracts as it is the contract itself ancestors = [] fathers = [] @@ -234,58 +271,70 @@ Please rename it, this name is reserved for Slither's internals""") # Resolve linearized base contracts. missing_inheritance = False - for i in contract.linearizedBaseContracts[1:]: - if i in contract.remapping: - ancestors.append(self.get_contract_from_name(contract.remapping[i])) + for i in contract_parser.linearized_base_contracts[1:]: + if i in contract_parser.remapping: + ancestors.append( + self._core.get_contract_from_name(contract_parser.remapping[i]) + ) elif i in self._contracts_by_id: ancestors.append(self._contracts_by_id[i]) else: missing_inheritance = True # Resolve immediate base contracts - for i in contract.baseContracts: - if i in contract.remapping: - fathers.append(self.get_contract_from_name(contract.remapping[i])) + for i in contract_parser.baseContracts: + if i in contract_parser.remapping: + fathers.append(self._core.get_contract_from_name(contract_parser.remapping[i])) elif i in self._contracts_by_id: fathers.append(self._contracts_by_id[i]) else: missing_inheritance = True # Resolve immediate base constructor calls - for i in contract.baseConstructorContractsCalled: - if i in contract.remapping: - father_constructors.append(self.get_contract_from_name(contract.remapping[i])) + for i in contract_parser.baseConstructorContractsCalled: + if i in contract_parser.remapping: + father_constructors.append( + self._core.get_contract_from_name(contract_parser.remapping[i]) + ) elif i in self._contracts_by_id: father_constructors.append(self._contracts_by_id[i]) else: missing_inheritance = True - contract.setInheritance(ancestors, fathers, father_constructors) + contract_parser.underlying_contract.set_inheritance( + ancestors, fathers, father_constructors + ) if missing_inheritance: - self._contract_with_missing_inheritance.add(contract) - contract.log_incorrect_parsing(f'Missing inheritance {contract}') - contract.set_is_analyzed(True) - contract.delete_content() + self._core.contracts_with_missing_inheritance.add( + contract_parser.underlying_contract + ) + contract_parser.log_incorrect_parsing(f"Missing inheritance {contract_parser}") + contract_parser.set_is_analyzed(True) + contract_parser.delete_content() - contracts_to_be_analyzed = self.contracts + contracts_to_be_analyzed = list(self._underlying_contract_to_parser.values()) # Any contract can refer another contract enum without need for inheritance self._analyze_all_enums(contracts_to_be_analyzed) - [c.set_is_analyzed(False) for c in self.contracts] + [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()] - libraries = [c for c in contracts_to_be_analyzed if c.contract_kind == 'library'] - contracts_to_be_analyzed = [c for c in contracts_to_be_analyzed if c.contract_kind != 'library'] + libraries = [ + c for c in contracts_to_be_analyzed if c.underlying_contract.contract_kind == "library" + ] + contracts_to_be_analyzed = [ + c for c in contracts_to_be_analyzed if c.underlying_contract.contract_kind != "library" + ] # We first parse the struct/variables/functions/contract self._analyze_first_part(contracts_to_be_analyzed, libraries) - [c.set_is_analyzed(False) for c in self.contracts] + [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()] # We analyze the struct and parse and analyze the events # A contract can refer in the variables a struct or a event from any contract # (without inheritance link) self._analyze_second_part(contracts_to_be_analyzed, libraries) - [c.set_is_analyzed(False) for c in self.contracts] + [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()] # Then we analyse state variables, functions and modifiers self._analyze_third_part(contracts_to_be_analyzed, libraries) @@ -294,22 +343,27 @@ Please rename it, this name is reserved for Slither's internals""") self._convert_to_slithir() - compute_dependency(self) + compute_dependency(self._core) - def _analyze_all_enums(self, contracts_to_be_analyzed): + def _analyze_all_enums(self, contracts_to_be_analyzed: List[ContractSolc]): while contracts_to_be_analyzed: contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) + all_father_analyzed = all( + self._underlying_contract_to_parser[father].is_analyzed + for father in contract.underlying_contract.inheritance + ) - if not contract.inheritance or all_father_analyzed: + if not contract.underlying_contract.inheritance or all_father_analyzed: self._analyze_enums(contract) else: contracts_to_be_analyzed += [contract] return - def _analyze_first_part(self, contracts_to_be_analyzed, libraries): + def _analyze_first_part( + self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] + ): for lib in libraries: self._parse_struct_var_modifiers_functions(lib) @@ -321,16 +375,21 @@ Please rename it, this name is reserved for Slither's internals""") contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) + all_father_analyzed = all( + self._underlying_contract_to_parser[father].is_analyzed + for father in contract.underlying_contract.inheritance + ) - if not contract.inheritance or all_father_analyzed: + if not contract.underlying_contract.inheritance or all_father_analyzed: self._parse_struct_var_modifiers_functions(contract) else: contracts_to_be_analyzed += [contract] return - def _analyze_second_part(self, contracts_to_be_analyzed, libraries): + def _analyze_second_part( + self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] + ): for lib in libraries: self._analyze_struct_events(lib) @@ -342,16 +401,21 @@ Please rename it, this name is reserved for Slither's internals""") contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) + all_father_analyzed = all( + self._underlying_contract_to_parser[father].is_analyzed + for father in contract.underlying_contract.inheritance + ) - if not contract.inheritance or all_father_analyzed: + if not contract.underlying_contract.inheritance or all_father_analyzed: self._analyze_struct_events(contract) else: contracts_to_be_analyzed += [contract] return - def _analyze_third_part(self, contracts_to_be_analyzed, libraries): + def _analyze_third_part( + self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] + ): for lib in libraries: self._analyze_variables_modifiers_functions(lib) @@ -363,28 +427,31 @@ Please rename it, this name is reserved for Slither's internals""") contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) + all_father_analyzed = all( + self._underlying_contract_to_parser[father].is_analyzed + for father in contract.underlying_contract.inheritance + ) - if not contract.inheritance or all_father_analyzed: + if not contract.underlying_contract.inheritance or all_father_analyzed: self._analyze_variables_modifiers_functions(contract) else: contracts_to_be_analyzed += [contract] return - def _analyze_enums(self, contract): + def _analyze_enums(self, contract: ContractSolc): # Enum must be analyzed first contract.analyze_enums() contract.set_is_analyzed(True) - def _parse_struct_var_modifiers_functions(self, contract): + def _parse_struct_var_modifiers_functions(self, contract: ContractSolc): contract.parse_structs() # struct can refer another struct contract.parse_state_variables() contract.parse_modifiers() contract.parse_functions() contract.set_is_analyzed(True) - def _analyze_struct_events(self, contract): + def _analyze_struct_events(self, contract: ContractSolc): contract.analyze_constant_state_variables() @@ -397,7 +464,7 @@ Please rename it, this name is reserved for Slither's internals""") contract.set_is_analyzed(True) - def _analyze_variables_modifiers_functions(self, contract): + def _analyze_variables_modifiers_functions(self, contract: ContractSolc): # State variables, modifiers and functions can refer to anything contract.analyze_params_modifiers() @@ -412,11 +479,23 @@ Please rename it, this name is reserved for Slither's internals""") def _convert_to_slithir(self): - for contract in self.contracts: + for contract in self._core.contracts: contract.add_constructor_variables() - contract.convert_expression_to_slithir() - self._propagate_function_calls() - for contract in self.contracts: + + for func in contract.functions + contract.modifiers: + try: + func.generate_slithir_and_analyze() + except AttributeError: + # This can happens for example if there is a call to an interface + # And the interface is redefined due to contract's name reuse + # But the available version misses some functions + self._underlying_contract_to_parser[contract].log_incorrect_parsing( + f"Impossible to generate IR for {contract.name}.{func.name}" + ) + + contract.convert_expression_to_slithir_ssa() + self._core.propagate_function_calls() + for contract in self._core.contracts: contract.fix_phi() contract.update_read_write_using_ssa() diff --git a/slither/solc_parsing/solidity_types/type_parsing.py b/slither/solc_parsing/solidity_types/type_parsing.py index 305c38812..9d11936ad 100644 --- a/slither/solc_parsing/solidity_types/type_parsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -1,6 +1,8 @@ import logging +from typing import List, TYPE_CHECKING, Union, Dict from slither.core.solidity_types.elementary_type import ElementaryType, ElementaryTypeName +from slither.core.solidity_types.type import Type from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.mapping_type import MappingType @@ -9,14 +11,17 @@ from slither.core.solidity_types.function_type import FunctionType from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function from slither.core.expressions.literal import Literal from slither.solc_parsing.exceptions import ParsingError import re -logger = logging.getLogger('TypeParsing') +if TYPE_CHECKING: + from slither.core.declarations import Structure, Enum + +logger = logging.getLogger("TypeParsing") + class UnknownType: def __init__(self, name): @@ -26,24 +31,31 @@ class UnknownType: def name(self): return self._name -def _find_from_type_name(name, contract, contracts, structures, enums): - name_elementary = name.split(' ')[0] - if '[' in name_elementary: - name_elementary = name_elementary[0:name_elementary.find('[')] + +def _find_from_type_name( + name: str, + contract: Contract, + contracts: List[Contract], + structures: List["Structure"], + enums: List["Enum"], +) -> Type: + name_elementary = name.split(" ")[0] + if "[" in name_elementary: + name_elementary = name_elementary[0 : name_elementary.find("[")] if name_elementary in ElementaryTypeName: - depth = name.count('[') + depth = name.count("[") if depth: - return ArrayType(ElementaryType(name_elementary), Literal(depth, 'uint256')) + return ArrayType(ElementaryType(name_elementary), Literal(depth, "uint256")) else: return ElementaryType(name_elementary) # We first look for contract - # To avoid collision + # To avoid collision # Ex: a structure with the name of a contract name_contract = name - if name_contract.startswith('contract '): - name_contract = name_contract[len('contract '):] - if name_contract.startswith('library '): - name_contract = name_contract[len('library '):] + if name_contract.startswith("contract "): + name_contract = name_contract[len("contract ") :] + if name_contract.startswith("library "): + name_contract = name_contract[len("library ") :] var_type = next((c for c in contracts if c.name == name_contract), None) if not var_type: @@ -53,8 +65,8 @@ def _find_from_type_name(name, contract, contracts, structures, enums): if not var_type: # any contract can refer to another contract's enum enum_name = name - if enum_name.startswith('enum '): - enum_name = enum_name[len('enum '):] + if enum_name.startswith("enum "): + enum_name = enum_name[len("enum ") :] all_enums = [c.enums for c in contracts] all_enums = [item for sublist in all_enums for item in sublist] var_type = next((e for e in all_enums if e.name == enum_name), None) @@ -63,9 +75,9 @@ def _find_from_type_name(name, contract, contracts, structures, enums): if not var_type: # any contract can refer to another contract's structure name_struct = name - if name_struct.startswith('struct '): - name_struct = name_struct[len('struct '):] - name_struct = name_struct.split(' ')[0] # remove stuff like storage pointer at the end + if name_struct.startswith("struct "): + name_struct = name_struct[len("struct ") :] + name_struct = name_struct.split(" ")[0] # remove stuff like storage pointer at the end all_structures = [c.structures for c in contracts] all_structures = [item for sublist in all_structures for item in sublist] var_type = next((st for st in all_structures if st.name == name_struct), None) @@ -74,75 +86,86 @@ def _find_from_type_name(name, contract, contracts, structures, enums): # case where struct xxx.xx[] where not well formed in the AST if not var_type: depth = 0 - while name_struct.endswith('[]'): + while name_struct.endswith("[]"): name_struct = name_struct[0:-2] - depth+=1 + depth += 1 var_type = next((st for st in all_structures if st.canonical_name == name_struct), None) if var_type: - return ArrayType(UserDefinedType(var_type), Literal(depth, 'uint256')) + return ArrayType(UserDefinedType(var_type), Literal(depth, "uint256")) if not var_type: var_type = next((f for f in contract.functions if f.name == name), None) if not var_type: - if name.startswith('function '): - found = re.findall('function \(([ ()a-zA-Z0-9\.,]*)\) returns \(([a-zA-Z0-9\.,]*)\)', name) + if name.startswith("function "): + found = re.findall( + "function \(([ ()a-zA-Z0-9\.,]*)\) returns \(([a-zA-Z0-9\.,]*)\)", name + ) assert len(found) == 1 - params = found[0][0].split(',') - return_values = found[0][1].split(',') - params = [_find_from_type_name(p, contract, contracts, structures, enums) for p in params] - return_values = [_find_from_type_name(r, contract, contracts, structures, enums) for r in return_values] + params = found[0][0].split(",") + return_values = found[0][1].split(",") + params = [ + _find_from_type_name(p, contract, contracts, structures, enums) for p in params + ] + return_values = [ + _find_from_type_name(r, contract, contracts, structures, enums) + for r in return_values + ] params_vars = [] return_vars = [] for p in params: - var = FunctionTypeVariable() - var.set_type(p) - params_vars.append(var) + var = FunctionTypeVariable() + var.set_type(p) + params_vars.append(var) for r in return_values: - var = FunctionTypeVariable() - var.set_type(r) - return_vars.append(var) + var = FunctionTypeVariable() + var.set_type(r) + return_vars.append(var) return FunctionType(params_vars, return_vars) if not var_type: - if name.startswith('mapping('): + if name.startswith("mapping("): # nested mapping declared with var - if name.count('mapping(') == 1 : - found = re.findall('mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.\[\]]*)\)', name) + if name.count("mapping(") == 1: + found = re.findall("mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.\[\]]*)\)", name) else: - found = re.findall('mapping\(([a-zA-Z0-9\.]*) => (mapping\([=> a-zA-Z0-9\.\[\]]*\))\)', name) + found = re.findall( + "mapping\(([a-zA-Z0-9\.]*) => (mapping\([=> a-zA-Z0-9\.\[\]]*\))\)", name + ) assert len(found) == 1 from_ = found[0][0] to_ = found[0][1] - + from_type = _find_from_type_name(from_, contract, contracts, structures, enums) to_type = _find_from_type_name(to_, contract, contracts, structures, enums) return MappingType(from_type, to_type) if not var_type: - raise ParsingError('Type not found '+str(name)) + raise ParsingError("Type not found " + str(name)) return UserDefinedType(var_type) - -def parse_type(t, caller_context): - # local import to avoid circular dependency +def parse_type(t: Union[Dict, UnknownType], caller_context): + # local import to avoid circular dependency from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc - - if isinstance(caller_context, Contract): - contract = caller_context - elif isinstance(caller_context, Function): - contract = caller_context.contract + from slither.solc_parsing.declarations.contract import ContractSolc + from slither.solc_parsing.declarations.function import FunctionSolc + + if isinstance(caller_context, ContractSolc): + contract = caller_context.underlying_contract + contract_parser = caller_context + is_compact_ast = caller_context.is_compact_ast + elif isinstance(caller_context, FunctionSolc): + contract = caller_context.underlying_function.contract + contract_parser = caller_context.contract_parser + is_compact_ast = caller_context.is_compact_ast else: - raise ParsingError('Incorrect caller context') - - - is_compact_ast = caller_context.is_compact_ast + raise ParsingError(f"Incorrect caller context: {type(caller_context)}") if is_compact_ast: - key = 'nodeType' + key = "nodeType" else: - key = 'name' + key = "name" structures = contract.structures enums = contract.enums @@ -151,75 +174,84 @@ def parse_type(t, caller_context): if isinstance(t, UnknownType): return _find_from_type_name(t.name, contract, contracts, structures, enums) - elif t[key] == 'ElementaryTypeName': + elif t[key] == "ElementaryTypeName": if is_compact_ast: - return ElementaryType(t['name']) - return ElementaryType(t['attributes'][key]) + return ElementaryType(t["name"]) + return ElementaryType(t["attributes"][key]) - elif t[key] == 'UserDefinedTypeName': + elif t[key] == "UserDefinedTypeName": if is_compact_ast: - return _find_from_type_name(t['typeDescriptions']['typeString'], contract, contracts, structures, enums) + return _find_from_type_name( + t["typeDescriptions"]["typeString"], contract, contracts, structures, enums + ) # Determine if we have a type node (otherwise we use the name node, as some older solc did not have 'type'). - type_name_key = 'type' if 'type' in t['attributes'] else key - return _find_from_type_name(t['attributes'][type_name_key], contract, contracts, structures, enums) + type_name_key = "type" if "type" in t["attributes"] else key + return _find_from_type_name( + t["attributes"][type_name_key], contract, contracts, structures, enums + ) - elif t[key] == 'ArrayTypeName': + elif t[key] == "ArrayTypeName": length = None if is_compact_ast: - if t['length']: - length = parse_expression(t['length'], caller_context) - array_type = parse_type(t['baseType'], contract) + if t["length"]: + length = parse_expression(t["length"], caller_context) + array_type = parse_type(t["baseType"], contract_parser) else: - if len(t['children']) == 2: - length = parse_expression(t['children'][1], caller_context) + if len(t["children"]) == 2: + length = parse_expression(t["children"][1], caller_context) else: - assert len(t['children']) == 1 - array_type = parse_type(t['children'][0], contract) + assert len(t["children"]) == 1 + array_type = parse_type(t["children"][0], contract_parser) return ArrayType(array_type, length) - elif t[key] == 'Mapping': + elif t[key] == "Mapping": if is_compact_ast: - mappingFrom = parse_type(t['keyType'], contract) - mappingTo = parse_type(t['valueType'], contract) + mappingFrom = parse_type(t["keyType"], contract_parser) + mappingTo = parse_type(t["valueType"], contract_parser) else: - assert len(t['children']) == 2 + assert len(t["children"]) == 2 - mappingFrom = parse_type(t['children'][0], contract) - mappingTo = parse_type(t['children'][1], contract) + mappingFrom = parse_type(t["children"][0], contract_parser) + mappingTo = parse_type(t["children"][1], contract_parser) return MappingType(mappingFrom, mappingTo) - elif t[key] == 'FunctionTypeName': + elif t[key] == "FunctionTypeName": if is_compact_ast: - params = t['parameterTypes'] - return_values = t['returnParameterTypes'] - index = 'parameters' + params = t["parameterTypes"] + return_values = t["returnParameterTypes"] + index = "parameters" else: - assert len(t['children']) == 2 - params = t['children'][0] - return_values = t['children'][1] - index = 'children' + assert len(t["children"]) == 2 + params = t["children"][0] + return_values = t["children"][1] + index = "children" - assert params[key] == 'ParameterList' - assert return_values[key] == 'ParameterList' + assert params[key] == "ParameterList" + assert return_values[key] == "ParameterList" - params_vars = [] - return_values_vars = [] + params_vars: List[FunctionTypeVariable] = [] + return_values_vars: List[FunctionTypeVariable] = [] for p in params[index]: - var = FunctionTypeVariableSolc(p) - var.set_offset(p['src'], caller_context.slither) - var.analyze(caller_context) + var = FunctionTypeVariable() + var.set_offset(p["src"], caller_context.slither) + + var_parser = FunctionTypeVariableSolc(var, p) + var_parser.analyze(caller_context) + params_vars.append(var) for p in return_values[index]: - var = FunctionTypeVariableSolc(p) + var = FunctionTypeVariable() + var.set_offset(p["src"], caller_context.slither) + + var_parser = FunctionTypeVariableSolc(var, p) + var_parser.analyze(caller_context) - var.set_offset(p['src'], caller_context.slither) - var.analyze(caller_context) return_values_vars.append(var) return FunctionType(params_vars, return_values_vars) - raise ParsingError('Type name not found '+str(t)) + raise ParsingError("Type name not found " + str(t)) diff --git a/slither/solc_parsing/variables/event_variable.py b/slither/solc_parsing/variables/event_variable.py index 8cf3aed9d..6d743ba11 100644 --- a/slither/solc_parsing/variables/event_variable.py +++ b/slither/solc_parsing/variables/event_variable.py @@ -1,10 +1,20 @@ +from typing import Dict from .variable_declaration import VariableDeclarationSolc from slither.core.variables.event_variable import EventVariable -class EventVariableSolc(VariableDeclarationSolc, EventVariable): - def _analyze_variable_attributes(self, attributes): +class EventVariableSolc(VariableDeclarationSolc): + def __init__(self, variable: EventVariable, variable_data: Dict): + super(EventVariableSolc, self).__init__(variable, variable_data) + + @property + def underlying_variable(self) -> EventVariable: + # Todo: Not sure how to overcome this with mypy + assert isinstance(self._variable, EventVariable) + return self._variable + + def _analyze_variable_attributes(self, attributes: Dict): """ Analyze event variable attributes :param attributes: The event variable attributes to parse. @@ -12,8 +22,7 @@ class EventVariableSolc(VariableDeclarationSolc, EventVariable): """ # Check for the indexed attribute - if 'indexed' in attributes: - self._indexed = attributes['indexed'] + if "indexed" in attributes: + self.underlying_variable.indexed = attributes["indexed"] super(EventVariableSolc, self)._analyze_variable_attributes(attributes) - diff --git a/slither/solc_parsing/variables/function_type_variable.py b/slither/solc_parsing/variables/function_type_variable.py index 50af5145c..d54a53bb1 100644 --- a/slither/solc_parsing/variables/function_type_variable.py +++ b/slither/solc_parsing/variables/function_type_variable.py @@ -1,5 +1,15 @@ +from typing import Dict from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc from slither.core.variables.function_type_variable import FunctionTypeVariable -class FunctionTypeVariableSolc(VariableDeclarationSolc, FunctionTypeVariable): pass + +class FunctionTypeVariableSolc(VariableDeclarationSolc): + def __init__(self, variable: FunctionTypeVariable, variable_data: Dict): + super(FunctionTypeVariableSolc, self).__init__(variable, variable_data) + + @property + def underlying_variable(self) -> FunctionTypeVariable: + # Todo: Not sure how to overcome this with mypy + assert isinstance(self._variable, FunctionTypeVariable) + return self._variable diff --git a/slither/solc_parsing/variables/local_variable.py b/slither/solc_parsing/variables/local_variable.py index 1df53261b..c2594c938 100644 --- a/slither/solc_parsing/variables/local_variable.py +++ b/slither/solc_parsing/variables/local_variable.py @@ -1,24 +1,33 @@ +from typing import Dict from .variable_declaration import VariableDeclarationSolc from slither.core.variables.local_variable import LocalVariable -class LocalVariableSolc(VariableDeclarationSolc, LocalVariable): - def _analyze_variable_attributes(self, attributes): - '''' +class LocalVariableSolc(VariableDeclarationSolc): + def __init__(self, variable: LocalVariable, variable_data: Dict): + super(LocalVariableSolc, self).__init__(variable, variable_data) + + @property + def underlying_variable(self) -> LocalVariable: + # Todo: Not sure how to overcome this with mypy + assert isinstance(self._variable, LocalVariable) + return self._variable + + def _analyze_variable_attributes(self, attributes: Dict): + """' Variable Location Can be storage/memory or default - ''' - if 'storageLocation' in attributes: - location = attributes['storageLocation'] - self._location = location + """ + if "storageLocation" in attributes: + location = attributes["storageLocation"] + self.underlying_variable.set_location(location) else: - if 'memory' in attributes['type']: - self._location = 'memory' - elif'storage' in attributes['type']: - self._location = 'storage' + if "memory" in attributes["type"]: + self.underlying_variable.set_location("memory") + elif "storage" in attributes["type"]: + self.underlying_variable.set_location("storage") else: - self._location = 'default' + self.underlying_variable.set_location("default") super(LocalVariableSolc, self)._analyze_variable_attributes(attributes) - diff --git a/slither/solc_parsing/variables/local_variable_init_from_tuple.py b/slither/solc_parsing/variables/local_variable_init_from_tuple.py index e962e4c3c..3384482c1 100644 --- a/slither/solc_parsing/variables/local_variable_init_from_tuple.py +++ b/slither/solc_parsing/variables/local_variable_init_from_tuple.py @@ -1,11 +1,16 @@ +from typing import Dict from .variable_declaration import VariableDeclarationSolc from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple -class LocalVariableInitFromTupleSolc(VariableDeclarationSolc, LocalVariableInitFromTuple): - - def __init__(self, var, index): - super(LocalVariableInitFromTupleSolc, self).__init__(var) - self._tuple_index = index +class LocalVariableInitFromTupleSolc(VariableDeclarationSolc): + def __init__(self, variable: LocalVariableInitFromTuple, variable_data: Dict, index: int): + super(LocalVariableInitFromTupleSolc, self).__init__(variable, variable_data) + variable.tuple_index = index + @property + def underlying_variable(self) -> LocalVariableInitFromTuple: + # Todo: Not sure how to overcome this with mypy + assert isinstance(self._variable, LocalVariableInitFromTuple) + return self._variable diff --git a/slither/solc_parsing/variables/state_variable.py b/slither/solc_parsing/variables/state_variable.py index 242f441ca..398b8ff3c 100644 --- a/slither/solc_parsing/variables/state_variable.py +++ b/slither/solc_parsing/variables/state_variable.py @@ -1,5 +1,15 @@ +from typing import Dict from .variable_declaration import VariableDeclarationSolc from slither.core.variables.state_variable import StateVariable -class StateVariableSolc(VariableDeclarationSolc, StateVariable): pass + +class StateVariableSolc(VariableDeclarationSolc): + def __init__(self, variable: StateVariable, variable_data: Dict): + super(StateVariableSolc, self).__init__(variable, variable_data) + + @property + def underlying_variable(self) -> StateVariable: + # Todo: Not sure how to overcome this with mypy + assert isinstance(self._variable, StateVariable) + return self._variable diff --git a/slither/solc_parsing/variables/structure_variable.py b/slither/solc_parsing/variables/structure_variable.py index f0823f67d..750778678 100644 --- a/slither/solc_parsing/variables/structure_variable.py +++ b/slither/solc_parsing/variables/structure_variable.py @@ -1,5 +1,15 @@ +from typing import Dict from .variable_declaration import VariableDeclarationSolc from slither.core.variables.structure_variable import StructureVariable -class StructureVariableSolc(VariableDeclarationSolc, StructureVariable): pass + +class StructureVariableSolc(VariableDeclarationSolc): + def __init__(self, variable: StructureVariable, variable_data: Dict): + super(StructureVariableSolc, self).__init__(variable, variable_data) + + @property + def underlying_variable(self) -> StructureVariable: + # Todo: Not sure how to overcome this with mypy + assert isinstance(self._variable, StructureVariable) + return self._variable diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index 72d8fc55e..61cc8d403 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -1,4 +1,6 @@ import logging +from typing import Dict + from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.core.variables.variable import Variable @@ -7,28 +9,31 @@ from slither.solc_parsing.solidity_types.type_parsing import parse_type, Unknown from slither.core.solidity_types.elementary_type import ElementaryType, NonElementaryType from slither.solc_parsing.exceptions import ParsingError + logger = logging.getLogger("VariableDeclarationSolcParsing") + class MultipleVariablesDeclaration(Exception): - ''' + """ This is raised on var (a,b) = ... It should occur only on local variable definition - ''' + """ + pass -class VariableDeclarationSolc(Variable): - def __init__(self, var): - ''' +class VariableDeclarationSolc: + def __init__(self, variable: Variable, variable_data: Dict): + """ A variable can be declared through a statement, or directly. If it is through a statement, the following children may contain the init value. It may be possible that the variable is declared through a statement, but the init value is declared at the VariableDeclaration children level - ''' + """ - super(VariableDeclarationSolc, self).__init__() + self._variable = variable self._was_analyzed = False self._elem_to_parse = None self._initializedNotParsed = None @@ -37,125 +42,122 @@ class VariableDeclarationSolc(Variable): self._reference_id = None - - if 'nodeType' in var: + if "nodeType" in variable_data: self._is_compact_ast = True - nodeType = var['nodeType'] - if nodeType in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: - if len(var['declarations'])>1: + nodeType = variable_data["nodeType"] + if nodeType in ["VariableDeclarationStatement", "VariableDefinitionStatement"]: + if len(variable_data["declarations"]) > 1: raise MultipleVariablesDeclaration init = None - if 'initialValue' in var: - init = var['initialValue'] - self._init_from_declaration(var['declarations'][0], init) - elif nodeType == 'VariableDeclaration': - self._init_from_declaration(var, var['value']) + if "initialValue" in variable_data: + init = variable_data["initialValue"] + self._init_from_declaration(variable_data["declarations"][0], init) + elif nodeType == "VariableDeclaration": + self._init_from_declaration(variable_data, variable_data["value"]) else: - raise ParsingError('Incorrect variable declaration type {}'.format(nodeType)) + raise ParsingError("Incorrect variable declaration type {}".format(nodeType)) else: - nodeType = var['name'] + nodeType = variable_data["name"] - if nodeType in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: - if len(var['children']) == 2: - init = var['children'][1] - elif len(var['children']) == 1: + if nodeType in ["VariableDeclarationStatement", "VariableDefinitionStatement"]: + if len(variable_data["children"]) == 2: + init = variable_data["children"][1] + elif len(variable_data["children"]) == 1: init = None - elif len(var['children']) > 2: + elif len(variable_data["children"]) > 2: raise MultipleVariablesDeclaration else: - raise ParsingError('Variable declaration without children?'+var) - declaration = var['children'][0] + raise ParsingError( + "Variable declaration without children?" + str(variable_data) + ) + declaration = variable_data["children"][0] self._init_from_declaration(declaration, init) - elif nodeType == 'VariableDeclaration': - self._init_from_declaration(var, None) + elif nodeType == "VariableDeclaration": + self._init_from_declaration(variable_data, False) else: - raise ParsingError('Incorrect variable declaration type {}'.format(nodeType)) - - @property - def initialized(self): - return self._initialized + raise ParsingError("Incorrect variable declaration type {}".format(nodeType)) @property - def uninitialized(self): - return not self._initialized + def underlying_variable(self) -> Variable: + return self._variable @property - def reference_id(self): - ''' + def reference_id(self) -> int: + """ Return the solc id. It can be compared with the referencedDeclaration attr Returns None if it was not parsed (legacy AST) - ''' + """ return self._reference_id - def _analyze_variable_attributes(self, attributes): - if 'visibility' in attributes: - self._visibility = attributes['visibility'] + def _analyze_variable_attributes(self, attributes: Dict): + if "visibility" in attributes: + self._variable.visibility = attributes["visibility"] else: - self._visibility = 'internal' + self._variable.visibility = "internal" - def _init_from_declaration(self, var, init): + def _init_from_declaration(self, var: Dict, init: bool): if self._is_compact_ast: attributes = var - self._typeName = attributes['typeDescriptions']['typeString'] + self._typeName = attributes["typeDescriptions"]["typeString"] else: - assert len(var['children']) <= 2 - assert var['name'] == 'VariableDeclaration' + assert len(var["children"]) <= 2 + assert var["name"] == "VariableDeclaration" - attributes = var['attributes'] - self._typeName = attributes['type'] + attributes = var["attributes"] + self._typeName = attributes["type"] - self._name = attributes['name'] - self._arrayDepth = 0 - self._isMapping = False - self._mappingFrom = None - self._mappingTo = False - self._initial_expression = None - self._type = None + self._variable.name = attributes["name"] + # self._arrayDepth = 0 + # self._isMapping = False + # self._mappingFrom = None + # self._mappingTo = False + # self._initial_expression = None + # self._type = None # Only for comapct ast format # the id can be used later if referencedDeclaration # is provided - if 'id' in var: - self._reference_id = var['id'] + if "id" in var: + self._reference_id = var["id"] - if 'constant' in attributes: - self._is_constant = attributes['constant'] + if "constant" in attributes: + self._variable.is_constant = attributes["constant"] self._analyze_variable_attributes(attributes) if self._is_compact_ast: - if var['typeName']: - self._elem_to_parse = var['typeName'] + if var["typeName"]: + self._elem_to_parse = var["typeName"] else: - self._elem_to_parse = UnknownType(var['typeDescriptions']['typeString']) + self._elem_to_parse = UnknownType(var["typeDescriptions"]["typeString"]) else: - if not var['children']: + if not var["children"]: # It happens on variable declared inside loop declaration try: - self._type = ElementaryType(self._typeName) + self._variable.type = ElementaryType(self._typeName) self._elem_to_parse = None except NonElementaryType: self._elem_to_parse = UnknownType(self._typeName) else: - self._elem_to_parse = var['children'][0] + self._elem_to_parse = var["children"][0] if self._is_compact_ast: self._initializedNotParsed = init if init: - self._initialized = True + self._variable.initialized = True else: - if init: # there are two way to init a var local in the AST - assert len(var['children']) <= 1 - self._initialized = True + if init: # there are two way to init a var local in the AST + assert len(var["children"]) <= 1 + self._variable.initialized = True self._initializedNotParsed = init - elif len(var['children']) in [0, 1]: - self._initialized = False + elif len(var["children"]) in [0, 1]: + self._variable.initialized = False self._initializedNotParsed = [] else: - assert len(var['children']) == 2 - self._initialized = True - self._initializedNotParsed = var['children'][1] + assert len(var["children"]) == 2 + self._variable.initialized = True + self._initializedNotParsed = var["children"][1] def analyze(self, caller_context): # Can be re-analyzed due to inheritance @@ -164,9 +166,9 @@ class VariableDeclarationSolc(Variable): self._was_analyzed = True if self._elem_to_parse: - self._type = parse_type(self._elem_to_parse, caller_context) + self._variable.type = parse_type(self._elem_to_parse, caller_context) self._elem_to_parse = None - if self._initialized: - self._initial_expression = parse_expression(self._initializedNotParsed, caller_context) + if self._variable.initialized: + self._variable.expression = parse_expression(self._initializedNotParsed, caller_context) self._initializedNotParsed = None diff --git a/slither/tools/demo/__main__.py b/slither/tools/demo/__main__.py index 4bee3b449..03dc984b6 100644 --- a/slither/tools/demo/__main__.py +++ b/slither/tools/demo/__main__.py @@ -9,16 +9,17 @@ logging.getLogger("Slither").setLevel(logging.INFO) logger = logging.getLogger("Slither-demo") + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Demo', - usage='slither-demo filename') + parser = argparse.ArgumentParser(description="Demo", usage="slither-demo filename") - parser.add_argument('filename', - help='The filename of the contract or truffle directory to analyze.') + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) # Add default arguments from crytic-compile cryticparser.init(parser) @@ -32,7 +33,8 @@ def main(): # Perform slither analysis on the given filename slither = Slither(args.filename, **vars(args)) - logger.info('Analysis done!') + logger.info("Analysis done!") + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/erc_conformance/__main__.py b/slither/tools/erc_conformance/__main__.py index fe0fbb845..449ae8669 100644 --- a/slither/tools/erc_conformance/__main__.py +++ b/slither/tools/erc_conformance/__main__.py @@ -17,28 +17,29 @@ logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) -formatter = logging.Formatter('%(message)s') +formatter = logging.Formatter("%(message)s") logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False -ADDITIONAL_CHECKS = { - "ERC20": check_erc20 -} +ADDITIONAL_CHECKS = {"ERC20": check_erc20} + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Check the ERC 20 conformance', - usage='slither-erc project contractName') + parser = argparse.ArgumentParser( + description="Check the ERC 20 conformance", usage="slither-erc project contractName" + ) - parser.add_argument('project', - help='The codebase to be tested.') + parser.add_argument("project", help="The codebase to be tested.") - parser.add_argument('contract_name', - help='The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.') + parser.add_argument( + "contract_name", + help="The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.", + ) parser.add_argument( "--erc", @@ -47,22 +48,26 @@ def parse_args(): default="erc20", ) - parser.add_argument('--json', - help='Export the results as a JSON file ("--json -" to export to stdout)', - action='store', - default=False) + parser.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=False, + ) # Add default arguments from crytic-compile cryticparser.init(parser) return parser.parse_args() + def _log_error(err, args): if args.json: output_to_json(args.json, str(err), {"upgradeability-check": []}) logger.error(err) + def main(): args = parse_args() @@ -76,7 +81,7 @@ def main(): contract = slither.get_contract_from_name(args.contract_name) if not contract: - err = f'Contract not found: {args.contract_name}' + err = f"Contract not found: {args.contract_name}" _log_error(err, args) return # First elem is the function, second is the event @@ -87,7 +92,7 @@ def main(): ADDITIONAL_CHECKS[args.erc.upper()](contract, ret) else: - err = f'Incorrect ERC selected {args.erc}' + err = f"Incorrect ERC selected {args.erc}" _log_error(err, args) return @@ -95,5 +100,5 @@ def main(): output_to_json(args.json, None, {"upgradeability-check": ret}) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/erc_conformance/erc/erc20.py b/slither/tools/erc_conformance/erc/erc20.py index 25473bc84..720b08322 100644 --- a/slither/tools/erc_conformance/erc/erc20.py +++ b/slither/tools/erc_conformance/erc/erc20.py @@ -6,21 +6,25 @@ logger = logging.getLogger("Slither-conformance") def approval_race_condition(contract, ret): - increaseAllowance = contract.get_function_from_signature('increaseAllowance(address,uint256)') + increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)") if not increaseAllowance: - increaseAllowance = contract.get_function_from_signature('safeIncreaseAllowance(address,uint256)') + increaseAllowance = contract.get_function_from_signature( + "safeIncreaseAllowance(address,uint256)" + ) if increaseAllowance: - txt = f'\t[✓] {contract.name} has {increaseAllowance.full_name}' + txt = f"\t[✓] {contract.name} has {increaseAllowance.full_name}" logger.info(txt) else: - txt = f'\t[ ] {contract.name} is not protected for the ERC20 approval race condition' + txt = f"\t[ ] {contract.name} is not protected for the ERC20 approval race condition" logger.info(txt) lack_of_erc20_race_condition_protection = output.Output(txt) lack_of_erc20_race_condition_protection.add(contract) - ret["lack_of_erc20_race_condition_protection"].append(lack_of_erc20_race_condition_protection.data) + ret["lack_of_erc20_race_condition_protection"].append( + lack_of_erc20_race_condition_protection.data + ) def check_erc20(contract, ret, explored=None): diff --git a/slither/tools/erc_conformance/erc/ercs.py b/slither/tools/erc_conformance/erc/ercs.py index 5af000357..334b2a408 100644 --- a/slither/tools/erc_conformance/erc/ercs.py +++ b/slither/tools/erc_conformance/erc/ercs.py @@ -22,13 +22,15 @@ def _check_signature(erc_function, contract, ret): # The check on state variable is needed until we have a better API to handle state variable getters state_variable_as_function = contract.get_state_variable_from_name(name) - if not state_variable_as_function or not state_variable_as_function.visibility in ['public', 'external']: + if not state_variable_as_function or not state_variable_as_function.visibility in [ + "public", + "external", + ]: txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' logger.info(txt) - missing_func = output.Output(txt, additional_fields={ - "function": sig, - "required": required - }) + missing_func = output.Output( + txt, additional_fields={"function": sig, "required": required} + ) missing_func.add(contract) ret["missing_function"].append(missing_func.data) return @@ -38,10 +40,9 @@ def _check_signature(erc_function, contract, ret): if types != parameters: txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' logger.info(txt) - missing_func = output.Output(txt, additional_fields={ - "function": sig, - "required": required - }) + missing_func = output.Output( + txt, additional_fields={"function": sig, "required": required} + ) missing_func.add(contract) ret["missing_function"].append(missing_func.data) return @@ -53,45 +54,51 @@ def _check_signature(erc_function, contract, ret): function_return_type = function.return_type function_view = function.view - txt = f'[✓] {sig} is present' + txt = f"[✓] {sig} is present" logger.info(txt) if function_return_type: - function_return_type = ','.join([str(x) for x in function_return_type]) + function_return_type = ",".join([str(x) for x in function_return_type]) if function_return_type == return_type: - txt = f'\t[✓] {sig} -> () (correct return value)' + txt = f"\t[✓] {sig} -> () (correct return value)" logger.info(txt) else: - txt = f'\t[ ] {sig} -> () should return {return_type}' + txt = f"\t[ ] {sig} -> () should return {return_type}" logger.info(txt) - incorrect_return = output.Output(txt, additional_fields={ - "expected_return_type": return_type, - "actual_return_type": function_return_type - }) + incorrect_return = output.Output( + txt, + additional_fields={ + "expected_return_type": return_type, + "actual_return_type": function_return_type, + }, + ) incorrect_return.add(function) ret["incorrect_return_type"].append(incorrect_return.data) elif not return_type: - txt = f'\t[✓] {sig} -> () (correct return type)' + txt = f"\t[✓] {sig} -> () (correct return type)" logger.info(txt) else: - txt = f'\t[ ] {sig} -> () should return {return_type}' + txt = f"\t[ ] {sig} -> () should return {return_type}" logger.info(txt) - incorrect_return = output.Output(txt, additional_fields={ - "expected_return_type": return_type, - "actual_return_type": function_return_type - }) + incorrect_return = output.Output( + txt, + additional_fields={ + "expected_return_type": return_type, + "actual_return_type": function_return_type, + }, + ) incorrect_return.add(function) ret["incorrect_return_type"].append(incorrect_return.data) if view: if function_view: - txt = f'\t[✓] {sig} is view' + txt = f"\t[✓] {sig} is view" logger.info(txt) else: - txt = f'\t[ ] {sig} should be view' + txt = f"\t[ ] {sig} should be view" logger.info(txt) should_be_view = output.Output(txt) @@ -103,12 +110,12 @@ def _check_signature(erc_function, contract, ret): event_sig = f'{event.name}({",".join(event.parameters)})' if not function: - txt = f'\t[ ] Must emit be view {event_sig}' + txt = f"\t[ ] Must emit be view {event_sig}" logger.info(txt) - missing_event_emmited = output.Output(txt, additional_fields={ - "missing_event": event_sig - }) + missing_event_emmited = output.Output( + txt, additional_fields={"missing_event": event_sig} + ) missing_event_emmited.add(function) ret["missing_event_emmited"].append(missing_event_emmited.data) @@ -121,15 +128,15 @@ def _check_signature(erc_function, contract, ret): event_found = True break if event_found: - txt = f'\t[✓] {event_sig} is emitted' + txt = f"\t[✓] {event_sig} is emitted" logger.info(txt) else: - txt = f'\t[ ] Must emit be view {event_sig}' + txt = f"\t[ ] Must emit be view {event_sig}" logger.info(txt) - missing_event_emmited = output.Output(txt, additional_fields={ - "missing_event": event_sig - }) + missing_event_emmited = output.Output( + txt, additional_fields={"missing_event": event_sig} + ) missing_event_emmited.add(function) ret["missing_event_emmited"].append(missing_event_emmited.data) @@ -143,31 +150,27 @@ def _check_events(erc_event, contract, ret): event = contract.get_event_from_signature(sig) if not event: - txt = f'[ ] {sig} is missing' + txt = f"[ ] {sig} is missing" logger.info(txt) - missing_event = output.Output(txt, additional_fields={ - "event": sig - }) + missing_event = output.Output(txt, additional_fields={"event": sig}) missing_event.add(contract) ret["missing_event"].append(missing_event.data) return - txt = f'[✓] {sig} is present' + txt = f"[✓] {sig} is present" logger.info(txt) for i, index in enumerate(indexes): if index: if event.elems[i].indexed: - txt = f'\t[✓] parameter {i} is indexed' + txt = f"\t[✓] parameter {i} is indexed" logger.info(txt) else: - txt = f'\t[ ] parameter {i} should be indexed' + txt = f"\t[ ] parameter {i} should be indexed" logger.info(txt) - missing_event_index = output.Output(txt, additional_fields={ - "missing_index": i - }) + missing_event_index = output.Output(txt, additional_fields={"missing_index": i}) missing_event_index.add_event(event) ret["missing_event_index"].append(missing_event_index.data) @@ -179,16 +182,16 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None): explored.add(contract) - logger.info(f'# Check {contract.name}\n') + logger.info(f"# Check {contract.name}\n") - logger.info(f'## Check functions') + logger.info(f"## Check functions") for erc_function in erc_functions: _check_signature(erc_function, contract, ret) - logger.info(f'\n## Check events') + logger.info(f"\n## Check events") for erc_event in erc_events: _check_events(erc_event, contract, ret) - logger.info('\n') + logger.info("\n") for derived_contract in contract.derived_contracts: generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored) diff --git a/slither/tools/kspec_coverage/__main__.py b/slither/tools/kspec_coverage/__main__.py index 47dc5c9fa..33bd3a162 100644 --- a/slither/tools/kspec_coverage/__main__.py +++ b/slither/tools/kspec_coverage/__main__.py @@ -11,35 +11,44 @@ logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) -formatter = logging.Formatter('%(message)s') +formatter = logging.Formatter("%(message)s") logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='slither-kspec-coverage', - usage='slither-kspec-coverage contract.sol kspec.md') - - parser.add_argument('contract', help='The filename of the contract or truffle directory to analyze.') - parser.add_argument('kspec', help='The filename of the Klab spec markdown for the analyzed contract(s)') - - parser.add_argument('--version', help='displays the current version', version='0.1.0',action='version') - parser.add_argument('--json', - help='Export the results as a JSON file ("--json -" to export to stdout)', - action='store', - default=False + parser = argparse.ArgumentParser( + description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md" + ) + + parser.add_argument( + "contract", help="The filename of the contract or truffle directory to analyze." + ) + parser.add_argument( + "kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)" ) - cryticparser.init(parser) - - if len(sys.argv) < 2: - parser.print_help(sys.stderr) + parser.add_argument( + "--version", help="displays the current version", version="0.1.0", action="version" + ) + parser.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=False, + ) + + cryticparser.init(parser) + + if len(sys.argv) < 2: + parser.print_help(sys.stderr) sys.exit(1) - + return parser.parse_args() @@ -53,6 +62,7 @@ def main(): args = parse_args() kspec_coverage(args) - -if __name__ == '__main__': + + +if __name__ == "__main__": main() diff --git a/slither/tools/kspec_coverage/analysis.py b/slither/tools/kspec_coverage/analysis.py index d2daf03d0..08e42ad76 100755 --- a/slither/tools/kspec_coverage/analysis.py +++ b/slither/tools/kspec_coverage/analysis.py @@ -7,25 +7,22 @@ from slither.utils.colors import yellow, green, red from slither.utils import output logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger('Slither.kspec') +logger = logging.getLogger("Slither.kspec") def _refactor_type(type): - return { - 'uint': 'uint256', - 'int': 'int256' - }.get(type, type) + return {"uint": "uint256", "int": "int256"}.get(type, type) def _get_all_covered_kspec_functions(target): # Create a set of our discovered functions which are covered covered_functions = set() - BEHAVIOUR_PATTERN = re.compile('behaviour\s+(\S+)\s+of\s+(\S+)') - INTERFACE_PATTERN = re.compile('interface\s+([^\r\n]+)') + BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)") + INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)") # Read the file contents - with open(target, 'r', encoding='utf8') as target_file: + with open(target, "r", encoding="utf8") as target_file: lines = target_file.readlines() # Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex, @@ -38,10 +35,12 @@ def _get_all_covered_kspec_functions(target): match = INTERFACE_PATTERN.match(lines[i + 1]) if match: function_full_name = match.groups()[0] - start, end = function_full_name.index('(') + 1, function_full_name.index(')') - function_arguments = function_full_name[start:end].split(',') - function_arguments = [_refactor_type(arg.strip().split(' ')[0]) for arg in function_arguments] - function_full_name = function_full_name[:start] + ','.join(function_arguments) + ')' + start, end = function_full_name.index("(") + 1, function_full_name.index(")") + function_arguments = function_full_name[start:end].split(",") + function_arguments = [ + _refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments + ] + function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")" covered_functions.add((contract_name, function_full_name)) i += 1 i += 1 @@ -50,14 +49,25 @@ def _get_all_covered_kspec_functions(target): def _get_slither_functions(slither): # Use contract == contract_declarer to avoid dupplicate - all_functions_declared = [f for f in slither.functions if (f.contract == f.contract_declarer - and f.is_implemented - and not f.is_constructor - and not f.is_constructor_variables)] + all_functions_declared = [ + f + for f in slither.functions + if ( + f.contract == f.contract_declarer + and f.is_implemented + and not f.is_constructor + and not f.is_constructor_variables + ) + ] # Use list(set()) because same state variable instances can be shared accross contracts # TODO: integrate state variables - all_functions_declared += list(set([s for s in slither.state_variables if s.visibility in ['public', 'external']])) - slither_functions = {(function.contract.name, function.full_name): function for function in all_functions_declared} + all_functions_declared += list( + set([s for s in slither.state_variables if s.visibility in ["public", "external"]]) + ) + slither_functions = { + (function.contract.name, function.full_name): function + for function in all_functions_declared + } return slither_functions @@ -110,35 +120,42 @@ def _run_coverage_analysis(args, slither, kspec_functions): else: kspec_missing.append(slither_func) - logger.info('## Check for functions coverage') + logger.info("## Check for functions coverage") json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) - json_kspec_missing_functions = _generate_output([f for f in kspec_missing if isinstance(f, Function)], - "[ ] (Missing function)", - red, - args.json) - json_kspec_missing_variables = _generate_output([f for f in kspec_missing if isinstance(f, Variable)], - "[ ] (Missing variable)", - yellow, - args.json) - json_kspec_unresolved = _generate_output_unresolved(kspec_functions_unresolved, - "[ ] (Unresolved)", - yellow, - args.json) + json_kspec_missing_functions = _generate_output( + [f for f in kspec_missing if isinstance(f, Function)], + "[ ] (Missing function)", + red, + args.json, + ) + json_kspec_missing_variables = _generate_output( + [f for f in kspec_missing if isinstance(f, Variable)], + "[ ] (Missing variable)", + yellow, + args.json, + ) + json_kspec_unresolved = _generate_output_unresolved( + kspec_functions_unresolved, "[ ] (Unresolved)", yellow, args.json + ) # Handle unresolved kspecs if args.json: - output.output_to_json(args.json, None, { - "functions_present": json_kspec_present, - "functions_missing": json_kspec_missing_functions, - "variables_missing": json_kspec_missing_variables, - "functions_unresolved": json_kspec_unresolved - }) + output.output_to_json( + args.json, + None, + { + "functions_present": json_kspec_present, + "functions_missing": json_kspec_missing_functions, + "variables_missing": json_kspec_missing_variables, + "functions_unresolved": json_kspec_unresolved, + }, + ) def run_analysis(args, slither, kspec): # Get all of our kspec'd functions (tuple(contract_name, function_name)). - if ',' in kspec: - kspecs = kspec.split(',') + if "," in kspec: + kspecs = kspec.split(",") kspec_functions = set() for kspec in kspecs: kspec_functions |= _get_all_covered_kspec_functions(kspec) diff --git a/slither/tools/kspec_coverage/kspec_coverage.py b/slither/tools/kspec_coverage/kspec_coverage.py index 2ee25477f..86b59be53 100755 --- a/slither/tools/kspec_coverage/kspec_coverage.py +++ b/slither/tools/kspec_coverage/kspec_coverage.py @@ -1,6 +1,7 @@ from slither.tools.kspec_coverage.analysis import run_analysis from slither import Slither + def kspec_coverage(args): contract = args.contract @@ -10,5 +11,3 @@ def kspec_coverage(args): # Run the analysis on the Klab specs run_analysis(args, slither, kspec) - - diff --git a/slither/tools/possible_paths/__main__.py b/slither/tools/possible_paths/__main__.py index 70aea8003..c13a3f390 100644 --- a/slither/tools/possible_paths/__main__.py +++ b/slither/tools/possible_paths/__main__.py @@ -9,18 +9,21 @@ from crytic_compile import cryticparser logging.basicConfig() logging.getLogger("Slither").setLevel(logging.INFO) + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='PossiblePaths', - usage='possible_paths.py filename [contract.function targets]') + parser = argparse.ArgumentParser( + description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]" + ) - parser.add_argument('filename', - help='The filename of the contract or truffle directory to analyze.') + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) - parser.add_argument('targets', nargs='+') + parser.add_argument("targets", nargs="+") cryticparser.init(parser) @@ -62,12 +65,16 @@ def main(): print("\n") # Format all function paths. - reaching_paths_str = [' -> '.join([f"{f.canonical_name}" for f in reaching_path]) for reaching_path in reaching_paths] + reaching_paths_str = [ + " -> ".join([f"{f.canonical_name}" for f in reaching_path]) + for reaching_path in reaching_paths + ] # Print a sorted list of all function paths which can reach the targets. print(f"The following paths reach the specified targets:") for reaching_path in sorted(reaching_paths_str): print(f"{reaching_path}\n") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/slither/tools/possible_paths/possible_paths.py b/slither/tools/possible_paths/possible_paths.py index e638b00ad..8137f3d1b 100644 --- a/slither/tools/possible_paths/possible_paths.py +++ b/slither/tools/possible_paths/possible_paths.py @@ -1,4 +1,5 @@ -class ResolveFunctionException(Exception): pass +class ResolveFunctionException(Exception): + pass def resolve_function(slither, contract_name, function_name): @@ -16,11 +17,15 @@ def resolve_function(slither, contract_name, function_name): raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}") # Obtain the target function - target_function = next((function for function in contract.functions if function.name == function_name), None) + target_function = next( + (function for function in contract.functions if function.name == function_name), None + ) # Verify we have resolved the function specified. if target_function is None: - raise ResolveFunctionException(f"Could not resolve target function: {contract_name}.{function_name}") + raise ResolveFunctionException( + f"Could not resolve target function: {contract_name}.{function_name}" + ) # Add the resolved function to the new list. return target_function @@ -44,17 +49,23 @@ def resolve_functions(slither, functions): for item in functions: if isinstance(item, str): # If the item is a single string, we assume it is of form 'ContractName.FunctionName'. - parts = item.split('.') + parts = item.split(".") if len(parts) < 2: - raise ResolveFunctionException("Provided string descriptor must be of form 'ContractName.FunctionName'") + raise ResolveFunctionException( + "Provided string descriptor must be of form 'ContractName.FunctionName'" + ) resolved.append(resolve_function(slither, parts[0], parts[1])) elif isinstance(item, tuple): # If the item is a tuple, it should be a 2-tuple providing contract and function names. if len(item) != 2: - raise ResolveFunctionException("Provided tuple descriptor must provide a contract and function name.") + raise ResolveFunctionException( + "Provided tuple descriptor must provide a contract and function name." + ) resolved.append(resolve_function(slither, item[0], item[1])) else: - raise ResolveFunctionException(f"Unexpected function descriptor type to resolve in list: {type(item)}") + raise ResolveFunctionException( + f"Unexpected function descriptor type to resolve in list: {type(item)}" + ) # Return the resolved list. return resolved @@ -66,9 +77,12 @@ def all_function_definitions(function): :param function: The function to obtain all definitions at and beneath. :return: Returns a list composed of the provided function definition and any base definitions. """ - return [function] + [f for c in function.contract.inheritance - for f in c.functions_and_modifiers_declared - if f.full_name == function.full_name] + return [function] + [ + f + for c in function.contract.inheritance + for f in c.functions_and_modifiers_declared + if f.full_name == function.full_name + ] def __find_target_paths(slither, target_function, current_path=[]): @@ -102,7 +116,7 @@ def __find_target_paths(slither, target_function, current_path=[]): results = results.union(path_results) # If this path is external accessible from this point, we add the current path to the list. - if target_function.visibility in ['public', 'external'] and len(current_path) > 1: + if target_function.visibility in ["public", "external"] and len(current_path) > 1: results.add(tuple(current_path)) return results @@ -122,6 +136,3 @@ def find_target_paths(slither, target_functions): results = results.union(__find_target_paths(slither, target_function)) return results - - - diff --git a/slither/tools/properties/__main__.py b/slither/tools/properties/__main__.py index 664d5c35a..25685fb6b 100644 --- a/slither/tools/properties/__main__.py +++ b/slither/tools/properties/__main__.py @@ -16,20 +16,21 @@ logging.getLogger("Slither").setLevel(logging.INFO) logger = logging.getLogger("Slither") ch = logging.StreamHandler() ch.setLevel(logging.INFO) -formatter = logging.Formatter('%(message)s') +formatter = logging.Formatter("%(message)s") logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False def _all_scenarios(): - txt = '\n' - txt += '#################### ERC20 ####################\n' + txt = "\n" + txt += "#################### ERC20 ####################\n" for k, value in ERC20_PROPERTIES.items(): - txt += f'{k} - {value.description}\n' + txt += f"{k} - {value.description}\n" return txt + def _all_properties(): table = MyPrettyTable(["Num", "Description", "Scenario"]) idx = 0 @@ -39,6 +40,7 @@ def _all_properties(): idx = idx + 1 return table + class ListScenarios(argparse.Action): def __call__(self, parser, *args, **kwargs): logger.info(_all_scenarios()) @@ -56,43 +58,51 @@ def parse_args(): Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Demo', - usage='slither-demo filename', - formatter_class=argparse.RawDescriptionHelpFormatter) - - parser.add_argument('filename', - help='The filename of the contract or truffle directory to analyze.') - - parser.add_argument('--contract', - help='The targeted contract.') - - parser.add_argument('--scenario', - help=f'Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable', - default='Transferable') - - parser.add_argument('--list-scenarios', - help='List available scenarios', - action=ListScenarios, - nargs=0, - default=False) - - parser.add_argument('--list-properties', - help='List available properties', - action=ListProperties, - nargs=0, - default=False) - - parser.add_argument('--address-owner', - help=f'Owner address. Default {OWNER_ADDRESS}', - default=None) - - parser.add_argument('--address-user', - help=f'Owner address. Default {USER_ADDRESS}', - default=None) - - parser.add_argument('--address-attacker', - help=f'Attacker address. Default {ATTACKER_ADDRESS}', - default=None) + parser = argparse.ArgumentParser( + description="Demo", + usage="slither-demo filename", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) + + parser.add_argument("--contract", help="The targeted contract.") + + parser.add_argument( + "--scenario", + help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable", + default="Transferable", + ) + + parser.add_argument( + "--list-scenarios", + help="List available scenarios", + action=ListScenarios, + nargs=0, + default=False, + ) + + parser.add_argument( + "--list-properties", + help="List available properties", + action=ListProperties, + nargs=0, + default=False, + ) + + parser.add_argument( + "--address-owner", help=f"Owner address. Default {OWNER_ADDRESS}", default=None + ) + + parser.add_argument( + "--address-user", help=f"Owner address. Default {USER_ADDRESS}", default=None + ) + + parser.add_argument( + "--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None + ) # Add default arguments from crytic-compile cryticparser.init(parser) @@ -116,9 +126,9 @@ def main(): contract = slither.contracts[0] else: if args.contract is None: - logger.error(f'Specify the target: --contract ContractName') + logger.error(f"Specify the target: --contract ContractName") else: - logger.error(f'{args.contract} not found') + logger.error(f"{args.contract} not found") return addresses = Addresses(args.address_owner, args.address_user, args.address_attacker) @@ -126,5 +136,5 @@ def main(): generate_erc20(contract, args.scenario, addresses) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/properties/addresses/address.py b/slither/tools/properties/addresses/address.py index e183fd3f7..2068bca23 100644 --- a/slither/tools/properties/addresses/address.py +++ b/slither/tools/properties/addresses/address.py @@ -8,8 +8,12 @@ ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef" class Addresses: - - def __init__(self, owner: Optional[str] = None, user: Optional[str] = None, attacker: Optional[str] = None): + def __init__( + self, + owner: Optional[str] = None, + user: Optional[str] = None, + attacker: Optional[str] = None, + ): self.owner = owner if owner else OWNER_ADDRESS self.user = user if user else USER_ADDRESS self.attacker = attacker if attacker else ATTACKER_ADDRESS diff --git a/slither/tools/properties/platforms/echidna.py b/slither/tools/properties/platforms/echidna.py index f02783429..6ab372cfe 100644 --- a/slither/tools/properties/platforms/echidna.py +++ b/slither/tools/properties/platforms/echidna.py @@ -11,11 +11,11 @@ def generate_echidna_config(output_dir: Path, addresses: Addresses) -> str: :param addresses: :return: """ - content = 'prefix: crytic_\n' + content = "prefix: crytic_\n" content += f'deployer: "{addresses.owner}"\n' content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n' content += f'psender: "{addresses.user}"\n' - content += 'coverage: true\n' - filename = 'echidna_config.yaml' + content += "coverage: true\n" + filename = "echidna_config.yaml" write_file(output_dir, filename, content) return filename diff --git a/slither/tools/properties/platforms/truffle.py b/slither/tools/properties/platforms/truffle.py index 0715ef7eb..48627fab4 100644 --- a/slither/tools/properties/platforms/truffle.py +++ b/slither/tools/properties/platforms/truffle.py @@ -7,21 +7,21 @@ from slither.tools.properties.addresses.address import Addresses from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller from slither.tools.properties.utils import write_file -PATTERN_TRUFFLE_MIGRATION = re.compile('^[0-9]*_') +PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_") logger = logging.getLogger("Slither") def _extract_caller(p: PropertyCaller): if p == PropertyCaller.OWNER: - return ['owner'] + return ["owner"] if p == PropertyCaller.SENDER: - return ['user'] + return ["user"] if p == PropertyCaller.ATTACKER: - return ['attacker'] + return ["attacker"] if p == PropertyCaller.ALL: - return ['owner', 'user', 'attacker'] + return ["owner", "user", "attacker"] assert p == PropertyCaller.ANY - return ['user'] + return ["user"] def _helpers(): @@ -31,7 +31,7 @@ def _helpers(): - catchRevertThrow: check if the call revert/throw :return: """ - return ''' + return """ async function catchRevertThrowReturnFalse(promise) { try { const ret = await promise; @@ -61,12 +61,17 @@ async function catchRevertThrow(promise) { } assert(false, "Expected revert/throw/or return false"); }; -''' +""" -def generate_unit_test(test_contract: str, filename: str, - unit_tests: List[Property], output_dir: Path, - addresses: Addresses, assert_message: str = ''): +def generate_unit_test( + test_contract: str, + filename: str, + unit_tests: List[Property], + output_dir: Path, + addresses: Addresses, + assert_message: str = "", +): """ Generate unit tests files :param test_contract: @@ -88,37 +93,37 @@ def generate_unit_test(test_contract: str, filename: str, content += f'\tlet attacker = "{addresses.attacker}";\n' for unit_test in unit_tests: content += f'\tit("{unit_test.description}", async () => {{\n' - content += f'\t\tlet instance = await {test_contract}.deployed();\n' + content += f"\t\tlet instance = await {test_contract}.deployed();\n" callers = _extract_caller(unit_test.caller) if unit_test.return_type == PropertyReturn.SUCCESS: for caller in callers: - content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' + content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n" if assert_message: content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n' else: - content += f'\t\tassert.equal(test_{caller}, true);\n' + content += f"\t\tassert.equal(test_{caller}, true);\n" elif unit_test.return_type == PropertyReturn.FAIL: for caller in callers: - content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' + content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n" if assert_message: content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n' else: - content += f'\t\tassert.equal(test_{caller}, false);\n' + content += f"\t\tassert.equal(test_{caller}, false);\n" elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW: for caller in callers: - content += f'\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' + content += f"\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n" elif unit_test.return_type == PropertyReturn.THROW: callers = _extract_caller(unit_test.caller) for caller in callers: - content += f'\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' - content += '\t});\n' + content += f"\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n" + content += "\t});\n" - content += '});\n' + content += "});\n" - output_dir = Path(output_dir, 'test') + output_dir = Path(output_dir, "test") output_dir.mkdir(exist_ok=True) - output_dir = Path(output_dir, 'crytic') + output_dir = Path(output_dir, "crytic") output_dir.mkdir(exist_ok=True) write_file(output_dir, filename, content) @@ -133,28 +138,31 @@ def generate_migration(test_contract: str, output_dir: Path, owner_address: str) :param owner_address: :return: """ - content = f'''{test_contract} = artifacts.require("{test_contract}"); + content = f"""{test_contract} = artifacts.require("{test_contract}"); module.exports = function(deployer) {{ deployer.deploy({test_contract}, {{from: "{owner_address}"}}); }}; -''' +""" - output_dir = Path(output_dir, 'migrations') + output_dir = Path(output_dir, "migrations") output_dir.mkdir(exist_ok=True) - migration_files = [js_file for js_file in output_dir.iterdir() if js_file.suffix == '.js' - and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)] + migration_files = [ + js_file + for js_file in output_dir.iterdir() + if js_file.suffix == ".js" and PATTERN_TRUFFLE_MIGRATION.match(js_file.name) + ] idx = len(migration_files) - filename = f'{idx + 1}_{test_contract}.js' - potential_previous_filename = f'{idx}_{test_contract}.js' + filename = f"{idx + 1}_{test_contract}.js" + potential_previous_filename = f"{idx}_{test_contract}.js" for m in migration_files: if m.name == potential_previous_filename: write_file(output_dir, potential_previous_filename, content) return if test_contract in m.name: - logger.error(f'Potential conflicts with {m.name}') + logger.error(f"Potential conflicts with {m.name}") write_file(output_dir, filename, content) diff --git a/slither/tools/properties/properties/erc20.py b/slither/tools/properties/properties/erc20.py index 5c83b1168..31f6ff8cc 100644 --- a/slither/tools/properties/properties/erc20.py +++ b/slither/tools/properties/properties/erc20.py @@ -12,27 +12,37 @@ from slither.tools.properties.platforms.echidna import generate_echidna_config from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable -from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ERC20_NotMintableNotBurnable -from slither.tools.properties.properties.ercs.erc20.properties.transfer import ERC20_Transferable, ERC20_Pausable +from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ( + ERC20_NotMintableNotBurnable, +) +from slither.tools.properties.properties.ercs.erc20.properties.transfer import ( + ERC20_Transferable, + ERC20_Pausable, +) from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test from slither.tools.properties.properties.properties import property_to_solidity, Property -from slither.tools.properties.solidity.generate_properties import generate_solidity_properties, generate_test_contract, \ - generate_solidity_interface +from slither.tools.properties.solidity.generate_properties import ( + generate_solidity_properties, + generate_test_contract, + generate_solidity_interface, +) from slither.utils.colors import red, green logger = logging.getLogger("Slither") -PropertyDescription = namedtuple('PropertyDescription', ['properties', 'description']) +PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"]) ERC20_PROPERTIES = { - "Transferable": PropertyDescription(ERC20_Transferable, 'Test the correct tokens transfer'), - "Pausable": PropertyDescription(ERC20_Pausable, 'Test the pausable functionality'), - "NotMintable": PropertyDescription(ERC20_NotMintable, 'Test that no one can mint tokens'), - "NotMintableNotBurnable": PropertyDescription(ERC20_NotMintableNotBurnable, - 'Test that no one can mint or burn tokens'), - "NotBurnable": PropertyDescription(ERC20_NotBurnable, 'Test that no one can burn tokens'), - "Burnable": PropertyDescription(ERC20_NotBurnable, - 'Test the burn of tokens. Require the "burn(address) returns()" function') + "Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"), + "Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"), + "NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"), + "NotMintableNotBurnable": PropertyDescription( + ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens" + ), + "NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"), + "Burnable": PropertyDescription( + ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function' + ), } @@ -54,7 +64,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) :return: """ if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]: - logging.error(f'{contract.slither.crytic_compile.type} not yet supported by slither-prop') + logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop") return # Check if the contract is an ERC20 contract and if the functions have the correct visibility @@ -65,7 +75,9 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) properties = ERC20_PROPERTIES.get(type_property, None) if properties is None: - logger.error(f'{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}') + logger.error( + f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}" + ) return properties = properties.properties @@ -78,51 +90,53 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) # Generate the contract containing the properties generate_solidity_interface(output_dir, addresses) - property_file = generate_solidity_properties(contract, type_property, solidity_properties, output_dir) + property_file = generate_solidity_properties( + contract, type_property, solidity_properties, output_dir + ) # Generate the Test contract initialization_recommendation = _initialization_recommendation(type_property) - contract_filename, contract_name = generate_test_contract(contract, - type_property, - output_dir, - property_file, - initialization_recommendation) + contract_filename, contract_name = generate_test_contract( + contract, type_property, output_dir, property_file, initialization_recommendation + ) # Generate Echidna config file - echidna_config_filename = generate_echidna_config(Path(contract.slither.crytic_compile.target).parent, addresses) + echidna_config_filename = generate_echidna_config( + Path(contract.slither.crytic_compile.target).parent, addresses + ) - unit_test_info = '' + unit_test_info = "" # If truffle, generate unit tests if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses) - logger.info('################################################') - logger.info(green(f'Update the constructor in {Path(output_dir, contract_filename)}')) + logger.info("################################################") + logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}")) if unit_test_info: logger.info(green(unit_test_info)) - logger.info(green('To run Echidna:')) - txt = f'\t echidna-test {contract.slither.crytic_compile.target} ' - txt += f'--contract {contract_name} --config {echidna_config_filename}' + logger.info(green("To run Echidna:")) + txt = f"\t echidna-test {contract.slither.crytic_compile.target} " + txt += f"--contract {contract_name} --config {echidna_config_filename}" logger.info(green(txt)) def _initialization_recommendation(type_property: str) -> str: - content = '' - content += '\t\t// Add below a minimal configuration:\n' - content += '\t\t// - crytic_owner must have some tokens \n' - content += '\t\t// - crytic_user must have some tokens \n' - content += '\t\t// - crytic_attacker must have some tokens \n' - if type_property in ['Pausable']: - content += '\t\t// - The contract must be paused \n' - if type_property in ['NotMintable', 'NotMintableNotBurnable']: - content += '\t\t// - The contract must not be mintable \n' - if type_property in ['NotBurnable', 'NotMintableNotBurnable']: - content += '\t\t// - The contract must not be burnable \n' - content += '\n' - content += '\n' + content = "" + content += "\t\t// Add below a minimal configuration:\n" + content += "\t\t// - crytic_owner must have some tokens \n" + content += "\t\t// - crytic_user must have some tokens \n" + content += "\t\t// - crytic_attacker must have some tokens \n" + if type_property in ["Pausable"]: + content += "\t\t// - The contract must be paused \n" + if type_property in ["NotMintable", "NotMintableNotBurnable"]: + content += "\t\t// - The contract must not be mintable \n" + if type_property in ["NotBurnable", "NotMintableNotBurnable"]: + content += "\t\t// - The contract must not be burnable \n" + content += "\n" + content += "\n" return content @@ -130,44 +144,44 @@ def _initialization_recommendation(type_property: str) -> str: # TODO: move this to crytic-compile def _platform_to_output_dir(platform: AbstractPlatform) -> Path: if platform.TYPE == PlatformType.TRUFFLE: - return Path(platform.target, 'contracts', 'crytic') + return Path(platform.target, "contracts", "crytic") if platform.TYPE == PlatformType.SOLC: return Path(platform.target).parent def _check_compatibility(contract): - errors = '' + errors = "" if not contract.is_erc20(): - errors = f'{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc' + errors = f"{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc" return errors - transfer = contract.get_function_from_signature('transfer(address,uint256)') + transfer = contract.get_function_from_signature("transfer(address,uint256)") - if transfer.visibility != 'public': - errors = f'slither-prop requires {transfer.canonical_name} to be public. Please change the visibility' + if transfer.visibility != "public": + errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility" - transfer_from = contract.get_function_from_signature('transferFrom(address,address,uint256)') - if transfer_from.visibility != 'public': + transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)") + if transfer_from.visibility != "public": if errors: - errors += '\n' - errors += f'slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility' + errors += "\n" + errors += f"slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility" - approve = contract.get_function_from_signature('approve(address,uint256)') - if approve.visibility != 'public': + approve = contract.get_function_from_signature("approve(address,uint256)") + if approve.visibility != "public": if errors: - errors += '\n' - errors += f'slither-prop requires {approve.canonical_name} to be public. Please change the visibility' + errors += "\n" + errors += f"slither-prop requires {approve.canonical_name} to be public. Please change the visibility" return errors def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]: - solidity_properties = '' + solidity_properties = "" if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: - solidity_properties += '\n'.join([property_to_solidity(p) for p in ERC20_CONFIG]) + solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG]) - solidity_properties += '\n'.join([property_to_solidity(p) for p in properties]) + solidity_properties += "\n".join([property_to_solidity(p) for p in properties]) unit_tests = [p for p in properties if p.is_unit_test] return solidity_properties, unit_tests diff --git a/slither/tools/properties/properties/ercs/erc20/properties/burn.py b/slither/tools/properties/properties/ercs/erc20/properties/burn.py index 47e744995..d612abf87 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/burn.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/burn.py @@ -1,30 +1,38 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_NotBurnable = [ - Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', - description='The total supply does not decrease.', - content=''' -\t\treturn initialTotalSupply == this.totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), + Property( + name="crytic_supply_constant_ERC20PropertiesNotBurnable()", + description="The total supply does not decrease.", + content=""" +\t\treturn initialTotalSupply == this.totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), ] # Require burn(address) returns() ERC20_Burnable = [ - Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', - description='Cannot burn more than available balance', - content=''' + Property( + name="crytic_supply_constant_ERC20PropertiesNotBurnable()", + description="Cannot burn more than available balance", + content=""" \t\tuint balance = balanceOf(msg.sender); \t\tburn(balance + 1); -\t\treturn false;''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL) +\t\treturn false;""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ) ] - diff --git a/slither/tools/properties/properties/ercs/erc20/properties/initialization.py b/slither/tools/properties/properties/ercs/erc20/properties/initialization.py index c01a1d973..5f954b512 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/initialization.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/initialization.py @@ -1,65 +1,76 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_CONFIG = [ - - Property(name='init_total_supply()', - description='The total supply is correctly initialized.', - content=''' -\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_owner_balance()', - description="Owner's balance is correctly initialized.", - content=''' -\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_user_balance()', - description="User's balance is correctly initialized.", - content=''' -\t\treturn initialBalance_user == this.balanceOf(crytic_user);''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_attacker_balance()', - description="Attacker's balance is correctly initialized.", - content=''' -\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_caller_balance()', - description="All the users have a positive balance.", - content=''' -\t\treturn this.balanceOf(msg.sender) >0 ;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ALL), - + Property( + name="init_total_supply()", + description="The total supply is correctly initialized.", + content=""" +\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_owner_balance()", + description="Owner's balance is correctly initialized.", + content=""" +\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_user_balance()", + description="User's balance is correctly initialized.", + content=""" +\t\treturn initialBalance_user == this.balanceOf(crytic_user);""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_attacker_balance()", + description="Attacker's balance is correctly initialized.", + content=""" +\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_caller_balance()", + description="All the users have a positive balance.", + content=""" +\t\treturn this.balanceOf(msg.sender) >0 ;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ALL, + ), # Note: there is a potential overflow on the addition, but we dont consider it - Property(name='init_total_supply_is_balances()', - description="The total supply is the user and owner balance.", - content=''' -\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), -] \ No newline at end of file + Property( + name="init_total_supply_is_balances()", + description="The total supply is the user and owner balance.", + content=""" +\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), +] diff --git a/slither/tools/properties/properties/ercs/erc20/properties/mint.py b/slither/tools/properties/properties/ercs/erc20/properties/mint.py index a1355a2ce..4aafa907e 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/mint.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/mint.py @@ -1,13 +1,20 @@ -from slither.tools.properties.properties.properties import PropertyType, PropertyReturn, Property, PropertyCaller +from slither.tools.properties.properties.properties import ( + PropertyType, + PropertyReturn, + Property, + PropertyCaller, +) ERC20_NotMintable = [ - Property(name='crytic_supply_constant_ERC20PropertiesNotMintable()', - description='The total supply does not increase.', - content=''' -\t\treturn initialTotalSupply >= totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), + Property( + name="crytic_supply_constant_ERC20PropertiesNotMintable()", + description="The total supply does not increase.", + content=""" +\t\treturn initialTotalSupply >= totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), ] diff --git a/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py b/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py index 58b9709a7..5f99838d0 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py @@ -1,14 +1,20 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_NotMintableNotBurnable = [ - - Property(name='crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()', - description='The total supply does not change.', - content=''' -\t\treturn initialTotalSupply == this.totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), -] \ No newline at end of file + Property( + name="crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()", + description="The total supply does not change.", + content=""" +\t\treturn initialTotalSupply == this.totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), +] diff --git a/slither/tools/properties/properties/ercs/erc20/properties/transfer.py b/slither/tools/properties/properties/ercs/erc20/properties/transfer.py index 02860e4ab..bea208f02 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/transfer.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/transfer.py @@ -1,96 +1,108 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_Transferable = [ - - Property(name='crytic_zero_always_empty_ERC20Properties()', - description='The address 0x0 should not receive tokens.', - content=''' -\t\treturn this.balanceOf(address(0x0)) == 0;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), - - Property(name='crytic_approve_overwrites()', - description='Allowance can be changed.', - content=''' + Property( + name="crytic_zero_always_empty_ERC20Properties()", + description="The address 0x0 should not receive tokens.", + content=""" +\t\treturn this.balanceOf(address(0x0)) == 0;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), + Property( + name="crytic_approve_overwrites()", + description="Allowance can be changed.", + content=""" \t\tbool approve_return; \t\tapprove_return = approve(crytic_user, 10); \t\trequire(approve_return); \t\tapprove_return = approve(crytic_user, 20); \t\trequire(approve_return); -\t\treturn this.allowance(msg.sender, crytic_user) == 20;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_less_than_total_ERC20Properties()', - description='Balance of one user must be less or equal to the total supply.', - content=''' -\t\treturn this.balanceOf(msg.sender) <= totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_totalSupply_consistant_ERC20Properties()', - description='Balance of the crytic users must be less or equal to the total supply.', - content=''' -\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), - - Property(name='crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()', - description='No one should be able to send tokens to the address 0x0 (transfer).', - content=''' +\t\treturn this.allowance(msg.sender, crytic_user) == 20;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_less_than_total_ERC20Properties()", + description="Balance of one user must be less or equal to the total supply.", + content=""" +\t\treturn this.balanceOf(msg.sender) <= totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_totalSupply_consistant_ERC20Properties()", + description="Balance of the crytic users must be less or equal to the total supply.", + content=""" +\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), + Property( + name="crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()", + description="No one should be able to send tokens to the address 0x0 (transfer).", + content=""" \t\tif (this.balanceOf(msg.sender) == 0){ \t\t\trevert(); \t\t} -\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()', - description='No one should be able to send tokens to the address 0x0 (transferFrom).', - content=''' +\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()", + description="No one should be able to send tokens to the address 0x0 (transferFrom).", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tif (balance == 0){ \t\t\trevert(); \t\t} \t\tapprove(msg.sender, balance); -\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_self_transferFrom_ERC20PropertiesTransferable()', - description='Self transferFrom works.', - content=''' +\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_self_transferFrom_ERC20PropertiesTransferable()", + description="Self transferFrom works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tbool approve_return = approve(msg.sender, balance); \t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance); -\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()', - description='transferFrom works.', - content=''' +\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()", + description="transferFrom works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tbool approve_return = approve(msg.sender, balance); \t\taddress other = crytic_user; @@ -98,29 +110,30 @@ ERC20_Transferable = [ \t\t\tother = crytic_owner; \t\t} \t\tbool transfer_return = transferFrom(msg.sender, other, balance); -\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - - Property(name='crytic_self_transfer_ERC20PropertiesTransferable()', - description='Self transfer works.', - content=''' +\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_self_transfer_ERC20PropertiesTransferable()", + description="Self transfer works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tbool transfer_return = transfer(msg.sender, balance); -\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_transfer_to_other_ERC20PropertiesTransferable()', - description='transfer works.', - content=''' +\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_transfer_to_other_ERC20PropertiesTransferable()", + description="transfer works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\taddress other = crytic_user; \t\tif (other == msg.sender) { @@ -130,74 +143,76 @@ ERC20_Transferable = [ \t\t\tbool transfer_other = transfer(other, 1); \t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other; \t\t} -\t\treturn true;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_revert_transfer_to_user_ERC20PropertiesTransferable()', - description='Cannot transfer more than the balance.', - content=''' +\t\treturn true;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_revert_transfer_to_user_ERC20PropertiesTransferable()", + description="Cannot transfer more than the balance.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tif (balance == (2 ** 256 - 1)) \t\t\treturn true; \t\tbool transfer_other = transfer(crytic_user, balance+1); -\t\treturn transfer_other;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - +\t\treturn transfer_other;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), ] ERC20_Pausable = [ - - Property(name='crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()', - description='Cannot transfer.', - content=''' -\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()', - description='Cannot execute transferFrom.', - content=''' + Property( + name="crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()", + description="Cannot transfer.", + content=""" +\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()", + description="Cannot execute transferFrom.", + content=""" \t\tapprove(msg.sender, this.balanceOf(msg.sender)); -\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_constantBalance()', - description='Cannot change the balance.', - content=''' -\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_constantAllowance()', - description='Cannot change the allowance.', - content=''' +\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_constantBalance()", + description="Cannot change the balance.", + content=""" +\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_constantAllowance()", + description="Cannot change the allowance.", + content=""" \t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) && -\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - +\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), ] - - diff --git a/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py b/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py index 97bdf167c..614447b4d 100644 --- a/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py +++ b/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py @@ -11,26 +11,32 @@ from slither.tools.properties.properties.properties import Property logger = logging.getLogger("Slither") -def generate_truffle_test(contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses) -> str: - test_contract = f'Test{contract.name}{type_property}' - filename_init = f'Initialization{test_contract}.js' - filename = f'{test_contract}.js' +def generate_truffle_test( + contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses +) -> str: + test_contract = f"Test{contract.name}{type_property}" + filename_init = f"Initialization{test_contract}.js" + filename = f"{test_contract}.js" output_dir = Path(contract.slither.crytic_compile.target) generate_migration(test_contract, output_dir, addresses.owner) - generate_unit_test(test_contract, - filename_init, - ERC20_CONFIG, - output_dir, - addresses, - f'Check the constructor of {test_contract}') - - generate_unit_test(test_contract, filename, unit_tests, output_dir, addresses,) - - log_info = '\n' - log_info += 'To run the unit tests:\n' + generate_unit_test( + test_contract, + filename_init, + ERC20_CONFIG, + output_dir, + addresses, + f"Check the constructor of {test_contract}", + ) + + generate_unit_test( + test_contract, filename, unit_tests, output_dir, addresses, + ) + + log_info = "\n" + log_info += "To run the unit tests:\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n" return log_info diff --git a/slither/tools/properties/properties/properties.py b/slither/tools/properties/properties/properties.py index 3280325d8..f90cde7be 100644 --- a/slither/tools/properties/properties/properties.py +++ b/slither/tools/properties/properties/properties.py @@ -36,4 +36,4 @@ class Property(NamedTuple): def property_to_solidity(p: Property): - return f'\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n' + return f"\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n" diff --git a/slither/tools/properties/solidity/generate_properties.py b/slither/tools/properties/solidity/generate_properties.py index d02b730c7..006ab24c3 100644 --- a/slither/tools/properties/solidity/generate_properties.py +++ b/slither/tools/properties/solidity/generate_properties.py @@ -9,59 +9,64 @@ from slither.tools.properties.utils import write_file logger = logging.getLogger("Slither") -def generate_solidity_properties(contract: Contract, type_property: str, solidity_properties: str, - output_dir: Path) -> Path: +def generate_solidity_properties( + contract: Contract, type_property: str, solidity_properties: str, output_dir: Path +) -> Path: solidity_import = f'import "./interfaces.sol";\n' solidity_import += f'import "../{contract.source_mapping["filename_short"]}";' - test_contract_name = f'Properties{contract.name}{type_property}' + test_contract_name = f"Properties{contract.name}{type_property}" - solidity_content = f'{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}' - solidity_content += f'{{\n\n{solidity_properties}\n}}\n' + solidity_content = ( + f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}" + ) + solidity_content += f"{{\n\n{solidity_properties}\n}}\n" - filename = f'{test_contract_name}.sol' + filename = f"{test_contract_name}.sol" write_file(output_dir, filename, solidity_content) return Path(filename) -def generate_test_contract(contract: Contract, - type_property: str, - output_dir: Path, - property_file: Path, - initialization_recommendation: str) -> Tuple[str, str]: - test_contract_name = f'Test{contract.name}{type_property}' - properties_name = f'Properties{contract.name}{type_property}' +def generate_test_contract( + contract: Contract, + type_property: str, + output_dir: Path, + property_file: Path, + initialization_recommendation: str, +) -> Tuple[str, str]: + test_contract_name = f"Test{contract.name}{type_property}" + properties_name = f"Properties{contract.name}{type_property}" - content = '' + content = "" content += f'import "./{property_file}";\n' content += f"contract {test_contract_name} is {properties_name} {{\n" - content += '\tconstructor() public{\n' - content += '\t\t// Existing addresses:\n' - content += '\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n' - content += '\t\t// - crytic_user: Legitimate user\n' - content += '\t\t// - crytic_attacker: Attacker\n' - content += '\t\t// \n' + content += "\tconstructor() public{\n" + content += "\t\t// Existing addresses:\n" + content += "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n" + content += "\t\t// - crytic_user: Legitimate user\n" + content += "\t\t// - crytic_attacker: Attacker\n" + content += "\t\t// \n" content += initialization_recommendation - content += '\t\t// \n' - content += '\t\t// \n' - content += '\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n' - content += '\t\tinitialTotalSupply = totalSupply();\n' - content += '\t\tinitialBalance_owner = balanceOf(crytic_owner);\n' - content += '\t\tinitialBalance_user = balanceOf(crytic_user);\n' - content += '\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n' + content += "\t\t// \n" + content += "\t\t// \n" + content += "\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n" + content += "\t\tinitialTotalSupply = totalSupply();\n" + content += "\t\tinitialBalance_owner = balanceOf(crytic_owner);\n" + content += "\t\tinitialBalance_user = balanceOf(crytic_user);\n" + content += "\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n" - content += '\t}\n}\n' + content += "\t}\n}\n" - filename = f'{test_contract_name}.sol' + filename = f"{test_contract_name}.sol" write_file(output_dir, filename, content, allow_overwrite=False) return filename, test_contract_name def generate_solidity_interface(output_dir: Path, addresses: Addresses): - content = f''' + content = f""" contract CryticInterface{{ address internal crytic_owner = address({addresses.owner}); address internal crytic_user = address({addresses.user}); @@ -70,7 +75,7 @@ contract CryticInterface{{ uint internal initialBalance_owner; uint internal initialBalance_user; uint internal initialBalance_attacker; -}}''' +}}""" # Static file, we discard if it exists as it should never change - write_file(output_dir, 'interfaces.sol', content, discard_if_exist=True) + write_file(output_dir, "interfaces.sol", content, discard_if_exist=True) diff --git a/slither/tools/properties/utils.py b/slither/tools/properties/utils.py index 239885319..541d85712 100644 --- a/slither/tools/properties/utils.py +++ b/slither/tools/properties/utils.py @@ -6,11 +6,13 @@ from slither.utils.colors import green, yellow logger = logging.getLogger("Slither") -def write_file(output_dir: Path, - filename: str, - content: str, - allow_overwrite: bool = True, - discard_if_exist: bool = False): +def write_file( + output_dir: Path, + filename: str, + content: str, + allow_overwrite: bool = True, + discard_if_exist: bool = False, +): """ Write the content into output_dir/filename :param output_dir: @@ -25,10 +27,10 @@ def write_file(output_dir: Path, if discard_if_exist: return if not allow_overwrite: - logger.info(yellow(f'{file_to_write} already exist and will not be overwritten')) + logger.info(yellow(f"{file_to_write} already exist and will not be overwritten")) return - logger.info(yellow(f'Overwrite {file_to_write}')) + logger.info(yellow(f"Overwrite {file_to_write}")) else: - logger.info(green(f'Write {file_to_write}')) - with open(file_to_write, 'w') as f: + logger.info(green(f"Write {file_to_write}")) + with open(file_to_write, "w") as f: f.write(content) diff --git a/slither/tools/similarity/__main__.py b/slither/tools/similarity/__main__.py index 239b68b62..85f837115 100755 --- a/slither/tools/similarity/__main__.py +++ b/slither/tools/similarity/__main__.py @@ -8,62 +8,56 @@ import operator from crytic_compile import cryticparser -from .info import info -from .test import test -from .train import train -from .plot import plot +from .info import info +from .test import test +from .train import train +from .plot import plot logging.basicConfig() logger = logging.getLogger("Slither-simil") modes = ["info", "test", "train", "plot"] + def parse_args(): - parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector') - - parser.add_argument('mode', - help="|".join(modes)) - - parser.add_argument('model', - help='model.bin') - - parser.add_argument('--filename', - action='store', - dest='filename', - help='contract.sol') - - parser.add_argument('--fname', - action='store', - dest='fname', - help='Target function') - - parser.add_argument('--ext', - action='store', - dest='ext', - help='Extension to filter contracts') - - parser.add_argument('--nsamples', - action='store', - type=int, - dest='nsamples', - help='Number of contract samples used for training') - - parser.add_argument('--ntop', - action='store', - type=int, - dest='ntop', - default=10, - help='Number of more similar contracts to show for testing') - - parser.add_argument('--input', - action='store', - dest='input', - help='File or directory used as input') - - parser.add_argument('--version', - help='displays the current version', - version="0.0", - action='version') + parser = argparse.ArgumentParser( + description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector" + ) + + parser.add_argument("mode", help="|".join(modes)) + + parser.add_argument("model", help="model.bin") + + parser.add_argument("--filename", action="store", dest="filename", help="contract.sol") + + parser.add_argument("--fname", action="store", dest="fname", help="Target function") + + parser.add_argument("--ext", action="store", dest="ext", help="Extension to filter contracts") + + parser.add_argument( + "--nsamples", + action="store", + type=int, + dest="nsamples", + help="Number of contract samples used for training", + ) + + parser.add_argument( + "--ntop", + action="store", + type=int, + dest="ntop", + default=10, + help="Number of more similar contracts to show for testing", + ) + + parser.add_argument( + "--input", action="store", dest="input", help="File or directory used as input" + ) + + parser.add_argument( + "--version", help="displays the current version", version="0.0", action="version" + ) cryticparser.init(parser) @@ -74,6 +68,7 @@ def parse_args(): args = parser.parse_args() return args + # endregion ################################################################################### ################################################################################### @@ -81,27 +76,29 @@ def parse_args(): ################################################################################### ################################################################################### + def main(): args = parse_args() default_log = logging.INFO logger.setLevel(default_log) - + mode = args.mode if mode == "info": info(args) elif mode == "train": - train(args) + train(args) elif mode == "test": test(args) elif mode == "plot": plot(args) else: - logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes)) + logger.error("Invalid mode!. It should be one of these: %s" % ", ".join(modes)) sys.exit(-1) -if __name__ == '__main__': + +if __name__ == "__main__": main() # endregion diff --git a/slither/tools/similarity/cache.py b/slither/tools/similarity/cache.py index de8896d02..81df40972 100644 --- a/slither/tools/similarity/cache.py +++ b/slither/tools/similarity/cache.py @@ -7,16 +7,18 @@ except ImportError: print("$ pip3 install numpy --user\n") sys.exit(-1) + def load_cache(infile, nsamples=None): cache = dict() with np.load(infile, allow_pickle=True) as data: - array = data['arr_0'][0] - for i,(x,y) in enumerate(array): + array = data["arr_0"][0] + for i, (x, y) in enumerate(array): cache[x] = y if i == nsamples: break return cache + def save_cache(cache, outfile): - np.savez(outfile,[np.array(cache)]) + np.savez(outfile, [np.array(cache)]) diff --git a/slither/tools/similarity/encode.py b/slither/tools/similarity/encode.py index 06b3691ed..615ed98ca 100644 --- a/slither/tools/similarity/encode.py +++ b/slither/tools/similarity/encode.py @@ -2,15 +2,46 @@ import logging import os from slither import Slither -from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function +from slither.core.declarations import ( + Structure, + Enum, + SolidityVariableComposed, + SolidityVariable, + Function, +) from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.core.variables.state_variable import StateVariable -from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \ - Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \ - SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \ - HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall +from slither.slithir.operations import ( + Assignment, + Index, + Member, + Length, + Balance, + Binary, + Unary, + Condition, + NewArray, + NewStructure, + NewContract, + NewElementaryType, + SolidityCall, + Push, + Delete, + EventCall, + LibraryCall, + InternalDynamicCall, + HighLevelCall, + LowLevelCall, + TypeConversion, + Return, + Transfer, + Send, + Unpack, + InitArray, + InternalCall, +) from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable from .cache import load_cache @@ -20,11 +51,12 @@ compiler_logger.setLevel(logging.CRITICAL) slither_logger = logging.getLogger("Slither") slither_logger.setLevel(logging.CRITICAL) + def parse_target(target): if target is None: return None, None - parts = target.split('.') + parts = target.split(".") if len(parts) == 1: return None, parts[0] elif len(parts) == 2: @@ -32,25 +64,27 @@ def parse_target(target): else: simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") + def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): r = dict() if infile.endswith(".npz"): r = load_cache(infile, nsamples=nsamples) - else: + else: contracts = load_contracts(infile, ext=ext, nsamples=nsamples) for contract in contracts: - for x,ir in encode_contract(contract, **kwargs).items(): + for x, ir in encode_contract(contract, **kwargs).items(): if ir != []: y = " ".join(ir) r[x] = vmodel.get_sentence_vector(y) return r + def load_contracts(dirname, ext=None, nsamples=None, **kwargs): r = [] walk = list(os.walk(dirname)) for x, y, files in walk: - for f in files: + for f in files: if ext is None or f.endswith(ext): r.append(x + "/".join(y) + "/" + f) @@ -60,6 +94,7 @@ def load_contracts(dirname, ext=None, nsamples=None, **kwargs): # TODO: shuffle return r[:nsamples] + def ntype(_type): if isinstance(_type, ElementaryType): _type = str(_type) @@ -79,8 +114,8 @@ def ntype(_type): else: _type = str(_type) - _type = _type.replace(" memory","") - _type = _type.replace(" storage ref","") + _type = _type.replace(" memory", "") + _type = _type.replace(" storage ref", "") if "struct" in _type: return "struct" @@ -93,92 +128,94 @@ def ntype(_type): elif "mapping" in _type: return "mapping" else: - return _type.replace(" ","_") + return _type.replace(" ", "_") + def encode_ir(ir): # operations if isinstance(ir, Assignment): - return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) + return "({}):=({})".format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) if isinstance(ir, Index): - return 'index({})'.format(ntype(ir._type)) + return "index({})".format(ntype(ir._type)) if isinstance(ir, Member): - return 'member' #.format(ntype(ir._type)) + return "member" # .format(ntype(ir._type)) if isinstance(ir, Length): - return 'length' + return "length" if isinstance(ir, Balance): - return 'balance' + return "balance" if isinstance(ir, Binary): - return 'binary({})'.format(ir.type_str) + return "binary({})".format(str(ir.type)) if isinstance(ir, Unary): - return 'unary({})'.format(ir.type_str) + return "unary({})".format(str(ir.type)) if isinstance(ir, Condition): - return 'condition({})'.format(encode_ir(ir.value)) + return "condition({})".format(encode_ir(ir.value)) if isinstance(ir, NewStructure): - return 'new_structure' + return "new_structure" if isinstance(ir, NewContract): - return 'new_contract' + return "new_contract" if isinstance(ir, NewArray): - return 'new_array({})'.format(ntype(ir._array_type)) + return "new_array({})".format(ntype(ir._array_type)) if isinstance(ir, NewElementaryType): - return 'new_elementary({})'.format(ntype(ir._type)) + return "new_elementary({})".format(ntype(ir._type)) if isinstance(ir, Push): - return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue)) + return "push({},{})".format(encode_ir(ir.value), encode_ir(ir.lvalue)) if isinstance(ir, Delete): - return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable)) + return "delete({},{})".format(encode_ir(ir.lvalue), encode_ir(ir.variable)) if isinstance(ir, SolidityCall): - return 'solidity_call({})'.format(ir.function.full_name) + return "solidity_call({})".format(ir.function.full_name) if isinstance(ir, InternalCall): - return 'internal_call({})'.format(ntype(ir._type_call)) - if isinstance(ir, EventCall): # is this useful? - return 'event' + return "internal_call({})".format(ntype(ir._type_call)) + if isinstance(ir, EventCall): # is this useful? + return "event" if isinstance(ir, LibraryCall): - return 'library_call' + return "library_call" if isinstance(ir, InternalDynamicCall): - return 'internal_dynamic_call' - if isinstance(ir, HighLevelCall): # TODO: improve - return 'high_level_call' - if isinstance(ir, LowLevelCall): # TODO: improve - return 'low_level_call' + return "internal_dynamic_call" + if isinstance(ir, HighLevelCall): # TODO: improve + return "high_level_call" + if isinstance(ir, LowLevelCall): # TODO: improve + return "low_level_call" if isinstance(ir, TypeConversion): - return 'type_conversion({})'.format(ntype(ir.type)) - if isinstance(ir, Return): # this can be improved using values - return 'return' #.format(ntype(ir.type)) + return "type_conversion({})".format(ntype(ir.type)) + if isinstance(ir, Return): # this can be improved using values + return "return" # .format(ntype(ir.type)) if isinstance(ir, Transfer): - return 'transfer({})'.format(encode_ir(ir.call_value)) + return "transfer({})".format(encode_ir(ir.call_value)) if isinstance(ir, Send): - return 'send({})'.format(encode_ir(ir.call_value)) - if isinstance(ir, Unpack): # TODO: improve - return 'unpack' - if isinstance(ir, InitArray): # TODO: improve - return 'init_array' - if isinstance(ir, Function): # TODO: investigate this - return 'function_solc' + return "send({})".format(encode_ir(ir.call_value)) + if isinstance(ir, Unpack): # TODO: improve + return "unpack" + if isinstance(ir, InitArray): # TODO: improve + return "init_array" + if isinstance(ir, Function): # TODO: investigate this + return "function_solc" # variables if isinstance(ir, Constant): - return 'constant({})'.format(ntype(ir._type)) + return "constant({})".format(ntype(ir._type)) if isinstance(ir, SolidityVariableComposed): - return 'solidity_variable_composed({})'.format(ir.name) + return "solidity_variable_composed({})".format(ir.name) if isinstance(ir, SolidityVariable): - return 'solidity_variable{}'.format(ir.name) + return "solidity_variable{}".format(ir.name) if isinstance(ir, TemporaryVariable): - return 'temporary_variable' + return "temporary_variable" if isinstance(ir, ReferenceVariable): - return 'reference({})'.format(ntype(ir._type)) + return "reference({})".format(ntype(ir._type)) if isinstance(ir, LocalVariable): - return 'local_solc_variable({})'.format(ir._location) + return "local_solc_variable({})".format(ir._location) if isinstance(ir, StateVariable): - return 'state_solc_variable({})'.format(ntype(ir._type)) + return "state_solc_variable({})".format(ntype(ir._type)) if isinstance(ir, LocalVariableInitFromTuple): - return 'local_variable_init_tuple' + return "local_variable_init_tuple" if isinstance(ir, TupleVariable): - return 'tuple_variable' + return "tuple_variable" # default else: - simil_logger.error(type(ir),"is missing encoding!") - return '' - + simil_logger.error(type(ir), "is missing encoding!") + return "" + + def encode_contract(cfilename, **kwargs): r = dict() @@ -186,7 +223,7 @@ def encode_contract(cfilename, **kwargs): try: slither = Slither(cfilename, **kwargs) except: - simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc']) + simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs["solc"]) return r # Iterate over all the contracts @@ -198,7 +235,7 @@ def encode_contract(cfilename, **kwargs): if function.nodes == [] or function.is_constructor_variables: continue - x = (cfilename,contract.name,function.name) + x = (cfilename, contract.name, function.name) r[x] = [] @@ -210,5 +247,3 @@ def encode_contract(cfilename, **kwargs): for ir in node.irs: r[x].append(encode_ir(ir)) return r - - diff --git a/slither/tools/similarity/info.py b/slither/tools/similarity/info.py index e250aa991..b577bfd93 100644 --- a/slither/tools/similarity/info.py +++ b/slither/tools/similarity/info.py @@ -3,18 +3,19 @@ import sys import os.path import traceback -from .model import load_model +from .model import load_model from .encode import parse_target, encode_contract logging.basicConfig() logger = logging.getLogger("Slither-simil") + def info(args): try: model = args.model - if os.path.isfile(model): + if os.path.isfile(model): model = load_model(model) else: model = None @@ -22,22 +23,22 @@ def info(args): filename = args.filename contract, fname = parse_target(args.fname) solc = args.solc - + if filename is None and contract is None and fname is None: - logger.info("%s uses the following words:",args.model) + logger.info("%s uses the following words:", args.model) for word in model.get_words(): logger.info(word) sys.exit(0) if filename is None or contract is None or fname is None: - logger.error('The encode mode requires filename, contract and fname parameters.') + logger.error("The encode mode requires filename, contract and fname parameters.") sys.exit(-1) irs = encode_contract(filename, **vars(args)) if len(irs) == 0: sys.exit(-1) - - x = (filename,contract,fname) + + x = (filename, contract, fname) y = " ".join(irs[x]) logger.info("Function {} in contract {} is encoded as:".format(fname, contract)) @@ -47,8 +48,6 @@ def info(args): logger.info(fvector) except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) - - diff --git a/slither/tools/similarity/plot.py b/slither/tools/similarity/plot.py index 05d8bf921..75ef90c15 100644 --- a/slither/tools/similarity/plot.py +++ b/slither/tools/similarity/plot.py @@ -5,7 +5,7 @@ import operator import numpy as np import random -from .model import load_model +from .model import load_model from .encode import load_and_encode, parse_target try: @@ -17,10 +17,13 @@ except ImportError: logger = logging.getLogger("Slither-simil") + def plot(args): if decomposition is None or plt is None: - logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:") + logger.error( + "ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:" + ) logger.error("$ pip3 install sklearn matplotlib --user") sys.exit(-1) @@ -29,50 +32,50 @@ def plot(args): model = args.model model = load_model(model) filename = args.filename - #contract = args.contract + # contract = args.contract contract, fname = parse_target(args.fname) - #solc = args.solc + # solc = args.solc infile = args.input - #ext = args.filter - #nsamples = args.nsamples + # ext = args.filter + # nsamples = args.nsamples if fname is None or infile is None: - logger.error('The plot mode requieres fname and input parameters.') + logger.error("The plot mode requieres fname and input parameters.") sys.exit(-1) - logger.info('Loading data..') + logger.info("Loading data..") cache = load_and_encode(infile, **vars(args)) data = list() fs = list() - logger.info('Procesing data..') - for (f,c,n),y in cache.items(): + logger.info("Procesing data..") + for (f, c, n), y in cache.items(): if (c == contract or contract is None) and n == fname: fs.append(f) data.append(y) if len(data) == 0: - logger.error('No contract was found with function %s', fname) + logger.error("No contract was found with function %s", fname) sys.exit(-1) data = np.array(data) pca = decomposition.PCA(n_components=2) tdata = pca.fit_transform(data) - logger.info('Plotting data..') - plt.figure(figsize=(20,10)) - assert(len(tdata) == len(fs)) - for ([x,y],l) in zip(tdata, fs): + logger.info("Plotting data..") + plt.figure(figsize=(20, 10)) + assert len(tdata) == len(fs) + for ([x, y], l) in zip(tdata, fs): x = random.gauss(0, 0.01) + x y = random.gauss(0, 0.01) + y - plt.scatter(x, y, c='blue') - plt.text(x-0.001,y+0.001, l) + plt.scatter(x, y, c="blue") + plt.text(x - 0.001, y + 0.001, l) + + logger.info("Saving figure to plot.png..") + plt.savefig("plot.png", bbox_inches="tight") - logger.info('Saving figure to plot.png..') - plt.savefig('plot.png', bbox_inches='tight') - except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/similarity.py b/slither/tools/similarity/similarity.py index 4cc3f2b35..3cf30acda 100644 --- a/slither/tools/similarity/similarity.py +++ b/slither/tools/similarity/similarity.py @@ -1,5 +1,6 @@ import numpy as np + def similarity(v1, v2): n1 = np.linalg.norm(v1) n2 = np.linalg.norm(v2) diff --git a/slither/tools/similarity/test.py b/slither/tools/similarity/test.py index 15a39cc13..89043a5a1 100755 --- a/slither/tools/similarity/test.py +++ b/slither/tools/similarity/test.py @@ -5,50 +5,51 @@ import traceback import operator import numpy as np -from .model import load_model -from .encode import encode_contract, load_and_encode, parse_target -from .cache import save_cache +from .model import load_model +from .encode import encode_contract, load_and_encode, parse_target +from .cache import save_cache from .similarity import similarity logger = logging.getLogger("Slither-simil") + def test(args): try: model = args.model model = load_model(model) filename = args.filename - contract, fname = parse_target(args.fname) + contract, fname = parse_target(args.fname) infile = args.input ntop = args.ntop if filename is None or contract is None or fname is None or infile is None: - logger.error('The test mode requires filename, contract, fname and input parameters.') + logger.error("The test mode requires filename, contract, fname and input parameters.") sys.exit(-1) irs = encode_contract(filename, **vars(args)) if len(irs) == 0: sys.exit(-1) - y = " ".join(irs[(filename,contract,fname)]) - + y = " ".join(irs[(filename, contract, fname)]) + fvector = model.get_sentence_vector(y) cache = load_and_encode(infile, model, **vars(args)) - #save_cache("cache.npz", cache) + # save_cache("cache.npz", cache) r = dict() - for x,y in cache.items(): + for x, y in cache.items(): r[x] = similarity(fvector, y) r = sorted(r.items(), key=operator.itemgetter(1), reverse=True) logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop) format_table = "{: <65} {: <20} {: <20} {: <10}" logger.info(format_table.format(*["filename", "contract", "function", "score"])) - for x,score in r[:ntop]: + for x, score in r[:ntop]: score = str(round(score, 3)) - logger.info(format_table.format(*(list(x)+[score]))) + logger.info(format_table.format(*(list(x) + [score]))) except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/train.py b/slither/tools/similarity/train.py index e810450a6..3052ae6c5 100755 --- a/slither/tools/similarity/train.py +++ b/slither/tools/similarity/train.py @@ -5,12 +5,13 @@ import traceback import operator import os -from .model import train_unsupervised -from .encode import encode_contract, load_contracts -from .cache import save_cache +from .model import train_unsupervised +from .encode import encode_contract, load_contracts +from .cache import save_cache logger = logging.getLogger("Slither-simil") + def train(args): try: @@ -20,35 +21,37 @@ def train(args): nsamples = args.nsamples if dirname is None: - logger.error('The train mode requires the input parameter.') + logger.error("The train mode requires the input parameter.") sys.exit(-1) contracts = load_contracts(dirname, **vars(args)) - logger.info('Saving extracted data into %s', last_data_train_filename) + logger.info("Saving extracted data into %s", last_data_train_filename) cache = [] - with open(last_data_train_filename, 'w') as f: + with open(last_data_train_filename, "w") as f: for filename in contracts: - #cache[filename] = dict() - for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items(): + # cache[filename] = dict() + for (filename, contract, function), ir in encode_contract( + filename, **vars(args) + ).items(): if ir != []: x = " ".join(ir) - f.write(x+"\n") + f.write(x + "\n") cache.append((os.path.split(filename)[-1], contract, function, x)) - logger.info('Starting training') - model = train_unsupervised(input=last_data_train_filename, model='skipgram') - logger.info('Training complete') - logger.info('Saving model') + logger.info("Starting training") + model = train_unsupervised(input=last_data_train_filename, model="skipgram") + logger.info("Training complete") + logger.info("Saving model") model.save_model(model_filename) - for i,(filename, contract, function, irs) in enumerate(cache): + for i, (filename, contract, function, irs) in enumerate(cache): cache[i] = ((filename, contract, function), model.get_sentence_vector(irs)) - logger.info('Saving cache in cache.npz') + logger.info("Saving cache in cache.npz") save_cache(cache, "cache.npz") - logger.info('Done!') - + logger.info("Done!") + except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/slither_format/__main__.py b/slither/tools/slither_format/__main__.py index 66787ee0e..b99485236 100644 --- a/slither/tools/slither_format/__main__.py +++ b/slither/tools/slither_format/__main__.py @@ -10,63 +10,77 @@ logging.basicConfig() logger = logging.getLogger("Slither").setLevel(logging.INFO) # Slither detectors for which slither-format currently works -available_detectors = ["unused-state", - "solc-version", - "pragma", - "naming-convention", - "external-function", - "constable-states", - "constant-function-asm", - "constatnt-function-state"] +available_detectors = [ + "unused-state", + "solc-version", + "pragma", + "naming-convention", + "external-function", + "constable-states", + "constant-function-asm", + "constatnt-function-state", +] detectors_to_run = [] + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='slither_format', - usage='slither_format filename') - - parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.') - parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False) - parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False) - parser.add_argument('--version', - help='displays the current version', - version='0.1.0', - action='version') - - parser.add_argument('--config-file', - help='Provide a config file (default: slither.config.json)', - action='store', - dest='config_file', - default='slither.config.json') - - - group_detector = parser.add_argument_group('Detectors') - group_detector.add_argument('--detect', - help='Comma-separated list of detectors, defaults to all, ' - 'available detectors: {}'.format( - ', '.join(d for d in available_detectors)), - action='store', - dest='detectors_to_run', - default='all') - - group_detector.add_argument('--exclude', - help='Comma-separated list of detectors to exclude,' - 'available detectors: {}'.format( - ', '.join(d for d in available_detectors)), - action='store', - dest='detectors_to_exclude', - default='all') - - cryticparser.init(parser) - - if len(sys.argv) == 1: - parser.print_help(sys.stderr) + parser = argparse.ArgumentParser(description="slither_format", usage="slither_format filename") + + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) + parser.add_argument( + "--verbose-test", + "-v", + help="verbose mode output for testing", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose-json", "-j", help="verbose json output", action="store_true", default=False + ) + parser.add_argument( + "--version", help="displays the current version", version="0.1.0", action="version" + ) + + parser.add_argument( + "--config-file", + help="Provide a config file (default: slither.config.json)", + action="store", + dest="config_file", + default="slither.config.json", + ) + + group_detector = parser.add_argument_group("Detectors") + group_detector.add_argument( + "--detect", + help="Comma-separated list of detectors, defaults to all, " + "available detectors: {}".format(", ".join(d for d in available_detectors)), + action="store", + dest="detectors_to_run", + default="all", + ) + + group_detector.add_argument( + "--exclude", + help="Comma-separated list of detectors to exclude," + "available detectors: {}".format(", ".join(d for d in available_detectors)), + action="store", + dest="detectors_to_exclude", + default="all", + ) + + cryticparser.init(parser) + + if len(sys.argv) == 1: + parser.print_help(sys.stderr) sys.exit(1) - + return parser.parse_args() @@ -80,11 +94,12 @@ def main(): read_config_file(args) - # Perform slither analysis on the given filename slither = Slither(args.filename, **vars(args)) # Format the input files based on slither analysis slither_format(slither, **vars(args)) -if __name__ == '__main__': + + +if __name__ == "__main__": main() diff --git a/slither/tools/slither_format/slither_format.py b/slither/tools/slither_format/slither_format.py index 597c17753..659b69557 100644 --- a/slither/tools/slither_format/slither_format.py +++ b/slither/tools/slither_format/slither_format.py @@ -11,27 +11,29 @@ from slither.detectors.attributes.const_functions_state import ConstantFunctions from slither.utils.colors import yellow logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('Slither.Format') +logger = logging.getLogger("Slither.Format") all_detectors = { - 'unused-state': UnusedStateVars, - 'solc-version': IncorrectSolc, - 'pragma': ConstantPragma, - 'naming-convention': NamingConvention, - 'external-function': ExternalFunction, - 'constable-states' : ConstCandidateStateVars, - 'constant-function-asm': ConstantFunctionsAsm, - 'constant-functions-state': ConstantFunctionsState + "unused-state": UnusedStateVars, + "solc-version": IncorrectSolc, + "pragma": ConstantPragma, + "naming-convention": NamingConvention, + "external-function": ExternalFunction, + "constable-states": ConstCandidateStateVars, + "constant-function-asm": ConstantFunctionsAsm, + "constant-functions-state": ConstantFunctionsState, } + def slither_format(slither, **kwargs): - '''' + """' Keyword Args: detectors_to_run (str): Comma-separated list of detectors, defaults to all - ''' + """ - detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'), - kwargs.get('detectors_to_exclude', '')) + detectors_to_run = choose_detectors( + kwargs.get("detectors_to_run", "all"), kwargs.get("detectors_to_exclude", "") + ) for detector in detectors_to_run: slither.register_detector(detector) @@ -42,32 +44,32 @@ def slither_format(slither, **kwargs): detector_results = [x for x in detector_results if x] # remove empty results detector_results = [item for sublist in detector_results for item in sublist] # flatten - export = Path('crytic-export', 'patches') + export = Path("crytic-export", "patches") export.mkdir(parents=True, exist_ok=True) counter_result = 0 - logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.')) + logger.info(yellow("slither-format is in beta, carefully review each patch before merging it.")) for result in detector_results: - if not 'patches' in result: + if not "patches" in result: continue one_line_description = result["description"].split("\n")[0] - export_result = Path(export, f'{counter_result}') + export_result = Path(export, f"{counter_result}") export_result.mkdir(parents=True, exist_ok=True) counter_result += 1 counter = 0 - logger.info(f'Issue: {one_line_description}') - logger.info(f'Generated: ({export_result})') + logger.info(f"Issue: {one_line_description}") + logger.info(f"Generated: ({export_result})") - for file, diff, in result['patches_diff'].items(): - filename = f'fix_{counter}.patch' + for file, diff, in result["patches_diff"].items(): + filename = f"fix_{counter}.patch" path = Path(export_result, filename) - logger.info(f'\t- {filename}') - with open(path, 'w') as f: + logger.info(f"\t- {filename}") + with open(path, "w") as f: f.write(diff) counter += 1 @@ -79,26 +81,28 @@ def slither_format(slither, **kwargs): ################################################################################### ################################################################################### + def choose_detectors(detectors_to_run, detectors_to_exclude): # If detectors are specified, run only these ones cls_detectors_to_run = [] - exclude = detectors_to_exclude.split(',') - if detectors_to_run == 'all': + exclude = detectors_to_exclude.split(",") + if detectors_to_run == "all": for d in all_detectors: if d in exclude: continue cls_detectors_to_run.append(all_detectors[d]) else: - exclude = detectors_to_exclude.split(',') - for d in detectors_to_run.split(','): + exclude = detectors_to_exclude.split(",") + for d in detectors_to_run.split(","): if d in all_detectors: if d in exclude: continue cls_detectors_to_run.append(all_detectors[d]) else: - raise Exception('Error: {} is not a detector'.format(d)) + raise Exception("Error: {} is not a detector".format(d)) return cls_detectors_to_run + # endregion ################################################################################### ################################################################################### @@ -106,6 +110,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude): ################################################################################### ################################################################################### + def print_patches(number_of_slither_results, patches): logger.info("Number of Slither results: " + str(number_of_slither_results)) number_of_patches = 0 @@ -115,39 +120,38 @@ def print_patches(number_of_slither_results, patches): for file in patches: logger.info("Patch file: " + file) for patch in patches[file]: - logger.info("Detector: " + patch['detector']) - logger.info("Old string: " + patch['old_string'].replace("\n","")) - logger.info("New string: " + patch['new_string'].replace("\n","")) - logger.info("Location start: " + str(patch['start'])) - logger.info("Location end: " + str(patch['end'])) + logger.info("Detector: " + patch["detector"]) + logger.info("Old string: " + patch["old_string"].replace("\n", "")) + logger.info("New string: " + patch["new_string"].replace("\n", "")) + logger.info("Location start: " + str(patch["start"])) + logger.info("Location end: " + str(patch["end"])) + def print_patches_json(number_of_slither_results, patches): - print('{',end='') - print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",') - print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',') - print("\"Patchlets\":" + '[') + print("{", end="") + print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",') + print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",") + print('"Patchlets":' + "[") for index, file in enumerate(patches): if index > 0: - print(',') - print('{',end='') - print("\"Patch file\":" + '"' + file + '",') - print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',') - print("\"Patches\":" + '[') + print(",") + print("{", end="") + print('"Patch file":' + '"' + file + '",') + print('"Number of patches":' + '"' + str(len(patches[file])) + '"', ",") + print('"Patches":' + "[") for index, patch in enumerate(patches[file]): if index > 0: - print(',') - print('{',end='') - print("\"Detector\":" + '"' + patch['detector'] + '",') - print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",') - print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",') - print("\"Location start\":" + '"' + str(patch['start']) + '",') - print("\"Location end\":" + '"' + str(patch['end']) + '"') - if 'overlaps' in patch: - print("\"Overlaps\":" + "Yes") - print('}',end='') - print(']',end='') - print('}',end='') - print(']',end='') - print('}') - - + print(",") + print("{", end="") + print('"Detector":' + '"' + patch["detector"] + '",') + print('"Old string":' + '"' + patch["old_string"].replace("\n", "") + '",') + print('"New string":' + '"' + patch["new_string"].replace("\n", "") + '",') + print('"Location start":' + '"' + str(patch["start"]) + '",') + print('"Location end":' + '"' + str(patch["end"]) + '"') + if "overlaps" in patch: + print('"Overlaps":' + "Yes") + print("}", end="") + print("]", end="") + print("}", end="") + print("]", end="") + print("}") diff --git a/slither/tools/upgradeability/__main__.py b/slither/tools/upgradeability/__main__.py index 52b3be060..8903c01c5 100644 --- a/slither/tools/upgradeability/__main__.py +++ b/slither/tools/upgradeability/__main__.py @@ -12,7 +12,12 @@ from slither.utils.colors import red from slither.utils.output import output_to_json from .checks import all_checks from .checks.abstract_checks import AbstractCheck -from .utils.command_line import output_detectors_json, output_wiki, output_detectors, output_to_markdown +from .utils.command_line import ( + output_detectors_json, + output_wiki, + output_detectors, + output_to_markdown, +) logging.basicConfig() logger = logging.getLogger("Slither") @@ -21,49 +26,53 @@ logger.setLevel(logging.INFO) def parse_args(): parser = argparse.ArgumentParser( - description='Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.', - usage="slither-check-upgradeability contract.sol ContractName") - - parser.add_argument('contract.sol', help='Codebase to analyze') - parser.add_argument('ContractName', help='Contract name (logic contract)') - - parser.add_argument('--proxy-name', help='Proxy name') - parser.add_argument('--proxy-filename', help='Proxy filename (if different)') - - parser.add_argument('--new-contract-name', help='New contract name (if changed)') - parser.add_argument('--new-contract-filename', help='New implementation filename (if different)') - - parser.add_argument('--json', - help='Export the results as a JSON file ("--json -" to export to stdout)', - action='store', - default=False) - - parser.add_argument('--list-detectors', - help='List available detectors', - action=ListDetectors, - nargs=0, - default=False) - - parser.add_argument('--markdown-root', - help='URL for markdown generation', - action='store', - default="") - - parser.add_argument('--wiki-detectors', - help=argparse.SUPPRESS, - action=OutputWiki, - default=False) - - parser.add_argument('--list-detectors-json', - help=argparse.SUPPRESS, - action=ListDetectorsJson, - nargs=0, - default=False) - - parser.add_argument('--markdown', - help=argparse.SUPPRESS, - action=OutputMarkdown, - default=False) + description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.", + usage="slither-check-upgradeability contract.sol ContractName", + ) + + parser.add_argument("contract.sol", help="Codebase to analyze") + parser.add_argument("ContractName", help="Contract name (logic contract)") + + parser.add_argument("--proxy-name", help="Proxy name") + parser.add_argument("--proxy-filename", help="Proxy filename (if different)") + + parser.add_argument("--new-contract-name", help="New contract name (if changed)") + parser.add_argument( + "--new-contract-filename", help="New implementation filename (if different)" + ) + + parser.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=False, + ) + + parser.add_argument( + "--list-detectors", + help="List available detectors", + action=ListDetectors, + nargs=0, + default=False, + ) + + parser.add_argument( + "--markdown-root", help="URL for markdown generation", action="store", default="" + ) + + parser.add_argument( + "--wiki-detectors", help=argparse.SUPPRESS, action=OutputWiki, default=False + ) + + parser.add_argument( + "--list-detectors-json", + help=argparse.SUPPRESS, + action=ListDetectorsJson, + nargs=0, + default=False, + ) + + parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False) cryticparser.init(parser) @@ -80,6 +89,7 @@ def parse_args(): ################################################################################### ################################################################################### + def _get_checks(): detectors = [getattr(all_checks, name) for name in dir(all_checks)] detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)] @@ -123,13 +133,18 @@ def _run_checks(detectors): def _checks_on_contract(detectors, contract): - detectors = [d(logger, contract) for d in detectors if (not d.REQUIRE_PROXY and - not d.REQUIRE_CONTRACT_V2)] + detectors = [ + d(logger, contract) + for d in detectors + if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2) + ] return _run_checks(detectors), len(detectors) def _checks_on_contract_update(detectors, contract_v1, contract_v2): - detectors = [d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2] + detectors = [ + d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2 + ] return _run_checks(detectors), len(detectors) @@ -147,15 +162,11 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy): def main(): - json_results = { - 'proxy-present': False, - 'contract_v2-present': False, - 'detectors': [] - } + json_results = {"proxy-present": False, "contract_v2-present": False, "detectors": []} args = parse_args() - v1_filename = vars(args)['contract.sol'] + v1_filename = vars(args)["contract.sol"] number_detectors_run = 0 detectors = _get_checks() try: @@ -165,14 +176,14 @@ def main(): v1_name = args.ContractName v1_contract = v1.get_contract_from_name(v1_name) if v1_contract is None: - info = 'Contract {} not found in {}'.format(v1_name, v1.filename) + info = "Contract {} not found in {}".format(v1_name, v1.filename) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract) - json_results['detectors'] += detectors_results + json_results["detectors"] += detectors_results number_detectors_run += number_detectors # Analyze Proxy @@ -185,15 +196,17 @@ def main(): proxy_contract = proxy.get_contract_from_name(args.proxy_name) if proxy_contract is None: - info = 'Proxy {} not found in {}'.format(args.proxy_name, proxy.filename) + info = "Proxy {} not found in {}".format(args.proxy_name, proxy.filename) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return - json_results['proxy-present'] = True + json_results["proxy-present"] = True - detectors_results, number_detectors = _checks_on_contract_and_proxy(detectors, v1_contract, proxy_contract) - json_results['detectors'] += detectors_results + detectors_results, number_detectors = _checks_on_contract_and_proxy( + detectors, v1_contract, proxy_contract + ) + json_results["detectors"] += detectors_results number_detectors_run += number_detectors # Analyze new version if args.new_contract_name: @@ -204,30 +217,36 @@ def main(): v2_contract = v2.get_contract_from_name(args.new_contract_name) if v2_contract is None: - info = 'New logic contract {} not found in {}'.format(args.new_contract_name, v2.filename) + info = "New logic contract {} not found in {}".format( + args.new_contract_name, v2.filename + ) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return - json_results['contract_v2-present'] = True + json_results["contract_v2-present"] = True if proxy_contract: - detectors_results, _ = _checks_on_contract_and_proxy(detectors, - v2_contract, - proxy_contract) + detectors_results, _ = _checks_on_contract_and_proxy( + detectors, v2_contract, proxy_contract + ) - json_results['detectors'] += detectors_results + json_results["detectors"] += detectors_results - detectors_results, number_detectors = _checks_on_contract_update(detectors, v1_contract, v2_contract) - json_results['detectors'] += detectors_results + detectors_results, number_detectors = _checks_on_contract_update( + detectors, v1_contract, v2_contract + ) + json_results["detectors"] += detectors_results number_detectors_run += number_detectors # If there is a V2, we run the contract-only check on the V2 detectors_results, _ = _checks_on_contract(detectors, v2_contract) - json_results['detectors'] += detectors_results + json_results["detectors"] += detectors_results number_detectors_run += number_detectors - logger.info(f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run') + logger.info( + f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run' + ) if args.json: output_to_json(args.json, None, json_results) @@ -237,4 +256,5 @@ def main(): output_to_json(args.json, str(e), json_results) return + # endregion diff --git a/slither/tools/upgradeability/checks/abstract_checks.py b/slither/tools/upgradeability/checks/abstract_checks.py index 94d55e107..05a0d1182 100644 --- a/slither/tools/upgradeability/checks/abstract_checks.py +++ b/slither/tools/upgradeability/checks/abstract_checks.py @@ -19,28 +19,28 @@ classification_colors = { CheckClassification.INFORMATIONAL: green, CheckClassification.LOW: yellow, CheckClassification.MEDIUM: yellow, - CheckClassification.HIGH: red + CheckClassification.HIGH: red, } classification_txt = { - CheckClassification.INFORMATIONAL: 'Informational', - CheckClassification.LOW: 'Low', - CheckClassification.MEDIUM: 'Medium', - CheckClassification.HIGH: 'High', + CheckClassification.INFORMATIONAL: "Informational", + CheckClassification.LOW: "Low", + CheckClassification.MEDIUM: "Medium", + CheckClassification.HIGH: "High", } class AbstractCheck(metaclass=abc.ABCMeta): - ARGUMENT = '' - HELP = '' + ARGUMENT = "" + HELP = "" IMPACT = None - WIKI = '' + WIKI = "" - WIKI_TITLE = '' - WIKI_DESCRIPTION = '' - WIKI_EXPLOIT_SCENARIO = '' - WIKI_RECOMMENDATION = '' + WIKI_TITLE = "" + WIKI_DESCRIPTION = "" + WIKI_EXPLOIT_SCENARIO = "" + WIKI_RECOMMENDATION = "" REQUIRE_CONTRACT = False REQUIRE_PROXY = False @@ -53,43 +53,69 @@ class AbstractCheck(metaclass=abc.ABCMeta): self.contract_v2 = contract_v2 if not self.ARGUMENT: - raise IncorrectCheckInitialization('NAME is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "NAME is not initialized {}".format(self.__class__.__name__) + ) if not self.HELP: - raise IncorrectCheckInitialization('HELP is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "HELP is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI: - raise IncorrectCheckInitialization('WIKI is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI_TITLE: - raise IncorrectCheckInitialization('WIKI_TITLE is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI_TITLE is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI_DESCRIPTION: - raise IncorrectCheckInitialization('WIKI_DESCRIPTION is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI_DESCRIPTION is not initialized {}".format(self.__class__.__name__) + ) - if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [CheckClassification.INFORMATIONAL]: - raise IncorrectCheckInitialization('WIKI_EXPLOIT_SCENARIO is not initialized {}'.format(self.__class__.__name__)) + if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [ + CheckClassification.INFORMATIONAL + ]: + raise IncorrectCheckInitialization( + "WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI_RECOMMENDATION: - raise IncorrectCheckInitialization('WIKI_RECOMMENDATION is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__) + ) if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2: # This is not a fundatemenal issues # But it requires to change __main__ to avoid running two times the detectors - txt = 'REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}'.format(self.__class__.__name__) + txt = "REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}".format( + self.__class__.__name__ + ) raise IncorrectCheckInitialization(txt) - if self.IMPACT not in [CheckClassification.LOW, - CheckClassification.MEDIUM, - CheckClassification.HIGH, - CheckClassification.INFORMATIONAL]: - raise IncorrectCheckInitialization('IMPACT is not initialized {}'.format(self.__class__.__name__)) + if self.IMPACT not in [ + CheckClassification.LOW, + CheckClassification.MEDIUM, + CheckClassification.HIGH, + CheckClassification.INFORMATIONAL, + ]: + raise IncorrectCheckInitialization( + "IMPACT is not initialized {}".format(self.__class__.__name__) + ) if self.REQUIRE_CONTRACT_V2 and contract_v2 is None: - raise IncorrectCheckInitialization('ContractV2 is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "ContractV2 is not initialized {}".format(self.__class__.__name__) + ) if self.REQUIRE_PROXY and proxy is None: - raise IncorrectCheckInitialization('Proxy is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "Proxy is not initialized {}".format(self.__class__.__name__) + ) @abc.abstractmethod def _check(self): @@ -102,19 +128,17 @@ class AbstractCheck(metaclass=abc.ABCMeta): all_results = [r.data for r in all_results] if all_results: if self.logger: - info = '\n' + info = "\n" for idx, result in enumerate(all_results): - info += result['description'] - info += 'Reference: {}'.format(self.WIKI) + info += result["description"] + info += "Reference: {}".format(self.WIKI) self._log(info) return all_results def generate_result(self, info, additional_fields=None): - output = Output(info, - additional_fields, - markdown_root=self.contract.slither.markdown_root) + output = Output(info, additional_fields, markdown_root=self.contract.slither.markdown_root) - output.data['check'] = self.ARGUMENT + output.data["check"] = self.ARGUMENT return output diff --git a/slither/tools/upgradeability/checks/all_checks.py b/slither/tools/upgradeability/checks/all_checks.py index 1c41316c6..fcedb69f8 100644 --- a/slither/tools/upgradeability/checks/all_checks.py +++ b/slither/tools/upgradeability/checks/all_checks.py @@ -1,11 +1,23 @@ -from .initialization import (InitializablePresent, InitializableInherited, - InitializableInitializer, MissingInitializerModifier, MissingCalls, MultipleCalls, InitializeTarget) +from .initialization import ( + InitializablePresent, + InitializableInherited, + InitializableInitializer, + MissingInitializerModifier, + MissingCalls, + MultipleCalls, + InitializeTarget, +) from .functions_ids import IDCollision, FunctionShadowing from .variable_initialization import VariableWithInit -from .variables_order import (MissingVariable, DifferentVariableContractProxy, - DifferentVariableContractNewContract, ExtraVariablesProxy, ExtraVariablesNewContract) +from .variables_order import ( + MissingVariable, + DifferentVariableContractProxy, + DifferentVariableContractNewContract, + ExtraVariablesProxy, + ExtraVariablesNewContract, +) -from .constant import WereConstant, BecameConstant \ No newline at end of file +from .constant import WereConstant, BecameConstant diff --git a/slither/tools/upgradeability/checks/constant.py b/slither/tools/upgradeability/checks/constant.py index 60f37f8d3..e1d547e28 100644 --- a/slither/tools/upgradeability/checks/constant.py +++ b/slither/tools/upgradeability/checks/constant.py @@ -2,17 +2,17 @@ from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, C class WereConstant(AbstractCheck): - ARGUMENT = 'were-constant' + ARGUMENT = "were-constant" IMPACT = CheckClassification.HIGH - HELP = 'Variables that should be constant' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant' - WIKI_TITLE = 'Variables that should be constant' - WIKI_DESCRIPTION = ''' + HELP = "Variables that should be constant" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant" + WIKI_TITLE = "Variables that should be constant" + WIKI_DESCRIPTION = """ Detect state variables that should be `constant̀`. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -28,11 +28,11 @@ contract ContractV2{ ``` Because `variable2` is not anymore a `constant`, the storage location of `variable3` will be different. As a result, `ContractV2` will have a corrupted storage layout. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Do not remove `constant` from a state variables during an update. -''' +""" REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True @@ -66,12 +66,13 @@ Do not remove `constant` from a state variables during an update. if state_v1.is_constant: if not state_v2.is_constant: # If v2 has additional non constant variables, we need to skip them - if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type) - and v2_additional_variables > 0): + if ( + state_v1.name != state_v2.name or state_v1.type != state_v2.type + ) and v2_additional_variables > 0: v2_additional_variables -= 1 idx_v2 += 1 continue - info = [state_v1, ' was constant, but ', state_v2, 'is not.\n'] + info = [state_v1, " was constant, but ", state_v2, "is not.\n"] json = self.generate_result(info) results.append(json) @@ -80,19 +81,20 @@ Do not remove `constant` from a state variables during an update. return results + class BecameConstant(AbstractCheck): - ARGUMENT = 'became-constant' + ARGUMENT = "became-constant" IMPACT = CheckClassification.HIGH - HELP = 'Variables that should not be constant' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant' - WIKI_TITLE = 'Variables that should not be constant' + HELP = "Variables that should not be constant" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant" + WIKI_TITLE = "Variables that should not be constant" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect state variables that should not be `constant̀`. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -108,11 +110,11 @@ contract ContractV2{ ``` Because `variable2` is now a `constant`, the storage location of `variable3` will be different. As a result, `ContractV2` will have a corrupted storage layout. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Do not make an existing state variable `constant`. -''' +""" REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True @@ -146,13 +148,14 @@ Do not make an existing state variable `constant`. if state_v1.is_constant: if not state_v2.is_constant: # If v2 has additional non constant variables, we need to skip them - if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type) - and v2_additional_variables > 0): + if ( + state_v1.name != state_v2.name or state_v1.type != state_v2.type + ) and v2_additional_variables > 0: v2_additional_variables -= 1 idx_v2 += 1 continue elif state_v2.is_constant: - info = [state_v1, ' was not constant but ', state_v2, ' is.\n'] + info = [state_v1, " was not constant but ", state_v2, " is.\n"] json = self.generate_result(info) results.append(json) diff --git a/slither/tools/upgradeability/checks/functions_ids.py b/slither/tools/upgradeability/checks/functions_ids.py index ecacbb798..cbc822d20 100644 --- a/slither/tools/upgradeability/checks/functions_ids.py +++ b/slither/tools/upgradeability/checks/functions_ids.py @@ -5,11 +5,16 @@ from slither.utils.function import get_function_id def get_signatures(c): functions = c.functions - functions = [f.full_name for f in functions if f.visibility in ['public', 'external'] and - not f.is_constructor and not f.is_fallback] + functions = [ + f.full_name + for f in functions + if f.visibility in ["public", "external"] and not f.is_constructor and not f.is_fallback + ] variables = c.state_variables - variables = [variable.name + '()' for variable in variables if variable.visibility in ['public']] + variables = [ + variable.name + "()" for variable in variables if variable.visibility in ["public"] + ] return list(set(functions + variables)) @@ -21,26 +26,26 @@ def _get_function_or_variable(contract, signature): for variable in contract.state_variables: # Todo: can lead to incorrect variable in case of shadowing - if variable.visibility in ['public']: - if variable.name + '()' == signature: + if variable.visibility in ["public"]: + if variable.name + "()" == signature: return variable - raise SlitherError(f'Function id checks: {signature} not found in {contract.name}') + raise SlitherError(f"Function id checks: {signature} not found in {contract.name}") class IDCollision(AbstractCheck): - ARGUMENT = 'function-id-collision' + ARGUMENT = "function-id-collision" IMPACT = CheckClassification.HIGH - HELP = 'Functions ids collision' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions' - WIKI_TITLE = 'Functions ids collisions' + HELP = "Functions ids collision" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions" + WIKI_TITLE = "Functions ids collisions" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect function id collision between the contract and the proxy. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ function gsf() public { @@ -56,11 +61,11 @@ contract Proxy{ ``` `Proxy.tgeo()` and `Contract.gsf()` have the same function id (0x67e43e43). As a result, `Proxy.tgeo()` will shadow Contract.gsf()`. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Rename the function. Avoid public functions in the proxy. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -77,11 +82,18 @@ Rename the function. Avoid public functions in the proxy. for (k, _) in signatures_ids_implem.items(): if k in signatures_ids_proxy: if signatures_ids_implem[k] != signatures_ids_proxy[k]: - implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k]) + implem_function = _get_function_or_variable( + self.contract, signatures_ids_implem[k] + ) proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) - info = ['Function id collision found: ', implem_function, - ' ', proxy_function, '\n'] + info = [ + "Function id collision found: ", + implem_function, + " ", + proxy_function, + "\n", + ] json = self.generate_result(info) results.append(json) @@ -89,18 +101,18 @@ Rename the function. Avoid public functions in the proxy. class FunctionShadowing(AbstractCheck): - ARGUMENT = 'function-shadowing' + ARGUMENT = "function-shadowing" IMPACT = CheckClassification.HIGH - HELP = 'Functions shadowing' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing' - WIKI_TITLE = 'Functions shadowing' + HELP = "Functions shadowing" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing" + WIKI_TITLE = "Functions shadowing" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect function shadowing between the contract and the proxy. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ function get() public { @@ -115,11 +127,11 @@ contract Proxy{ } ``` `Proxy.get` will shadow any call to `get()`. As a result `get()` is never executed in the logic contract and cannot be updated. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Rename the function. Avoid public functions in the proxy. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -136,11 +148,18 @@ Rename the function. Avoid public functions in the proxy. for (k, _) in signatures_ids_implem.items(): if k in signatures_ids_proxy: if signatures_ids_implem[k] == signatures_ids_proxy[k]: - implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k]) + implem_function = _get_function_or_variable( + self.contract, signatures_ids_implem[k] + ) proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) - info = ['Function shadowing found: ', implem_function, - ' ', proxy_function, '\n'] + info = [ + "Function shadowing found: ", + implem_function, + " ", + proxy_function, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/tools/upgradeability/checks/initialization.py b/slither/tools/upgradeability/checks/initialization.py index 3809f294b..2e37457dc 100644 --- a/slither/tools/upgradeability/checks/initialization.py +++ b/slither/tools/upgradeability/checks/initialization.py @@ -13,16 +13,20 @@ class MultipleInitTarget(Exception): def _get_initialize_functions(contract): - return [f for f in contract.functions if f.name == 'initialize' and f.is_implemented] + return [f for f in contract.functions if f.name == "initialize" and f.is_implemented] def _get_all_internal_calls(function): all_ir = function.all_slithir_operations() - return [i.function for i in all_ir if isinstance(i, InternalCall) and i.function_name == "initialize"] + return [ + i.function + for i in all_ir + if isinstance(i, InternalCall) and i.function_name == "initialize" + ] def _get_most_derived_init(contract): - init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == 'initialize'] + init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == "initialize"] if len(init_functions) > 1: if len([f for f in init_functions if f.contract_declarer == contract]) == 1: return next((f for f in init_functions if f.contract_declarer == contract)) @@ -33,80 +37,82 @@ def _get_most_derived_init(contract): class InitializablePresent(AbstractCheck): - ARGUMENT = 'init-missing' + ARGUMENT = "init-missing" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'Initializable is missing' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing' - WIKI_TITLE = 'Initializable is missing' - WIKI_DESCRIPTION = ''' + HELP = "Initializable is missing" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing" + WIKI_TITLE = "Initializable is missing" + WIKI_DESCRIPTION = """ Detect if a contract `Initializable` is present. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Review manually the contract's initialization.. Consider using a `Initializable` contract to follow [standard practice](https://docs.openzeppelin.com/upgrades/2.7/writing-upgradeable). -''' +""" def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") if initializable is None: - info = ["Initializable contract not found, the contract does not follow a standard initalization schema.\n"] + info = [ + "Initializable contract not found, the contract does not follow a standard initalization schema.\n" + ] json = self.generate_result(info) return [json] return [] class InitializableInherited(AbstractCheck): - ARGUMENT = 'init-inherited' + ARGUMENT = "init-inherited" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'Initializable is not inherited' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited' - WIKI_TITLE = 'Initializable is not inherited' + HELP = "Initializable is not inherited" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited" + WIKI_TITLE = "Initializable is not inherited" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect if `Initializable` is inherited. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Review manually the contract's initialization. Consider inheriting `Initializable`. -''' +""" REQUIRE_CONTRACT = True def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") # See InitializablePresent if initializable is None: return [] if initializable not in self.contract.inheritance: - info = [self.contract, ' does not inherit from ', initializable, '.\n'] + info = [self.contract, " does not inherit from ", initializable, ".\n"] json = self.generate_result(info) return [json] return [] class InitializableInitializer(AbstractCheck): - ARGUMENT = 'initializer-missing' + ARGUMENT = "initializer-missing" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'initializer() is missing' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing' - WIKI_TITLE = 'initializer() is missing' + HELP = "initializer() is missing" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing" + WIKI_TITLE = "initializer() is missing" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect the lack of `Initializable.initializer()` modifier. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Review manually the contract's initialization. Consider inheriting a `Initializable.initializer()` modifier. -''' +""" REQUIRE_CONTRACT = True def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") # See InitializablePresent if initializable is None: return [] @@ -114,26 +120,26 @@ Review manually the contract's initialization. Consider inheriting a `Initializa if initializable not in self.contract.inheritance: return [] - initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()') + initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()") if initializer is None: - info = ['Initializable.initializer() does not exist.\n'] + info = ["Initializable.initializer() does not exist.\n"] json = self.generate_result(info) return [json] return [] class MissingInitializerModifier(AbstractCheck): - ARGUMENT = 'missing-init-modifier' + ARGUMENT = "missing-init-modifier" IMPACT = CheckClassification.HIGH - HELP = 'initializer() is not called' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called' - WIKI_TITLE = 'initializer() is not called' - WIKI_DESCRIPTION = ''' + HELP = "initializer() is not called" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called" + WIKI_TITLE = "initializer() is not called" + WIKI_DESCRIPTION = """ Detect if `Initializable.initializer()` is called. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ function initialize() public{ @@ -143,23 +149,23 @@ contract Contract{ ``` `initialize` should have the `initializer` modifier to prevent someone from initializing the contract multiple times. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Use `Initializable.initializer()`. -''' +""" REQUIRE_CONTRACT = True def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") # See InitializablePresent if initializable is None: return [] # See InitializableInherited if initializable not in self.contract.inheritance: return [] - initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()') + initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()") # InitializableInitializer if initializer is None: return [] @@ -168,24 +174,24 @@ Use `Initializable.initializer()`. all_init_functions = _get_initialize_functions(self.contract) for f in all_init_functions: if initializer not in f.modifiers: - info = [f, ' does not call the initializer modifier.\n'] + info = [f, " does not call the initializer modifier.\n"] json = self.generate_result(info) results.append(json) return results class MissingCalls(AbstractCheck): - ARGUMENT = 'missing-calls' + ARGUMENT = "missing-calls" IMPACT = CheckClassification.HIGH - HELP = 'Missing calls to init functions' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called' - WIKI_TITLE = 'Initialize functions are not called' - WIKI_DESCRIPTION = ''' + HELP = "Missing calls to init functions" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called" + WIKI_TITLE = "Initialize functions are not called" + WIKI_DESCRIPTION = """ Detect missing calls to initialize functions. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Base{ function initialize() public{ @@ -200,11 +206,11 @@ contract Derived is Base{ ``` `Derived.initialize` does not call `Base.initialize` leading the contract to not be correctly initialized. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Ensure all the initialize functions are reached by the most derived initialize function. -''' +""" REQUIRE_CONTRACT = True @@ -215,7 +221,7 @@ Ensure all the initialize functions are reached by the most derived initialize f try: most_derived_init = _get_most_derived_init(self.contract) except MultipleInitTarget: - logger.error(red(f'Too many init targets in {self.contract}')) + logger.error(red(f"Too many init targets in {self.contract}")) return [] if most_derived_init is None: @@ -225,24 +231,24 @@ Ensure all the initialize functions are reached by the most derived initialize f all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] missing_calls = [f for f in all_init_functions if not f in all_init_functions_called] for f in missing_calls: - info = ['Missing call to ', f, ' in ', most_derived_init, '.\n'] + info = ["Missing call to ", f, " in ", most_derived_init, ".\n"] json = self.generate_result(info) results.append(json) return results class MultipleCalls(AbstractCheck): - ARGUMENT = 'multiple-calls' + ARGUMENT = "multiple-calls" IMPACT = CheckClassification.HIGH - HELP = 'Init functions called multiple times' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times' - WIKI_TITLE = 'Initialize functions are called multiple times' - WIKI_DESCRIPTION = ''' + HELP = "Init functions called multiple times" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times" + WIKI_TITLE = "Initialize functions are called multiple times" + WIKI_DESCRIPTION = """ Detect multiple calls to a initialize function. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Base{ function initialize(uint) public{ @@ -264,11 +270,11 @@ contract DerivedDerived is Derived{ ``` `Base.initialize(uint)` is called two times in `DerivedDerived.initiliaze` execution, leading to a potential corruption. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Call only one time every initialize function. -''' +""" REQUIRE_CONTRACT = True @@ -280,38 +286,41 @@ Call only one time every initialize function. most_derived_init = _get_most_derived_init(self.contract) except MultipleInitTarget: # Should be already reported by MissingCalls - #logger.error(red(f'Too many init targets in {self.contract}')) + # logger.error(red(f'Too many init targets in {self.contract}')) return [] if most_derived_init is None: return [] all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] - double_calls = list(set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1])) + double_calls = list( + set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1]) + ) for f in double_calls: - info = [f, ' is called multiple times in ', most_derived_init, '.\n'] + info = [f, " is called multiple times in ", most_derived_init, ".\n"] json = self.generate_result(info) results.append(json) return results + class InitializeTarget(AbstractCheck): - ARGUMENT = 'initialize-target' + ARGUMENT = "initialize-target" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'Initialize function that must be called' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function' - WIKI_TITLE = 'Initialize function' + HELP = "Initialize function that must be called" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function" + WIKI_TITLE = "Initialize function" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Show the function that must be called at deployment. This finding does not have an immediate security impact and is informative. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Ensure that the function is called at deployment. -''' +""" REQUIRE_CONTRACT = True @@ -322,12 +331,12 @@ Ensure that the function is called at deployment. most_derived_init = _get_most_derived_init(self.contract) except MultipleInitTarget: # Should be already reported by MissingCalls - #logger.error(red(f'Too many init targets in {self.contract}')) + # logger.error(red(f'Too many init targets in {self.contract}')) return [] if most_derived_init is None: return [] - info = [self.contract, f' needs to be initialized by ', most_derived_init, '.\n'] + info = [self.contract, f" needs to be initialized by ", most_derived_init, ".\n"] json = self.generate_result(info) return [json] diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index 9c03e0789..7b9316ef7 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -2,29 +2,29 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat class VariableWithInit(AbstractCheck): - ARGUMENT = 'variables-initialized' + ARGUMENT = "variables-initialized" IMPACT = CheckClassification.HIGH - HELP = 'State variables with an initial value' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized' - WIKI_TITLE = 'State variable initialized' + HELP = "State variables with an initial value" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized" + WIKI_TITLE = "State variable initialized" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect state variables that are initialized. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable = 10; } ``` Using `Contract` will the delegatecall proxy pattern will lead `variable` to be 0 when called through the proxy. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Using initialize functions to write initial values in state variables. -''' +""" REQUIRE_CONTRACT = True @@ -32,7 +32,7 @@ Using initialize functions to write initial values in state variables. results = [] for s in self.contract.state_variables: if s.initialized and not s.is_constant: - info = [s, ' is a state variable with an initial value.\n'] + info = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) results.append(json) return results diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index f300b9dca..735a8a1e1 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -2,16 +2,16 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat class MissingVariable(AbstractCheck): - ARGUMENT = 'missing-variables' + ARGUMENT = "missing-variables" IMPACT = CheckClassification.MEDIUM - HELP = 'Variable missing in the v2' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables' - WIKI_TITLE = 'Missing variables' - WIKI_DESCRIPTION = ''' + HELP = "Variable missing in the v2" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables" + WIKI_TITLE = "Missing variables" + WIKI_DESCRIPTION = """ Detect variables that were present in the original contracts but are not in the updated one. -''' - WIKI_EXPLOIT_SCENARIO = ''' +""" + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract V1{ uint variable1; @@ -25,11 +25,11 @@ contract V2{ The new version, `V2` does not contain `variable1`. If a new variable is added in an update of `V2`, this variable will hold the latest value of `variable2` and will be corrupted. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Do not change the order of the state variables in the updated contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True @@ -44,7 +44,7 @@ Do not change the order of the state variables in the updated contract. for idx in range(0, len(order1)): variable1 = order1[idx] if len(order2) <= idx: - info = ['Variable missing in ', contract2, ': ', variable1, '\n'] + info = ["Variable missing in ", contract2, ": ", variable1, "\n"] json = self.generate_result(info) results.append(json) @@ -52,18 +52,18 @@ Do not change the order of the state variables in the updated contract. class DifferentVariableContractProxy(AbstractCheck): - ARGUMENT = 'order-vars-proxy' + ARGUMENT = "order-vars-proxy" IMPACT = CheckClassification.HIGH - HELP = 'Incorrect vars order with the proxy' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy' - WIKI_TITLE = 'Incorrect variables with the proxy' + HELP = "Incorrect vars order with the proxy" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy" + WIKI_TITLE = "Incorrect variables with the proxy" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect variables that are different between the contract and the proxy. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -74,11 +74,11 @@ contract Proxy{ } ``` `Contract` and `Proxy` do not have the same storage layout. As a result the storage of both contracts can be corrupted. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -104,9 +104,9 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s variable1 = order1[idx] variable2 = order2[idx] if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info = ['Different variables between ', contract1, ' and ', contract2, '\n'] - info += [f'\t ', variable1, '\n'] - info += [f'\t ', variable2, '\n'] + info = ["Different variables between ", contract1, " and ", contract2, "\n"] + info += [f"\t ", variable1, "\n"] + info += [f"\t ", variable2, "\n"] json = self.generate_result(info) results.append(json) @@ -114,17 +114,17 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s class DifferentVariableContractNewContract(DifferentVariableContractProxy): - ARGUMENT = 'order-vars-contracts' + ARGUMENT = "order-vars-contracts" - HELP = 'Incorrect vars order with the v2' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2' - WIKI_TITLE = 'Incorrect variables with the v2' + HELP = "Incorrect vars order with the v2" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2" + WIKI_TITLE = "Incorrect variables with the v2" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect variables that are different between the original contract and the updated one. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -135,11 +135,11 @@ contract ContractV2{ } ``` `Contract` and `ContractV2` do not have the same storage layout. As a result the storage of both contracts can be corrupted. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Respect the variable order of the original contract in the updated contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = False @@ -150,18 +150,20 @@ Respect the variable order of the original contract in the updated contract. class ExtraVariablesProxy(AbstractCheck): - ARGUMENT = 'extra-vars-proxy' + ARGUMENT = "extra-vars-proxy" IMPACT = CheckClassification.MEDIUM - HELP = 'Extra vars in the proxy' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy' - WIKI_TITLE = 'Extra variables in the proxy' + HELP = "Extra vars in the proxy" + WIKI = ( + "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy" + ) + WIKI_TITLE = "Extra variables in the proxy" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect variables that are in the proxy and not in the contract. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -173,11 +175,11 @@ contract Proxy{ } ``` `Proxy` contains additional variables. A future update of `Contract` is likely to corrupt the proxy. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -203,7 +205,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s while idx < len(order2): variable2 = order2[idx] - info = ['Extra variables in ', contract2, ': ', variable2, '\n'] + info = ["Extra variables in ", contract2, ": ", variable2, "\n"] json = self.generate_result(info) results.append(json) idx = idx + 1 @@ -212,21 +214,21 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s class ExtraVariablesNewContract(ExtraVariablesProxy): - ARGUMENT = 'extra-vars-v2' + ARGUMENT = "extra-vars-v2" - HELP = 'Extra vars in the v2' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2' - WIKI_TITLE = 'Extra variables in the v2' + HELP = "Extra vars in the v2" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2" + WIKI_TITLE = "Extra variables in the v2" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Show new variables in the updated contract. This finding does not have an immediate security impact and is informative. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Ensure that all the new variables are expected. -''' +""" IMPACT = CheckClassification.INFORMATIONAL diff --git a/slither/tools/upgradeability/utils/command_line.py b/slither/tools/upgradeability/utils/command_line.py index 2af6daa54..57a6ef88d 100644 --- a/slither/tools/upgradeability/utils/command_line.py +++ b/slither/tools/upgradeability/utils/command_line.py @@ -4,7 +4,9 @@ from slither.utils.myprettytable import MyPrettyTable def output_wiki(detector_classes, filter_wiki): # Sort by impact, confidence, and name - detectors_list = sorted(detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)) + detectors_list = sorted( + detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT) + ) for detector in detectors_list: if filter_wiki not in detector.WIKI: @@ -16,16 +18,16 @@ def output_wiki(detector_classes, filter_wiki): exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO recommendation = detector.WIKI_RECOMMENDATION - print('\n## {}'.format(title)) - print('### Configuration') - print('* Check: `{}`'.format(argument)) - print('* Severity: `{}`'.format(impact)) - print('\n### Description') + print("\n## {}".format(title)) + print("### Configuration") + print("* Check: `{}`".format(argument)) + print("* Severity: `{}`".format(impact)) + print("\n### Description") print(description) if exploit_scenario: - print('\n### Exploit Scenario:') + print("\n### Exploit Scenario:") print(exploit_scenario) - print('\n### Recommendation') + print("\n### Recommendation") print(recommendation) @@ -38,27 +40,31 @@ def output_detectors(detector_classes): require_proxy = detector.REQUIRE_PROXY require_v2 = detector.REQUIRE_CONTRACT_V2 detectors_list.append((argument, help_info, impact, require_proxy, require_v2)) - table = MyPrettyTable(["Num", - "Check", - "What it Detects", - "Impact", - "Proxy", - "Contract V2"]) + table = MyPrettyTable(["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"]) # Sort by impact, confidence, and name detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 for (argument, help_info, impact, proxy, v2) in detectors_list: - table.add_row([idx, argument, help_info, classification_txt[impact], 'X' if proxy else '', 'X' if v2 else '']) + table.add_row( + [ + idx, + argument, + help_info, + classification_txt[impact], + "X" if proxy else "", + "X" if v2 else "", + ] + ) idx = idx + 1 print(table) def output_to_markdown(detector_classes, filter_wiki): def extract_help(cls): - if cls.WIKI == '': + if cls.WIKI == "": return cls.HELP - return '[{}]({})'.format(cls.HELP, cls.WIKI) + return "[{}]({})".format(cls.HELP, cls.WIKI) detectors_list = [] for detector in detector_classes: @@ -73,12 +79,16 @@ def output_to_markdown(detector_classes, filter_wiki): detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 for (argument, help_info, impact, proxy, v2) in detectors_list: - print('{} | `{}` | {} | {} | {} | {}'.format(idx, - argument, - help_info, - classification_txt[impact], - 'X' if proxy else '', - 'X' if v2 else '')) + print( + "{} | `{}` | {} | {} | {} | {}".format( + idx, + argument, + help_info, + classification_txt[impact], + "X" if proxy else "", + "X" if v2 else "", + ) + ) idx = idx + 1 @@ -92,26 +102,42 @@ def output_detectors_json(detector_classes): wiki_description = detector.WIKI_DESCRIPTION wiki_exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO wiki_recommendation = detector.WIKI_RECOMMENDATION - detectors_list.append((argument, - help_info, - impact, - wiki_url, - wiki_description, - wiki_exploit_scenario, - wiki_recommendation)) + detectors_list.append( + ( + argument, + help_info, + impact, + wiki_url, + wiki_description, + wiki_exploit_scenario, + wiki_recommendation, + ) + ) # Sort by impact, confidence, and name detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 table = [] - for (argument, help_info, impact, wiki_url, description, exploit, recommendation) in detectors_list: - table.append({'index': idx, - 'check': argument, - 'title': help_info, - 'impact': classification_txt[impact], - 'wiki_url': wiki_url, - 'description': description, - 'exploit_scenario': exploit, - 'recommendation': recommendation}) + for ( + argument, + help_info, + impact, + wiki_url, + description, + exploit, + recommendation, + ) in detectors_list: + table.append( + { + "index": idx, + "check": argument, + "title": help_info, + "impact": classification_txt[impact], + "wiki_url": wiki_url, + "description": description, + "exploit_scenario": exploit, + "recommendation": recommendation, + } + ) idx = idx + 1 return table diff --git a/slither/utils/arithmetic.py b/slither/utils/arithmetic.py index 0612b0a26..03b2c1de1 100644 --- a/slither/utils/arithmetic.py +++ b/slither/utils/arithmetic.py @@ -3,32 +3,32 @@ from decimal import Decimal from slither.exceptions import SlitherException -def convert_subdenomination(value, sub): +def convert_subdenomination(value: str, sub: str) -> int: # to allow 0.1 ether conversion if value[0:2] == "0x": - value = Decimal(int(value, 16)) + decimal_value = Decimal(int(value, 16)) else: - value = Decimal(value) - if sub == 'wei': - return int(value) - if sub == 'szabo': - return int(value * int(1e12)) - if sub == 'finney': - return int(value * int(1e15)) - if sub == 'ether': - return int(value * int(1e18)) - if sub == 'seconds': - return int(value) - if sub == 'minutes': - return int(value * 60) - if sub == 'hours': - return int(value * 60 * 60) - if sub == 'days': - return int(value * 60 * 60 * 24) - if sub == 'weeks': - return int(value * 60 * 60 * 24 * 7) - if sub == 'years': - return int(value * 60 * 60 * 24 * 7 * 365) + decimal_value = Decimal(value) + if sub == "wei": + return int(decimal_value) + if sub == "szabo": + return int(decimal_value * int(1e12)) + if sub == "finney": + return int(decimal_value * int(1e15)) + if sub == "ether": + return int(decimal_value * int(1e18)) + if sub == "seconds": + return int(decimal_value) + if sub == "minutes": + return int(decimal_value * 60) + if sub == "hours": + return int(decimal_value * 60 * 60) + if sub == "days": + return int(decimal_value * 60 * 60 * 24) + if sub == "weeks": + return int(decimal_value * 60 * 60 * 24 * 7) + if sub == "years": + return int(decimal_value * 60 * 60 * 24 * 7 * 365) - raise SlitherException(f'Subdemonination conversion impossible {value} {sub}') \ No newline at end of file + raise SlitherException(f"Subdemonination conversion impossible {decimal_value} {sub}") diff --git a/slither/utils/code_complexity.py b/slither/utils/code_complexity.py index 6d533ab42..54b14f028 100644 --- a/slither/utils/code_complexity.py +++ b/slither/utils/code_complexity.py @@ -1,6 +1,12 @@ # Function computing the code complexity +from typing import TYPE_CHECKING, List -def compute_number_edges(function): +if TYPE_CHECKING: + from slither.core.declarations import Function + from slither.core.cfg.node import Node + + +def compute_number_edges(function: "Function") -> int: """ Compute the number of edges of the CFG Args: @@ -14,7 +20,7 @@ def compute_number_edges(function): return n -def compute_strongly_connected_components(function): +def compute_strongly_connected_components(function: "Function") -> List[List["Node"]]: """ Compute strongly connected components Based on Kosaraju algo @@ -24,8 +30,8 @@ def compute_strongly_connected_components(function): Returns: list(list(nodes)) """ - visited = {n:False for n in function.nodes} - assigned = {n:False for n in function.nodes} + visited = {n: False for n in function.nodes} + assigned = {n: False for n in function.nodes} components = [] l = [] @@ -39,7 +45,7 @@ def compute_strongly_connected_components(function): for n in function.nodes: visit(n) - def assign(node, root): + def assign(node: "Node", root: List["Node"]): if not assigned[node]: assigned[node] = True root.append(node) @@ -47,14 +53,15 @@ def compute_strongly_connected_components(function): assign(father, root) for n in l: - component = [] + component: List["Node"] = [] assign(n, component) if component: components.append(component) return components -def compute_cyclomatic_complexity(function): + +def compute_cyclomatic_complexity(function: "Function") -> int: """ Compute the cyclomatic complexity of a function Args: diff --git a/slither/utils/colors.py b/slither/utils/colors.py index ba6cedc4a..c1a0e73af 100644 --- a/slither/utils/colors.py +++ b/slither/utils/colors.py @@ -4,22 +4,22 @@ import platform class Colors: COLORIZATION_ENABLED = True - RED = '\033[91m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - BLUE = '\033[94m' - MAGENTA = '\033[95m' - END = '\033[0m' + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + MAGENTA = "\033[95m" + END = "\033[0m" -def colorize(color, txt): +def colorize(color: Colors, txt: str) -> str: if Colors.COLORIZATION_ENABLED: - return '{}{}{}'.format(color, txt, Colors.END) + return "{}{}{}".format(color, txt, Colors.END) else: return txt -def enable_windows_virtual_terminal_sequences(): +def enable_windows_virtual_terminal_sequences() -> bool: """ Sets the appropriate flags to enable virtual terminal sequences in a Windows command prompt. Reference: https://docs.microsoft.com/en-us/windows/console/console-virtual-terminal-sequences @@ -51,7 +51,9 @@ def enable_windows_virtual_terminal_sequences(): # If the virtual terminal sequence processing is not yet enabled, we enable it. if (current_mode.value & virtual_terminal_flag) == 0: - if not kernel32.SetConsoleMode(current_handle, current_mode.value | virtual_terminal_flag): + if not kernel32.SetConsoleMode( + current_handle, current_mode.value | virtual_terminal_flag + ): return False except: # Any generic failure (possibly from calling these methods on older Windows builds where they do not exist) @@ -61,14 +63,14 @@ def enable_windows_virtual_terminal_sequences(): return True -def set_colorization_enabled(enabled): +def set_colorization_enabled(enabled: bool): """ Sets the enabled state of output colorization. :param enabled: Boolean indicating whether output should be colorized. :return: None """ # If color is supposed to be enabled and this is windows, we have to enable console virtual terminal sequences: - if enabled and platform.system() == 'Windows': + if enabled and platform.system() == "Windows": Colors.COLORIZATION_ENABLED = enable_windows_virtual_terminal_sequences() else: # This is not windows so we can enable color immediately. diff --git a/slither/utils/command_line.py b/slither/utils/command_line.py index 8527f911f..3e0659772 100644 --- a/slither/utils/command_line.py +++ b/slither/utils/command_line.py @@ -2,7 +2,9 @@ import json import os import logging from collections import defaultdict -from crytic_compile.cryticparser.defaults import DEFAULTS_FLAG_IN_CONFIG as DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE +from crytic_compile.cryticparser.defaults import ( + DEFAULTS_FLAG_IN_CONFIG as DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE, +) from slither.detectors.abstract_detector import classification_txt from .colors import yellow, red @@ -11,30 +13,37 @@ from .myprettytable import MyPrettyTable logger = logging.getLogger("Slither") DEFAULT_JSON_OUTPUT_TYPES = ["detectors", "printers"] -JSON_OUTPUT_TYPES = ["compilations", "console", "detectors", "printers", "list-detectors", "list-printers"] +JSON_OUTPUT_TYPES = [ + "compilations", + "console", + "detectors", + "printers", + "list-detectors", + "list-printers", +] # Those are the flags shared by the command line and the config file defaults_flag_in_config = { - 'detectors_to_run': 'all', - 'printers_to_run': None, - 'detectors_to_exclude': None, - 'exclude_dependencies': False, - 'exclude_informational': False, - 'exclude_optimization': False, - 'exclude_low': False, - 'exclude_medium': False, - 'exclude_high': False, - 'json': None, - 'json-types': ','.join(DEFAULT_JSON_OUTPUT_TYPES), - 'disable_color': False, - 'filter_paths': None, - 'generate_patches': False, + "detectors_to_run": "all", + "printers_to_run": None, + "detectors_to_exclude": None, + "exclude_dependencies": False, + "exclude_informational": False, + "exclude_optimization": False, + "exclude_low": False, + "exclude_medium": False, + "exclude_high": False, + "json": None, + "json-types": ",".join(DEFAULT_JSON_OUTPUT_TYPES), + "disable_color": False, + "filter_paths": None, + "generate_patches": False, # debug command - 'legacy_ast': False, - 'ignore_return_value': False, - 'zip': None, - 'zip_type': 'lzma', - **DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE + "legacy_ast": False, + "ignore_return_value": False, + "zip": None, + "zip_type": "lzma", + **DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE, } @@ -45,26 +54,32 @@ def read_config_file(args): config = json.load(f) for key, elem in config.items(): if key not in defaults_flag_in_config: - logger.info(yellow('{} has an unknown key: {} : {}'.format(args.config_file, key, elem))) + logger.info( + yellow( + "{} has an unknown key: {} : {}".format(args.config_file, key, elem) + ) + ) continue if getattr(args, key) == defaults_flag_in_config[key]: setattr(args, key, elem) except json.decoder.JSONDecodeError as e: - logger.error(red('Impossible to read {}, please check the file {}'.format(args.config_file, e))) + logger.error( + red("Impossible to read {}, please check the file {}".format(args.config_file, e)) + ) def output_to_markdown(detector_classes, printer_classes, filter_wiki): def extract_help(cls): - if cls.WIKI == '': + if cls.WIKI == "": return cls.HELP - return '[{}]({})'.format(cls.HELP, cls.WIKI) + return "[{}]({})".format(cls.HELP, cls.WIKI) detectors_list = [] print(filter_wiki) for detector in detector_classes: argument = detector.ARGUMENT # dont show the backdoor example - if argument == 'backdoor': + if argument == "backdoor": continue if not filter_wiki in detector.WIKI: continue @@ -74,14 +89,16 @@ def output_to_markdown(detector_classes, printer_classes, filter_wiki): detectors_list.append((argument, help_info, impact, confidence)) # Sort by impact, confidence, and name - detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[3], element[0])) + detectors_list = sorted( + detectors_list, key=lambda element: (element[2], element[3], element[0]) + ) idx = 1 for (argument, help_info, impact, confidence) in detectors_list: - print('{} | `{}` | {} | {} | {}'.format(idx, - argument, - help_info, - classification_txt[impact], - confidence)) + print( + "{} | `{}` | {} | {} | {}".format( + idx, argument, help_info, classification_txt[impact], confidence + ) + ) idx = idx + 1 print() @@ -95,67 +112,72 @@ def output_to_markdown(detector_classes, printer_classes, filter_wiki): printers_list = sorted(printers_list, key=lambda element: (element[0])) idx = 1 for (argument, help_info) in printers_list: - print('{} | `{}` | {}'.format(idx, argument, help_info)) + print("{} | `{}` | {}".format(idx, argument, help_info)) idx = idx + 1 def get_level(l): - tab = l.count('\t') + 1 - if l.replace('\t', '').startswith(' -'): + tab = l.count("\t") + 1 + if l.replace("\t", "").startswith(" -"): tab = tab + 1 - if l.replace('\t', '').startswith('-'): + if l.replace("\t", "").startswith("-"): tab = tab + 1 return tab def convert_result_to_markdown(txt): # -1 to remove the last \n - lines = txt[0:-1].split('\n') + lines = txt[0:-1].split("\n") ret = [] level = 0 for l in lines: next_level = get_level(l) - prefix = '
  • ' + prefix = "
  • " if next_level < level: - prefix = '' * (level - next_level) + prefix + prefix = "" * (level - next_level) + prefix if next_level > level: - prefix = '