pull/1666/head
Feist Josselin 2 years ago
parent 060f550b0d
commit bd2d572f48
  1. 6
      examples/scripts/data_dependency.py
  2. 1
      examples/scripts/variable_in_condition.py
  3. 4
      slither/__main__.py
  4. 26
      slither/analyses/data_dependency/data_dependency.py
  5. 27
      slither/core/declarations/contract.py
  6. 6
      slither/core/declarations/custom_error.py
  7. 3
      slither/core/declarations/using_for_top_level.py
  8. 5
      slither/core/variables/variable.py
  9. 2
      slither/detectors/statements/costly_operations_in_loop.py
  10. 2
      slither/detectors/statements/write_after_write.py
  11. 3
      slither/printers/call/call_graph.py
  12. 3
      slither/printers/summary/constructor_calls.py
  13. 19
      slither/printers/summary/contract.py
  14. 3
      slither/printers/summary/variable_order.py
  15. 2
      slither/slithir/utils/ssa.py
  16. 2
      slither/slithir/utils/utils.py
  17. 4
      slither/slithir/variables/local_variable.py
  18. 4
      slither/slithir/variables/state_variable.py
  19. 5
      slither/slithir/variables/variable.py
  20. 78
      slither/solc_parsing/declarations/contract.py
  21. 10
      slither/solc_parsing/declarations/function.py
  22. 16
      slither/solc_parsing/variables/variable_declaration.py
  23. 5
      slither/tools/doctor/checks/versions.py
  24. 5
      slither/tools/read_storage/utils/utils.py
  25. 8
      slither/tools/upgradeability/checks/variable_initialization.py
  26. 38
      slither/tools/upgradeability/checks/variables_order.py
  27. 93
      tests/test_ssa_generation.py

@ -18,6 +18,8 @@ assert len(contracts) == 1
contract = contracts[0] contract = contracts[0]
destination = contract.get_state_variable_from_name("destination") destination = contract.get_state_variable_from_name("destination")
source = contract.get_state_variable_from_name("source") source = contract.get_state_variable_from_name("source")
assert source
assert destination
print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}") print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}")
assert not is_dependent(source, destination, contract) assert not is_dependent(source, destination, contract)
@ -47,9 +49,11 @@ print(f"{destination} is tainted {is_tainted(destination, contract)}")
assert is_tainted(destination, contract) assert is_tainted(destination, contract)
destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1") destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1")
assert destination_indirect_1
print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}") print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}")
assert is_tainted(destination_indirect_1, contract) assert is_tainted(destination_indirect_1, contract)
destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2") destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2")
assert destination_indirect_2
print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}") print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}")
assert is_tainted(destination_indirect_2, contract) assert is_tainted(destination_indirect_2, contract)
@ -88,6 +92,8 @@ contract = contracts[0]
contract_derived = slither.get_contract_from_name("Derived")[0] contract_derived = slither.get_contract_from_name("Derived")[0]
destination = contract.get_state_variable_from_name("destination") destination = contract.get_state_variable_from_name("destination")
source = contract.get_state_variable_from_name("source") source = contract.get_state_variable_from_name("source")
assert destination
assert source
print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}") print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}")
assert not is_dependent(destination, source, contract) assert not is_dependent(destination, source, contract)

@ -14,6 +14,7 @@ assert len(contracts) == 1
contract = contracts[0] contract = contracts[0]
# Get the variable # Get the variable
var_a = contract.get_state_variable_from_name("a") var_a = contract.get_state_variable_from_name("a")
assert var_a
# Get the functions reading the variable # Get the functions reading the variable
functions_reading_a = contract.get_functions_reading_from_variable(var_a) functions_reading_a = contract.get_functions_reading_from_variable(var_a)

@ -615,7 +615,9 @@ def parse_args(
class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
output_detectors(detectors) output_detectors(detectors)
parser.exit() parser.exit()

@ -2,7 +2,7 @@
Compute the data depenency between all the SSA variables Compute the data depenency between all the SSA variables
""" """
from collections import defaultdict from collections import defaultdict
from typing import Union, Set, Dict, TYPE_CHECKING from typing import Union, Set, Dict, TYPE_CHECKING, List
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations import ( from slither.core.declarations import (
@ -20,6 +20,7 @@ from slither.core.solidity_types.type import Type
from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.top_level_variable import TopLevelVariable
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation
from slither.slithir.utils.utils import LVALUE
from slither.slithir.variables import ( from slither.slithir.variables import (
Constant, Constant,
LocalIRVariable, LocalIRVariable,
@ -29,6 +30,7 @@ from slither.slithir.variables import (
TemporaryVariableSSA, TemporaryVariableSSA,
TupleVariableSSA, TupleVariableSSA,
) )
from slither.slithir.variables.variable import SlithIRVariable
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
@ -393,13 +395,9 @@ def transitive_close_dependencies(
while changed: while changed:
changed = False changed = False
to_add = defaultdict(set) to_add = defaultdict(set)
[ # pylint: disable=expression-not-assigned for key, items in context.context[context_key].items():
[ for item in items & keys:
to_add[key].update(context.context[context_key][item] - {key} - items) to_add[key].update(context.context[context_key][item] - {key} - items)
for item in items & keys
]
for key, items in context.context[context_key].items()
]
for k, v in to_add.items(): for k, v in to_add.items():
# Because we dont have any check on the update operation # Because we dont have any check on the update operation
# We might update an empty set with an empty set # We might update an empty set with an empty set
@ -418,20 +416,20 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote
function.context[KEY_SSA][lvalue] = set() function.context[KEY_SSA][lvalue] = set()
if not is_protected: if not is_protected:
function.context[KEY_SSA_UNPROTECTED][lvalue] = set() function.context[KEY_SSA_UNPROTECTED][lvalue] = set()
read: Union[List[Union[LVALUE, SolidityVariableComposed]], List[SlithIRVariable]]
if isinstance(ir, Index): if isinstance(ir, Index):
read = [ir.variable_left] read = [ir.variable_left]
elif isinstance(ir, InternalCall): elif isinstance(ir, InternalCall) and ir.function:
read = ir.function.return_values_ssa read = ir.function.return_values_ssa
else: else:
read = ir.read read = ir.read
# pylint: disable=expression-not-assigned for v in read:
[function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] if not isinstance(v, Constant):
function.context[KEY_SSA][lvalue].add(v)
if not is_protected: if not is_protected:
[ for v in read:
if not isinstance(v, Constant):
function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) function.context[KEY_SSA_UNPROTECTED][lvalue].add(v)
for v in read
if not isinstance(v, Constant)
]
def compute_dependency_function(function: Function) -> None: def compute_dependency_function(function: Function) -> None:

@ -49,6 +49,9 @@ if TYPE_CHECKING:
LOGGER = logging.getLogger("Contract") LOGGER = logging.getLogger("Contract")
USING_FOR_KEY = Union[str, Type]
USING_FOR_ITEM = List[Union[Type, Function]]
class Contract(SourceMapping): # pylint: disable=too-many-public-methods class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
@ -80,8 +83,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._custom_errors: Dict[str, "CustomErrorContract"] = {} self._custom_errors: Dict[str, "CustomErrorContract"] = {}
# The only str is "*" # The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {} self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {}
self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None
self._kind: Optional[str] = None self._kind: Optional[str] = None
self._is_interface: bool = False self._is_interface: bool = False
self._is_library: bool = False self._is_library: bool = False
@ -123,7 +126,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._name return self._name
@name.setter @name.setter
def name(self, name: str): def name(self, name: str) -> None:
self._name = name self._name = name
@property @property
@ -133,7 +136,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._id return self._id
@id.setter @id.setter
def id(self, new_id): def id(self, new_id: int) -> None:
"""Unique id.""" """Unique id."""
self._id = new_id self._id = new_id
@ -146,7 +149,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._kind return self._kind
@contract_kind.setter @contract_kind.setter
def contract_kind(self, kind): def contract_kind(self, kind: str) -> None:
self._kind = kind self._kind = kind
@property @property
@ -154,7 +157,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_interface return self._is_interface
@is_interface.setter @is_interface.setter
def is_interface(self, is_interface: bool): def is_interface(self, is_interface: bool) -> None:
self._is_interface = is_interface self._is_interface = is_interface
@property @property
@ -162,7 +165,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_library return self._is_library
@is_library.setter @is_library.setter
def is_library(self, is_library: bool): def is_library(self, is_library: bool) -> None:
self._is_library = is_library self._is_library = is_library
# endregion # endregion
@ -266,16 +269,18 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
@property @property
def using_for(self) -> Dict[Union[str, Type], List[Type]]: def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
return self._using_for return self._using_for
@property @property
def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]: def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
""" """
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
""" """
def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict: def _merge_using_for(
uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM]
) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
result = {**uf1, **uf2} result = {**uf1, **uf2}
for key, value in result.items(): for key, value in result.items():
if key in uf1 and key in uf2: if key in uf1 and key in uf2:
@ -1452,7 +1457,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
result = func.get_last_ssa_state_variables_instances() result = func.get_last_ssa_state_variables_instances()
for variable_name, instances in result.items(): for variable_name, instances in result.items():
# TODO: investigate the next operation # TODO: investigate the next operation
last_state_variables_instances[variable_name] += instances last_state_variables_instances[variable_name] += list(instances)
for func in self.functions + list(self.modifiers): for func in self.functions + list(self.modifiers):
func.fix_phi(last_state_variables_instances, initial_state_variables_instances) func.fix_phi(last_state_variables_instances, initial_state_variables_instances)

@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Optional, Type, Union from typing import List, TYPE_CHECKING, Optional, Type
from slither.core.solidity_types import UserDefinedType from slither.core.solidity_types import UserDefinedType
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
@ -42,7 +42,7 @@ class CustomError(SourceMapping):
################################################################################### ###################################################################################
@staticmethod @staticmethod
def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: def _convert_type_for_solidity_signature(t: Optional[Type]) -> str:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -72,7 +72,7 @@ class CustomError(SourceMapping):
Returns: Returns:
""" """
parameters = [x.type for x in self.parameters] parameters = [x.type for x in self.parameters if x.type]
self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")" self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")"
solidity_parameters = map(self._convert_type_for_solidity_signature, parameters) solidity_parameters = map(self._convert_type_for_solidity_signature, parameters)
self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")" self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")"

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, List, Dict, Union from typing import TYPE_CHECKING, List, Dict, Union
from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
@ -14,5 +15,5 @@ class UsingForTopLevel(TopLevel):
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope
@property @property
def using_for(self) -> Dict[Union[str, Type], List[Type]]: def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
return self._using_for return self._using_for

@ -160,8 +160,8 @@ class Variable(SourceMapping):
return ( return (
self.name, self.name,
[str(x) for x in export_nested_types_from_variable(self)], [str(x) for x in export_nested_types_from_variable(self)], # type: ignore
[str(x) for x in export_return_type_from_variable(self)], [str(x) for x in export_return_type_from_variable(self)], # type: ignore
) )
@property @property
@ -179,4 +179,5 @@ class Variable(SourceMapping):
return f'{name}({",".join(parameters)})' return f'{name}({",".join(parameters)})'
def __str__(self) -> str: def __str__(self) -> str:
assert self._name
return self._name return self._name

@ -43,7 +43,7 @@ def costly_operations_in_loop(
if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable): if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable):
ret.append(ir.node) ret.append(ir.node)
break break
if isinstance(ir, (InternalCall)): if isinstance(ir, (InternalCall)) and ir.function:
costly_operations_in_loop(ir.function.entry_point, in_loop_counter, visited, ret) costly_operations_in_loop(ir.function.entry_point, in_loop_counter, visited, ret)
for son in node.sons: for son in node.sons:

@ -37,6 +37,8 @@ def _handle_ir(
_remove_states(written) _remove_states(written)
if isinstance(ir, InternalCall): if isinstance(ir, InternalCall):
if not ir.function:
return
if ir.function.all_high_level_calls() or ir.function.all_library_calls(): if ir.function.all_high_level_calls() or ir.function.all_library_calls():
_remove_states(written) _remove_states(written)

@ -13,6 +13,7 @@ from slither.core.declarations.function import Function
from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.output import Output
def _contract_subgraph(contract: Contract) -> str: def _contract_subgraph(contract: Contract) -> str:
@ -222,7 +223,7 @@ class PrinterCallGraph(AbstractPrinter):
WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph" WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph"
def output(self, filename): def output(self, filename: str) -> Output:
""" """
Output the graph in filename Output the graph in filename
Args: Args:

@ -5,6 +5,7 @@ from slither.core.declarations import Function
from slither.core.source_mapping.source_mapping import Source from slither.core.source_mapping.source_mapping import Source
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from slither.utils import output from slither.utils import output
from slither.utils.output import Output
def _get_source_code(cst: Function) -> str: def _get_source_code(cst: Function) -> str:
@ -17,7 +18,7 @@ class ConstructorPrinter(AbstractPrinter):
ARGUMENT = "constructor-calls" ARGUMENT = "constructor-calls"
HELP = "Print the constructors executed" HELP = "Print the constructors executed"
def output(self, _filename): def output(self, _filename: str) -> Output:
info = "" info = ""
for contract in self.slither.contracts_derived: for contract in self.slither.contracts_derived:
stack_name = [] stack_name = []

@ -2,9 +2,13 @@
Module printing summary of the contract Module printing summary of the contract
""" """
import collections import collections
from typing import Dict, List
from slither.core.declarations import FunctionContract
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from slither.utils import output from slither.utils import output
from slither.utils.colors import blue, green, magenta from slither.utils.colors import blue, green, magenta
from slither.utils.output import Output
class ContractSummary(AbstractPrinter): class ContractSummary(AbstractPrinter):
@ -13,7 +17,7 @@ class ContractSummary(AbstractPrinter):
WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary" WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary"
def output(self, _filename): # pylint: disable=too-many-locals def output(self, _filename: str) -> Output: # pylint: disable=too-many-locals
""" """
_filename is not used _filename is not used
Args: Args:
@ -53,17 +57,16 @@ class ContractSummary(AbstractPrinter):
# Order the function with # Order the function with
# contract_declarer -> list_functions # contract_declarer -> list_functions
public = [ public_function = [
(f.contract_declarer.name, f) (f.contract_declarer.name, f)
for f in c.functions for f in c.functions
if (not f.is_shadowed and not f.is_constructor_variables) if (not f.is_shadowed and not f.is_constructor_variables)
] ]
collect = collections.defaultdict(list) collect: Dict[str, List[FunctionContract]] = collections.defaultdict(list)
for a, b in public: for a, b in public_function:
collect[a].append(b) collect[a].append(b)
public = list(collect.items())
for contract, functions in public: for contract, functions in collect.items():
txt += blue(f" - From {contract}\n") txt += blue(f" - From {contract}\n")
functions = sorted(functions, key=lambda f: f.full_name) functions = sorted(functions, key=lambda f: f.full_name)
@ -90,7 +93,7 @@ class ContractSummary(AbstractPrinter):
self.info(txt) self.info(txt)
res = self.generate_output(txt) res = self.generate_output(txt)
for contract, additional_fields in all_contracts: for current_contract, current_additional_fields in all_contracts:
res.add(contract, additional_fields=additional_fields) res.add(current_contract, additional_fields=current_additional_fields)
return res return res

@ -4,6 +4,7 @@
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.myprettytable import MyPrettyTable from slither.utils.myprettytable import MyPrettyTable
from slither.utils.output import Output
class VariableOrder(AbstractPrinter): class VariableOrder(AbstractPrinter):
@ -13,7 +14,7 @@ class VariableOrder(AbstractPrinter):
WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variable-order" WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variable-order"
def output(self, _filename): def output(self, _filename: str) -> Output:
""" """
_filename is not used _filename is not used
Args: Args:

@ -366,7 +366,7 @@ def last_name(
def is_used_later( def is_used_later(
initial_node: Node, initial_node: Node,
variable: Union[StateIRVariable, LocalVariable], variable: Union[StateIRVariable, LocalVariable, TemporaryVariableSSA],
) -> bool: ) -> bool:
# TODO: does not handle the case where its read and written in the declaration node # TODO: does not handle the case where its read and written in the declaration node
# It can be problematic if this happens in a loop/if structure # It can be problematic if this happens in a loop/if structure

@ -46,7 +46,7 @@ def is_valid_rvalue(v: SourceMapping) -> bool:
) )
def is_valid_lvalue(v) -> bool: def is_valid_lvalue(v: SourceMapping) -> bool:
return isinstance( return isinstance(
v, v,
( (

@ -41,11 +41,11 @@ class LocalIRVariable(
self._non_ssa_version = local_variable self._non_ssa_version = local_variable
@property @property
def index(self): def index(self) -> int:
return self._index return self._index
@index.setter @index.setter
def index(self, idx): def index(self, idx: int) -> None:
self._index = idx self._index = idx
@property @property

@ -30,11 +30,11 @@ class StateIRVariable(
self._non_ssa_version = state_variable self._non_ssa_version = state_variable
@property @property
def index(self): def index(self) -> int:
return self._index return self._index
@index.setter @index.setter
def index(self, idx): def index(self, idx: int) -> None:
self._index = idx self._index = idx
@property @property

@ -7,8 +7,9 @@ class SlithIRVariable(Variable):
self._index = 0 self._index = 0
@property @property
def ssa_name(self): def ssa_name(self) -> str:
assert self.name
return self.name return self.name
def __str__(self): def __str__(self) -> str:
return self.ssa_name return self.ssa_name

@ -1,6 +1,6 @@
import logging import logging
import re import re
from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence
from slither.core.declarations import ( from slither.core.declarations import (
Modifier, Modifier,
@ -9,10 +9,10 @@ from slither.core.declarations import (
StructureContract, StructureContract,
Function, Function,
) )
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract, USING_FOR_KEY
from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type from slither.core.solidity_types import ElementaryType, TypeAliasContract
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
@ -302,7 +302,7 @@ class ContractSolc(CallerContextExpression):
st.set_contract(self._contract) st.set_contract(self._contract)
st.set_offset(struct["src"], self._contract.compilation_unit) st.set_offset(struct["src"], self._contract.compilation_unit)
st_parser = StructureContractSolc(st, struct, self) st_parser = StructureContractSolc(st, struct, self) # type: ignore
self._contract.structures_as_dict[st.name] = st self._contract.structures_as_dict[st.name] = st
self._structures_parser.append(st_parser) self._structures_parser.append(st_parser)
@ -312,7 +312,7 @@ class ContractSolc(CallerContextExpression):
for struct in self._structuresNotParsed: for struct in self._structuresNotParsed:
self._parse_struct(struct) self._parse_struct(struct)
self._structuresNotParsed = None self._structuresNotParsed = []
def _parse_custom_error(self, custom_error: Dict) -> None: def _parse_custom_error(self, custom_error: Dict) -> None:
ce = CustomErrorContract(self.compilation_unit) ce = CustomErrorContract(self.compilation_unit)
@ -329,7 +329,7 @@ class ContractSolc(CallerContextExpression):
for custom_error in self._customErrorParsed: for custom_error in self._customErrorParsed:
self._parse_custom_error(custom_error) self._parse_custom_error(custom_error)
self._customErrorParsed = None self._customErrorParsed = []
def parse_state_variables(self) -> None: def parse_state_variables(self) -> None:
for father in self._contract.inheritance_reverse: for father in self._contract.inheritance_reverse:
@ -356,6 +356,7 @@ class ContractSolc(CallerContextExpression):
var_parser = StateVariableSolc(var, varNotParsed) var_parser = StateVariableSolc(var, varNotParsed)
self._variables_parser.append(var_parser) self._variables_parser.append(var_parser)
assert var.name
self._contract.variables_as_dict[var.name] = var self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var]) self._contract.add_variables_ordered([var])
@ -365,7 +366,7 @@ class ContractSolc(CallerContextExpression):
modif.set_contract(self._contract) modif.set_contract(self._contract)
modif.set_contract_declarer(self._contract) modif.set_contract_declarer(self._contract)
modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) # type: ignore
self._contract.compilation_unit.add_modifier(modif) self._contract.compilation_unit.add_modifier(modif)
self._modifiers_no_params.append(modif_parser) self._modifiers_no_params.append(modif_parser)
self._modifiers_parser.append(modif_parser) self._modifiers_parser.append(modif_parser)
@ -375,7 +376,7 @@ class ContractSolc(CallerContextExpression):
def parse_modifiers(self) -> None: def parse_modifiers(self) -> None:
for modifier in self._modifiersNotParsed: for modifier in self._modifiersNotParsed:
self._parse_modifier(modifier) self._parse_modifier(modifier)
self._modifiersNotParsed = None self._modifiersNotParsed = []
def _parse_function(self, function_data: Dict) -> None: def _parse_function(self, function_data: Dict) -> None:
func = FunctionContract(self._contract.compilation_unit) func = FunctionContract(self._contract.compilation_unit)
@ -383,7 +384,7 @@ class ContractSolc(CallerContextExpression):
func.set_contract(self._contract) func.set_contract(self._contract)
func.set_contract_declarer(self._contract) func.set_contract_declarer(self._contract)
func_parser = FunctionSolc(func, function_data, self, self._slither_parser) func_parser = FunctionSolc(func, function_data, self, self._slither_parser) # type: ignore
self._contract.compilation_unit.add_function(func) self._contract.compilation_unit.add_function(func)
self._functions_no_params.append(func_parser) self._functions_no_params.append(func_parser)
self._functions_parser.append(func_parser) self._functions_parser.append(func_parser)
@ -395,7 +396,7 @@ class ContractSolc(CallerContextExpression):
for function in self._functionsNotParsed: for function in self._functionsNotParsed:
self._parse_function(function) self._parse_function(function)
self._functionsNotParsed = None self._functionsNotParsed = []
# endregion # endregion
################################################################################### ###################################################################################
@ -439,7 +440,8 @@ class ContractSolc(CallerContextExpression):
Cls_parser, Cls_parser,
self._modifiers_parser, self._modifiers_parser,
) )
self._contract.set_modifiers(modifiers) # modifiers will be using Modifier so we can ignore the next type check
self._contract.set_modifiers(modifiers) # type: ignore
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing params {e}") self.log_incorrect_parsing(f"Missing params {e}")
self._modifiers_no_params = [] self._modifiers_no_params = []
@ -459,7 +461,8 @@ class ContractSolc(CallerContextExpression):
Cls_parser, Cls_parser,
self._functions_parser, self._functions_parser,
) )
self._contract.set_functions(functions) # function will be using FunctionContract so we can ignore the next type check
self._contract.set_functions(functions) # type: ignore
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing params {e}") self.log_incorrect_parsing(f"Missing params {e}")
self._functions_no_params = [] self._functions_no_params = []
@ -470,7 +473,7 @@ class ContractSolc(CallerContextExpression):
Cls_parser: Callable, Cls_parser: Callable,
element_parser: FunctionSolc, element_parser: FunctionSolc,
explored_reference_id: Set[str], explored_reference_id: Set[str],
parser: List[FunctionSolc], parser: Union[List[FunctionSolc], List[ModifierSolc]],
all_elements: Dict[str, Function], all_elements: Dict[str, Function],
) -> None: ) -> None:
elem = Cls(self._contract.compilation_unit) elem = Cls(self._contract.compilation_unit)
@ -508,13 +511,13 @@ class ContractSolc(CallerContextExpression):
def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals
self, self,
elements_no_params: List[FunctionSolc], elements_no_params: Sequence[FunctionSolc],
getter: Callable[["ContractSolc"], List[FunctionSolc]], getter: Callable[["ContractSolc"], List[FunctionSolc]],
getter_available: Callable[[Contract], List[FunctionContract]], getter_available: Callable[[Contract], List[FunctionContract]],
Cls: Callable, Cls: Callable,
Cls_parser: Callable, Cls_parser: Callable,
parser: List[FunctionSolc], parser: Union[List[FunctionSolc], List[ModifierSolc]],
) -> Dict[str, Union[FunctionContract, Modifier]]: ) -> Dict[str, Function]:
""" """
Analyze the parameters of the given elements (Function or Modifier). Analyze the parameters of the given elements (Function or Modifier).
The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier) The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier)
@ -526,13 +529,13 @@ class ContractSolc(CallerContextExpression):
:param Cls: Class to create for collision :param Cls: Class to create for collision
:return: :return:
""" """
all_elements = {} all_elements: Dict[str, Function] = {}
explored_reference_id = set() explored_reference_id: Set[str] = set()
try: try:
for father in self._contract.inheritance: for father in self._contract.inheritance:
father_parser = self._slither_parser.underlying_contract_to_parser[father] father_parser = self._slither_parser.underlying_contract_to_parser[father]
for element_parser in getter(father_parser): for element_parser in getter(father_parser): # type: ignore
self._analyze_params_element( self._analyze_params_element(
Cls, Cls_parser, element_parser, explored_reference_id, parser, all_elements Cls, Cls_parser, element_parser, explored_reference_id, parser, all_elements
) )
@ -597,7 +600,7 @@ class ContractSolc(CallerContextExpression):
if self.is_compact_ast: if self.is_compact_ast:
for using_for in self._usingForNotParsed: for using_for in self._usingForNotParsed:
if "typeName" in using_for and using_for["typeName"]: if "typeName" in using_for and using_for["typeName"]:
type_name = parse_type(using_for["typeName"], self) type_name: USING_FOR_KEY = parse_type(using_for["typeName"], self)
else: else:
type_name = "*" type_name = "*"
if type_name not in self._contract.using_for: if type_name not in self._contract.using_for:
@ -616,7 +619,7 @@ class ContractSolc(CallerContextExpression):
assert children and len(children) <= 2 assert children and len(children) <= 2
if len(children) == 2: if len(children) == 2:
new = parse_type(children[0], self) new = parse_type(children[0], self)
old = parse_type(children[1], self) old: USING_FOR_KEY = parse_type(children[1], self)
else: else:
new = parse_type(children[0], self) new = parse_type(children[0], self)
old = "*" old = "*"
@ -627,7 +630,7 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}") self.log_incorrect_parsing(f"Missing using for {e}")
def _analyze_function_list(self, function_list: List, type_name: Type) -> None: def _analyze_function_list(self, function_list: List, type_name: USING_FOR_KEY) -> None:
for f in function_list: for f in function_list:
full_name_split = f["function"]["name"].split(".") full_name_split = f["function"]["name"].split(".")
if len(full_name_split) == 1: if len(full_name_split) == 1:
@ -646,7 +649,9 @@ class ContractSolc(CallerContextExpression):
function_name = full_name_split[2] function_name = full_name_split[2]
self._analyze_library_function(library_name, function_name, type_name) self._analyze_library_function(library_name, function_name, type_name)
def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None: def _check_aliased_import(
self, first_part: str, function_name: str, type_name: USING_FOR_KEY
) -> None:
# We check if the first part appear as alias for an import # We check if the first part appear as alias for an import
# if it is then function_name must be a top level function # if it is then function_name must be a top level function
# otherwise it's a library function # otherwise it's a library function
@ -656,13 +661,13 @@ class ContractSolc(CallerContextExpression):
return return
self._analyze_library_function(first_part, function_name, type_name) self._analyze_library_function(first_part, function_name, type_name)
def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None: def _analyze_top_level_function(self, function_name: str, type_name: USING_FOR_KEY) -> None:
for tl_function in self.compilation_unit.functions_top_level: for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name: if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function) self._contract.using_for[type_name].append(tl_function)
def _analyze_library_function( def _analyze_library_function(
self, library_name: str, function_name: str, type_name: Type self, library_name: str, function_name: str, type_name: USING_FOR_KEY
) -> None: ) -> None:
# Get the library function # Get the library function
found = False found = False
@ -689,22 +694,13 @@ class ContractSolc(CallerContextExpression):
# for enum, we can parse and analyze it # for enum, we can parse and analyze it
# at the same time # at the same time
self._analyze_enum(enum) self._analyze_enum(enum)
self._enumsNotParsed = None self._enumsNotParsed = []
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing enum {e}") self.log_incorrect_parsing(f"Missing enum {e}")
def _analyze_enum( def _analyze_enum(
self, self,
enum: Dict[ enum: Dict,
str,
Union[
str,
int,
List[Dict[str, Union[int, str]]],
Dict[str, str],
List[Dict[str, Union[Dict[str, str], int, str]]],
],
],
) -> None: ) -> None:
# Enum can be parsed in one pass # Enum can be parsed in one pass
if self.is_compact_ast: if self.is_compact_ast:
@ -753,13 +749,13 @@ class ContractSolc(CallerContextExpression):
event.set_contract(self._contract) event.set_contract(self._contract)
event.set_offset(event_to_parse["src"], self._contract.compilation_unit) event.set_offset(event_to_parse["src"], self._contract.compilation_unit)
event_parser = EventSolc(event, event_to_parse, self) event_parser = EventSolc(event, event_to_parse, self) # type: ignore
event_parser.analyze(self) event_parser.analyze(self) # type: ignore
self._contract.events_as_dict[event.full_name] = event self._contract.events_as_dict[event.full_name] = event
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing event {e}") self.log_incorrect_parsing(f"Missing event {e}")
self._eventsNotParsed = None self._eventsNotParsed = []
# endregion # endregion
################################################################################### ###################################################################################
@ -768,7 +764,7 @@ class ContractSolc(CallerContextExpression):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def delete_content(self): def delete_content(self) -> None:
""" """
Remove everything not parsed from the contract Remove everything not parsed from the contract
This is used only if something went wrong with the inheritance parsing This is used only if something went wrong with the inheritance parsing
@ -810,7 +806,7 @@ class ContractSolc(CallerContextExpression):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def __hash__(self): def __hash__(self) -> int:
return self._contract.id return self._contract.id
# endregion # endregion

@ -242,7 +242,7 @@ class FunctionSolc(CallerContextExpression):
if "payable" in attributes: if "payable" in attributes:
self._function.payable = attributes["payable"] self._function.payable = attributes["payable"]
def analyze_params(self): def analyze_params(self) -> None:
# Can be re-analyzed due to inheritance # Can be re-analyzed due to inheritance
if self._params_was_analyzed: if self._params_was_analyzed:
return return
@ -272,7 +272,7 @@ class FunctionSolc(CallerContextExpression):
if returns: if returns:
self._parse_returns(returns) self._parse_returns(returns)
def analyze_content(self): def analyze_content(self) -> None:
if self._content_was_analyzed: if self._content_was_analyzed:
return return
@ -308,8 +308,8 @@ class FunctionSolc(CallerContextExpression):
for node_parser in self._node_to_nodesolc.values(): for node_parser in self._node_to_nodesolc.values():
node_parser.analyze_expressions(self) node_parser.analyze_expressions(self)
for node_parser in self._node_to_yulobject.values(): for yul_parser in self._node_to_yulobject.values():
node_parser.analyze_expressions() yul_parser.analyze_expressions()
self._rewrite_ternary_as_if_else() self._rewrite_ternary_as_if_else()
@ -1297,7 +1297,7 @@ class FunctionSolc(CallerContextExpression):
son.remove_father(node) son.remove_father(node)
node.set_sons(new_sons) node.set_sons(new_sons)
def _remove_alone_endif(self): def _remove_alone_endif(self) -> None:
""" """
Can occur on: Can occur on:
if(..){ if(..){

@ -1,6 +1,6 @@
import logging import logging
import re import re
from typing import Dict, Optional from typing import Dict, Optional, Union
from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
@ -42,12 +42,12 @@ class VariableDeclarationSolc:
self._variable = variable self._variable = variable
self._was_analyzed = False self._was_analyzed = False
self._elem_to_parse = None self._elem_to_parse: Optional[Union[Dict, UnknownType]] = None
self._initializedNotParsed = None self._initializedNotParsed: Optional[Dict] = None
self._is_compact_ast = False self._is_compact_ast = False
self._reference_id = None self._reference_id: Optional[int] = None
if "nodeType" in variable_data: if "nodeType" in variable_data:
self._is_compact_ast = True self._is_compact_ast = True
@ -87,7 +87,7 @@ class VariableDeclarationSolc:
declaration = variable_data["children"][0] declaration = variable_data["children"][0]
self._init_from_declaration(declaration, init) self._init_from_declaration(declaration, init)
elif nodeType == "VariableDeclaration": elif nodeType == "VariableDeclaration":
self._init_from_declaration(variable_data, False) self._init_from_declaration(variable_data, None)
else: else:
raise ParsingError(f"Incorrect variable declaration type {nodeType}") raise ParsingError(f"Incorrect variable declaration type {nodeType}")
@ -101,6 +101,7 @@ class VariableDeclarationSolc:
Return the solc id. It can be compared with the referencedDeclaration attr Return the solc id. It can be compared with the referencedDeclaration attr
Returns None if it was not parsed (legacy AST) Returns None if it was not parsed (legacy AST)
""" """
assert self._reference_id
return self._reference_id return self._reference_id
def _handle_comment(self, attributes: Dict) -> None: def _handle_comment(self, attributes: Dict) -> None:
@ -127,7 +128,7 @@ class VariableDeclarationSolc:
self._variable.visibility = "internal" self._variable.visibility = "internal"
def _init_from_declaration( def _init_from_declaration(
self, var: Dict, init: Optional[bool] self, var: Dict, init: Optional[Dict]
) -> None: # pylint: disable=too-many-branches ) -> None: # pylint: disable=too-many-branches
if self._is_compact_ast: if self._is_compact_ast:
attributes = var attributes = var
@ -195,7 +196,7 @@ class VariableDeclarationSolc:
self._initializedNotParsed = init self._initializedNotParsed = init
elif len(var["children"]) in [0, 1]: elif len(var["children"]) in [0, 1]:
self._variable.initialized = False self._variable.initialized = False
self._initializedNotParsed = [] self._initializedNotParsed = None
else: else:
assert len(var["children"]) == 2 assert len(var["children"]) == 2
self._variable.initialized = True self._variable.initialized = True
@ -212,5 +213,6 @@ class VariableDeclarationSolc:
self._elem_to_parse = None self._elem_to_parse = None
if self._variable.initialized: if self._variable.initialized:
assert self._initializedNotParsed
self._variable.expression = parse_expression(self._initializedNotParsed, caller_context) self._variable.expression = parse_expression(self._initializedNotParsed, caller_context)
self._initializedNotParsed = None self._initializedNotParsed = None

@ -1,6 +1,6 @@
from importlib import metadata from importlib import metadata
import json import json
from typing import Optional from typing import Optional, Any
import urllib import urllib
from packaging.version import parse, Version from packaging.version import parse, Version
@ -17,6 +17,7 @@ def get_installed_version(name: str) -> Optional[Version]:
def get_github_version(name: str) -> Optional[Version]: def get_github_version(name: str) -> Optional[Version]:
try: try:
# type: ignore
with urllib.request.urlopen( with urllib.request.urlopen(
f"https://api.github.com/repos/crytic/{name}/releases/latest" f"https://api.github.com/repos/crytic/{name}/releases/latest"
) as response: ) as response:
@ -27,7 +28,7 @@ def get_github_version(name: str) -> Optional[Version]:
return None return None
def show_versions(**_kwargs) -> None: def show_versions(**_kwargs: Any) -> None:
versions = { versions = {
"Slither": (get_installed_version("slither-analyzer"), get_github_version("slither")), "Slither": (get_installed_version("slither-analyzer"), get_github_version("slither")),
"crytic-compile": ( "crytic-compile": (

@ -2,6 +2,7 @@ from typing import Union
from eth_typing.evm import ChecksumAddress from eth_typing.evm import ChecksumAddress
from eth_utils import to_int, to_text, to_checksum_address from eth_utils import to_int, to_text, to_checksum_address
from web3 import Web3
def get_offset_value(hex_bytes: bytes, offset: int, size: int) -> bytes: def get_offset_value(hex_bytes: bytes, offset: int, size: int) -> bytes:
@ -48,7 +49,7 @@ def coerce_type(
if "address" in solidity_type: if "address" in solidity_type:
if not isinstance(value, (str, bytes)): if not isinstance(value, (str, bytes)):
raise TypeError raise TypeError
return to_checksum_address(value) return to_checksum_address(value) # type: ignore
if not isinstance(value, bytes): if not isinstance(value, bytes):
raise TypeError raise TypeError
@ -56,7 +57,7 @@ def coerce_type(
def get_storage_data( def get_storage_data(
web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] web3: Web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str]
) -> bytes: ) -> bytes:
""" """
Retrieves the storage data from the blockchain at target address and slot. Retrieves the storage data from the blockchain at target address and slot.

@ -1,7 +1,11 @@
from typing import List
from slither.tools.upgradeability.checks.abstract_checks import ( from slither.tools.upgradeability.checks.abstract_checks import (
CheckClassification, CheckClassification,
AbstractCheck, AbstractCheck,
CHECK_INFO,
) )
from slither.utils.output import Output
class VariableWithInit(AbstractCheck): class VariableWithInit(AbstractCheck):
@ -37,11 +41,11 @@ Using initialize functions to write initial values in state variables.
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self) -> List[Output]:
results = [] results = []
for s in self.contract.state_variables_ordered: for s in self.contract.state_variables_ordered:
if s.initialized and not (s.is_constant or s.is_immutable): if s.initialized and not (s.is_constant or s.is_immutable):
info = [s, " is a state variable with an initial value.\n"] info: CHECK_INFO = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
return results return results

@ -1,7 +1,12 @@
from typing import List
from slither.core.declarations import Contract
from slither.tools.upgradeability.checks.abstract_checks import ( from slither.tools.upgradeability.checks.abstract_checks import (
CheckClassification, CheckClassification,
AbstractCheck, AbstractCheck,
CHECK_INFO,
) )
from slither.utils.output import Output
class MissingVariable(AbstractCheck): class MissingVariable(AbstractCheck):
@ -45,9 +50,12 @@ Do not change the order of the state variables in the updated contract.
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True REQUIRE_CONTRACT_V2 = True
def _check(self): def _check(self) -> List[Output]:
contract1 = self.contract contract1 = self.contract
contract2 = self.contract_v2 contract2 = self.contract_v2
assert contract2
order1 = [ order1 = [
variable variable
for variable in contract1.state_variables_ordered for variable in contract1.state_variables_ordered
@ -63,7 +71,7 @@ Do not change the order of the state variables in the updated contract.
for idx, _ in enumerate(order1): for idx, _ in enumerate(order1):
variable1 = order1[idx] variable1 = order1[idx]
if len(order2) <= idx: if len(order2) <= idx:
info = ["Variable missing in ", contract2, ": ", variable1, "\n"] info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
@ -108,13 +116,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = True REQUIRE_PROXY = True
def _contract1(self): def _contract1(self) -> Contract:
return self.contract return self.contract
def _contract2(self): def _contract2(self) -> Contract:
assert self.proxy
return self.proxy return self.proxy
def _check(self): def _check(self) -> List[Output]:
contract1 = self._contract1() contract1 = self._contract1()
contract2 = self._contract2() contract2 = self._contract2()
order1 = [ order1 = [
@ -128,7 +137,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
if not (variable.is_constant or variable.is_immutable) if not (variable.is_constant or variable.is_immutable)
] ]
results = [] results: List[Output] = []
for idx, _ in enumerate(order1): for idx, _ in enumerate(order1):
if len(order2) <= idx: if len(order2) <= idx:
# Handle by MissingVariable # Handle by MissingVariable
@ -137,7 +146,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
variable1 = order1[idx] variable1 = order1[idx]
variable2 = order2[idx] variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type): if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info = [ info: CHECK_INFO = [
"Different variables between ", "Different variables between ",
contract1, contract1,
" and ", " and ",
@ -190,7 +199,8 @@ Respect the variable order of the original contract in the updated contract.
REQUIRE_PROXY = False REQUIRE_PROXY = False
REQUIRE_CONTRACT_V2 = True REQUIRE_CONTRACT_V2 = True
def _contract2(self): def _contract2(self) -> Contract:
assert self.contract_v2
return self.contract_v2 return self.contract_v2
@ -235,13 +245,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = True REQUIRE_PROXY = True
def _contract1(self): def _contract1(self) -> Contract:
return self.contract return self.contract
def _contract2(self): def _contract2(self) -> Contract:
assert self.proxy
return self.proxy return self.proxy
def _check(self): def _check(self) -> List[Output]:
contract1 = self._contract1() contract1 = self._contract1()
contract2 = self._contract2() contract2 = self._contract2()
order1 = [ order1 = [
@ -264,7 +275,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
while idx < len(order2): while idx < len(order2):
variable2 = order2[idx] variable2 = order2[idx]
info = ["Extra variables in ", contract2, ": ", variable2, "\n"] info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
idx = idx + 1 idx = idx + 1
@ -299,5 +310,6 @@ Ensure that all the new variables are expected.
REQUIRE_PROXY = False REQUIRE_PROXY = False
REQUIRE_CONTRACT_V2 = True REQUIRE_CONTRACT_V2 = True
def _contract2(self): def _contract2(self) -> Contract:
assert self.contract_v2
return self.contract_v2 return self.contract_v2

@ -6,7 +6,7 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from inspect import getsourcefile from inspect import getsourcefile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Union, List, Optional from typing import Union, List, Optional, Dict, Callable
import pytest import pytest
from solc_select import solc_select from solc_select import solc_select
@ -15,6 +15,7 @@ from solc_select.solc_select import valid_version as solc_valid_version
from slither import Slither from slither import Slither
from slither.core.cfg.node import Node, NodeType from slither.core.cfg.node import Node, NodeType
from slither.core.declarations import Function, Contract from slither.core.declarations import Function, Contract
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import ( from slither.slithir.operations import (
OperationWithLValue, OperationWithLValue,
@ -34,10 +35,11 @@ from slither.slithir.variables import (
ReferenceVariable, ReferenceVariable,
LocalIRVariable, LocalIRVariable,
StateIRVariable, StateIRVariable,
TemporaryVariableSSA,
) )
# Directory of currently executing script. Will be used as basis for temporary file names. # Directory of currently executing script. Will be used as basis for temporary file names.
SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent # type:ignore
def valid_version(ver: str) -> bool: def valid_version(ver: str) -> bool:
@ -53,15 +55,15 @@ def valid_version(ver: str) -> bool:
return False return False
def have_ssa_if_ir(function: Function): def have_ssa_if_ir(function: Function) -> None:
"""Verifies that all nodes in a function that have IR also have SSA IR""" """Verifies that all nodes in a function that have IR also have SSA IR"""
for n in function.nodes: for n in function.nodes:
if n.irs: if n.irs:
assert n.irs_ssa assert n.irs_ssa
# pylint: disable=too-many-branches # pylint: disable=too-many-branches, too-many-locals
def ssa_basic_properties(function: Function): def ssa_basic_properties(function: Function) -> None:
"""Verifies that basic properties of ssa holds """Verifies that basic properties of ssa holds
1. Every name is defined only once 1. Every name is defined only once
@ -75,12 +77,14 @@ def ssa_basic_properties(function: Function):
""" """
ssa_lvalues = set() ssa_lvalues = set()
ssa_rvalues = set() ssa_rvalues = set()
lvalue_assignments = {} lvalue_assignments: Dict[str, int] = {}
for n in function.nodes: for n in function.nodes:
for ir in n.irs: for ir in n.irs:
if isinstance(ir, OperationWithLValue): if isinstance(ir, OperationWithLValue) and ir.lvalue:
name = ir.lvalue.name name = ir.lvalue.name
if name is None:
continue
if name in lvalue_assignments: if name in lvalue_assignments:
lvalue_assignments[name] += 1 lvalue_assignments[name] += 1
else: else:
@ -93,8 +97,9 @@ def ssa_basic_properties(function: Function):
ssa_lvalues.add(ssa.lvalue) ssa_lvalues.add(ssa.lvalue)
# 2 (if Local/State Var) # 2 (if Local/State Var)
if isinstance(ssa.lvalue, (StateIRVariable, LocalIRVariable)): ssa_lvalue = ssa.lvalue
assert ssa.lvalue.index > 0 if isinstance(ssa_lvalue, (StateIRVariable, LocalIRVariable)):
assert ssa_lvalue.index > 0
for rvalue in filter( for rvalue in filter(
lambda x: not isinstance(x, (StateIRVariable, Constant)), ssa.read lambda x: not isinstance(x, (StateIRVariable, Constant)), ssa.read
@ -111,15 +116,18 @@ def ssa_basic_properties(function: Function):
undef_vars.add(rvalue.non_ssa_version) undef_vars.add(rvalue.non_ssa_version)
# 4 # 4
ssa_defs = defaultdict(int) ssa_defs: Dict[str, int] = defaultdict(int)
for v in ssa_lvalues: for v in ssa_lvalues:
if v and v.name:
ssa_defs[v.name] += 1 ssa_defs[v.name] += 1
for (k, n) in lvalue_assignments.items(): for (k, count) in lvalue_assignments.items():
assert ssa_defs[k] >= n assert ssa_defs[k] >= count
# Helper 5/6 # Helper 5/6
def check_property_5_and_6(variables, ssavars): def check_property_5_and_6(
variables: List[LocalVariable], ssavars: List[LocalIRVariable]
) -> None:
for var in filter(lambda x: x.name, variables): for var in filter(lambda x: x.name, variables):
ssa_vars = [x for x in ssavars if x.non_ssa_version == var] ssa_vars = [x for x in ssavars if x.non_ssa_version == var]
assert len(ssa_vars) == 1 assert len(ssa_vars) == 1
@ -136,7 +144,7 @@ def ssa_basic_properties(function: Function):
check_property_5_and_6(function.returns, function.returns_ssa) check_property_5_and_6(function.returns, function.returns_ssa)
def ssa_phi_node_properties(f: Function): def ssa_phi_node_properties(f: Function) -> None:
"""Every phi-function should have as many args as predecessors """Every phi-function should have as many args as predecessors
This does not apply if the phi-node refers to state variables, This does not apply if the phi-node refers to state variables,
@ -152,7 +160,7 @@ def ssa_phi_node_properties(f: Function):
# TODO (hbrodin): This should probably go into another file, not specific to SSA # TODO (hbrodin): This should probably go into another file, not specific to SSA
def dominance_properties(f: Function): def dominance_properties(f: Function) -> None:
"""Verifies properties related to dominators holds """Verifies properties related to dominators holds
1. Every node have an immediate dominator except entry_node which have none 1. Every node have an immediate dominator except entry_node which have none
@ -180,14 +188,16 @@ def dominance_properties(f: Function):
assert find_path(node.immediate_dominator, node) assert find_path(node.immediate_dominator, node)
def phi_values_inserted(f: Function): def phi_values_inserted(f: Function) -> None:
"""Verifies that phi-values are inserted at the right places """Verifies that phi-values are inserted at the right places
For every node that has a dominance frontier, any def (including For every node that has a dominance frontier, any def (including
phi) should be a phi function in its dominance frontier phi) should be a phi function in its dominance frontier
""" """
def have_phi_for_var(node: Node, var): def have_phi_for_var(
node: Node, var: Union[StateIRVariable, LocalIRVariable, TemporaryVariableSSA]
) -> bool:
"""Checks if a node has a phi-instruction for var """Checks if a node has a phi-instruction for var
The ssa version would ideally be checked, but then The ssa version would ideally be checked, but then
@ -198,7 +208,14 @@ def phi_values_inserted(f: Function):
non_ssa = var.non_ssa_version non_ssa = var.non_ssa_version
for ssa in node.irs_ssa: for ssa in node.irs_ssa:
if isinstance(ssa, Phi): if isinstance(ssa, Phi):
if non_ssa in map(lambda ssa_var: ssa_var.non_ssa_version, ssa.read): if non_ssa in map(
lambda ssa_var: ssa_var.non_ssa_version,
[
r
for r in ssa.read
if isinstance(r, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA))
],
):
return True return True
return False return False
@ -206,12 +223,15 @@ def phi_values_inserted(f: Function):
for df in node.dominance_frontier: for df in node.dominance_frontier:
for ssa in node.irs_ssa: for ssa in node.irs_ssa:
if isinstance(ssa, OperationWithLValue): if isinstance(ssa, OperationWithLValue):
if is_used_later(node, ssa.lvalue): ssa_lvalue = ssa.lvalue
assert have_phi_for_var(df, ssa.lvalue) if isinstance(
ssa_lvalue, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA)
) and is_used_later(node, ssa_lvalue):
assert have_phi_for_var(df, ssa_lvalue)
@contextmanager @contextmanager
def select_solc_version(version: Optional[str]): def select_solc_version(version: Optional[str]) -> None:
"""Selects solc version to use for running tests. """Selects solc version to use for running tests.
If no version is provided, latest is used.""" If no version is provided, latest is used."""
@ -256,17 +276,17 @@ def slither_from_source(source_code: str, solc_version: Optional[str] = None):
pathlib.Path(fname).unlink() pathlib.Path(fname).unlink()
def verify_properties_hold(source_code_or_slither: Union[str, Slither]): def verify_properties_hold(source_code_or_slither: Union[str, Slither]) -> None:
"""Ensures that basic properties of SSA hold true""" """Ensures that basic properties of SSA hold true"""
def verify_func(func: Function): def verify_func(func: Function) -> None:
have_ssa_if_ir(func) have_ssa_if_ir(func)
phi_values_inserted(func) phi_values_inserted(func)
ssa_basic_properties(func) ssa_basic_properties(func)
ssa_phi_node_properties(func) ssa_phi_node_properties(func)
dominance_properties(func) dominance_properties(func)
def verify(slither): def verify(slither: Slither) -> None:
for cu in slither.compilation_units: for cu in slither.compilation_units:
for func in cu.functions_and_modifiers: for func in cu.functions_and_modifiers:
_dump_function(func) _dump_function(func)
@ -280,11 +300,12 @@ def verify_properties_hold(source_code_or_slither: Union[str, Slither]):
if isinstance(source_code_or_slither, Slither): if isinstance(source_code_or_slither, Slither):
verify(source_code_or_slither) verify(source_code_or_slither)
else: else:
slither: Slither
with slither_from_source(source_code_or_slither) as slither: with slither_from_source(source_code_or_slither) as slither:
verify(slither) verify(slither)
def _dump_function(f: Function): def _dump_function(f: Function) -> None:
"""Helper function to print nodes/ssa ir for a function or modifier""" """Helper function to print nodes/ssa ir for a function or modifier"""
print(f"---- {f.name} ----") print(f"---- {f.name} ----")
for n in f.nodes: for n in f.nodes:
@ -294,13 +315,13 @@ def _dump_function(f: Function):
print("") print("")
def _dump_functions(c: Contract): def _dump_functions(c: Contract) -> None:
"""Helper function to print functions and modifiers of a contract""" """Helper function to print functions and modifiers of a contract"""
for f in c.functions_and_modifiers: for f in c.functions_and_modifiers:
_dump_function(f) _dump_function(f)
def get_filtered_ssa(f: Union[Function, Node], flt) -> List[Operation]: def get_filtered_ssa(f: Union[Function, Node], flt: Callable) -> List[Operation]:
"""Returns a list of all ssanodes filtered by filter for all nodes in function f""" """Returns a list of all ssanodes filtered by filter for all nodes in function f"""
if isinstance(f, Function): if isinstance(f, Function):
return [ssanode for node in f.nodes for ssanode in node.irs_ssa if flt(ssanode)] return [ssanode for node in f.nodes for ssanode in node.irs_ssa if flt(ssanode)]
@ -314,7 +335,7 @@ def get_ssa_of_type(f: Union[Function, Node], ssatype) -> List[Operation]:
return get_filtered_ssa(f, lambda ssanode: isinstance(ssanode, ssatype)) return get_filtered_ssa(f, lambda ssanode: isinstance(ssanode, ssatype))
def test_multi_write(): def test_multi_write() -> None:
contract = """ contract = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract Test { contract Test {
@ -327,7 +348,7 @@ def test_multi_write():
verify_properties_hold(contract) verify_properties_hold(contract)
def test_single_branch_phi(): def test_single_branch_phi() -> None:
contract = """ contract = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract Test { contract Test {
@ -342,7 +363,7 @@ def test_single_branch_phi():
verify_properties_hold(contract) verify_properties_hold(contract)
def test_basic_phi(): def test_basic_phi() -> None:
contract = """ contract = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract Test { contract Test {
@ -359,7 +380,7 @@ def test_basic_phi():
verify_properties_hold(contract) verify_properties_hold(contract)
def test_basic_loop_phi(): def test_basic_loop_phi() -> None:
contract = """ contract = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract Test { contract Test {
@ -375,7 +396,7 @@ def test_basic_loop_phi():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_phi_propagation_loop(): def test_phi_propagation_loop() -> None:
contract = """ contract = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract Test { contract Test {
@ -396,7 +417,7 @@ def test_phi_propagation_loop():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_free_function_properties(): def test_free_function_properties() -> None:
contract = """ contract = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
@ -417,7 +438,7 @@ def test_free_function_properties():
verify_properties_hold(contract) verify_properties_hold(contract)
def test_ssa_inter_transactional(): def test_ssa_inter_transactional() -> None:
source = """ source = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract A { contract A {
@ -460,7 +481,7 @@ def test_ssa_inter_transactional():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_ssa_phi_callbacks(): def test_ssa_phi_callbacks() -> None:
source = """ source = """
pragma solidity ^0.8.11; pragma solidity ^0.8.11;
contract A { contract A {
@ -519,7 +540,7 @@ def test_ssa_phi_callbacks():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_storage_refers_to(): def test_storage_refers_to() -> None:
"""Test the storage aspects of the SSA IR """Test the storage aspects of the SSA IR
When declaring a var as being storage, start tracking what storage it refers_to. When declaring a var as being storage, start tracking what storage it refers_to.

Loading…
Cancel
Save