pull/1666/head
Feist Josselin 2 years ago
parent 060f550b0d
commit bd2d572f48
  1. 6
      examples/scripts/data_dependency.py
  2. 1
      examples/scripts/variable_in_condition.py
  3. 4
      slither/__main__.py
  4. 26
      slither/analyses/data_dependency/data_dependency.py
  5. 27
      slither/core/declarations/contract.py
  6. 6
      slither/core/declarations/custom_error.py
  7. 3
      slither/core/declarations/using_for_top_level.py
  8. 5
      slither/core/variables/variable.py
  9. 2
      slither/detectors/statements/costly_operations_in_loop.py
  10. 2
      slither/detectors/statements/write_after_write.py
  11. 3
      slither/printers/call/call_graph.py
  12. 3
      slither/printers/summary/constructor_calls.py
  13. 19
      slither/printers/summary/contract.py
  14. 3
      slither/printers/summary/variable_order.py
  15. 2
      slither/slithir/utils/ssa.py
  16. 2
      slither/slithir/utils/utils.py
  17. 4
      slither/slithir/variables/local_variable.py
  18. 4
      slither/slithir/variables/state_variable.py
  19. 5
      slither/slithir/variables/variable.py
  20. 78
      slither/solc_parsing/declarations/contract.py
  21. 10
      slither/solc_parsing/declarations/function.py
  22. 16
      slither/solc_parsing/variables/variable_declaration.py
  23. 5
      slither/tools/doctor/checks/versions.py
  24. 5
      slither/tools/read_storage/utils/utils.py
  25. 8
      slither/tools/upgradeability/checks/variable_initialization.py
  26. 38
      slither/tools/upgradeability/checks/variables_order.py
  27. 93
      tests/test_ssa_generation.py

@ -18,6 +18,8 @@ assert len(contracts) == 1
contract = contracts[0]
destination = contract.get_state_variable_from_name("destination")
source = contract.get_state_variable_from_name("source")
assert source
assert destination
print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}")
assert not is_dependent(source, destination, contract)
@ -47,9 +49,11 @@ print(f"{destination} is tainted {is_tainted(destination, contract)}")
assert is_tainted(destination, contract)
destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1")
assert destination_indirect_1
print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}")
assert is_tainted(destination_indirect_1, contract)
destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2")
assert destination_indirect_2
print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}")
assert is_tainted(destination_indirect_2, contract)
@ -88,6 +92,8 @@ contract = contracts[0]
contract_derived = slither.get_contract_from_name("Derived")[0]
destination = contract.get_state_variable_from_name("destination")
source = contract.get_state_variable_from_name("source")
assert destination
assert source
print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}")
assert not is_dependent(destination, source, contract)

@ -14,6 +14,7 @@ assert len(contracts) == 1
contract = contracts[0]
# Get the variable
var_a = contract.get_state_variable_from_name("a")
assert var_a
# Get the functions reading the variable
functions_reading_a = contract.get_functions_reading_from_variable(var_a)

@ -615,7 +615,9 @@ def parse_args(
class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers()
output_detectors(detectors)
parser.exit()

@ -2,7 +2,7 @@
Compute the data depenency between all the SSA variables
"""
from collections import defaultdict
from typing import Union, Set, Dict, TYPE_CHECKING
from typing import Union, Set, Dict, TYPE_CHECKING, List
from slither.core.cfg.node import Node
from slither.core.declarations import (
@ -20,6 +20,7 @@ from slither.core.solidity_types.type import Type
from slither.core.variables.top_level_variable import TopLevelVariable
from slither.core.variables.variable import Variable
from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation
from slither.slithir.utils.utils import LVALUE
from slither.slithir.variables import (
Constant,
LocalIRVariable,
@ -29,6 +30,7 @@ from slither.slithir.variables import (
TemporaryVariableSSA,
TupleVariableSSA,
)
from slither.slithir.variables.variable import SlithIRVariable
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
@ -393,13 +395,9 @@ def transitive_close_dependencies(
while changed:
changed = False
to_add = defaultdict(set)
[ # pylint: disable=expression-not-assigned
[
for key, items in context.context[context_key].items():
for item in items & keys:
to_add[key].update(context.context[context_key][item] - {key} - items)
for item in items & keys
]
for key, items in context.context[context_key].items()
]
for k, v in to_add.items():
# Because we dont have any check on the update operation
# We might update an empty set with an empty set
@ -418,20 +416,20 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote
function.context[KEY_SSA][lvalue] = set()
if not is_protected:
function.context[KEY_SSA_UNPROTECTED][lvalue] = set()
read: Union[List[Union[LVALUE, SolidityVariableComposed]], List[SlithIRVariable]]
if isinstance(ir, Index):
read = [ir.variable_left]
elif isinstance(ir, InternalCall):
elif isinstance(ir, InternalCall) and ir.function:
read = ir.function.return_values_ssa
else:
read = ir.read
# pylint: disable=expression-not-assigned
[function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)]
for v in read:
if not isinstance(v, Constant):
function.context[KEY_SSA][lvalue].add(v)
if not is_protected:
[
for v in read:
if not isinstance(v, Constant):
function.context[KEY_SSA_UNPROTECTED][lvalue].add(v)
for v in read
if not isinstance(v, Constant)
]
def compute_dependency_function(function: Function) -> None:

@ -49,6 +49,9 @@ if TYPE_CHECKING:
LOGGER = logging.getLogger("Contract")
USING_FOR_KEY = Union[str, Type]
USING_FOR_ITEM = List[Union[Type, Function]]
class Contract(SourceMapping): # pylint: disable=too-many-public-methods
"""
@ -80,8 +83,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._custom_errors: Dict[str, "CustomErrorContract"] = {}
# The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None
self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {}
self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None
self._kind: Optional[str] = None
self._is_interface: bool = False
self._is_library: bool = False
@ -123,7 +126,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._name
@name.setter
def name(self, name: str):
def name(self, name: str) -> None:
self._name = name
@property
@ -133,7 +136,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._id
@id.setter
def id(self, new_id):
def id(self, new_id: int) -> None:
"""Unique id."""
self._id = new_id
@ -146,7 +149,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._kind
@contract_kind.setter
def contract_kind(self, kind):
def contract_kind(self, kind: str) -> None:
self._kind = kind
@property
@ -154,7 +157,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_interface
@is_interface.setter
def is_interface(self, is_interface: bool):
def is_interface(self, is_interface: bool) -> None:
self._is_interface = is_interface
@property
@ -162,7 +165,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_library
@is_library.setter
def is_library(self, is_library: bool):
def is_library(self, is_library: bool) -> None:
self._is_library = is_library
# endregion
@ -266,16 +269,18 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
###################################################################################
@property
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
return self._using_for
@property
def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]:
def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
"""
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
"""
def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict:
def _merge_using_for(
uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM]
) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
result = {**uf1, **uf2}
for key, value in result.items():
if key in uf1 and key in uf2:
@ -1452,7 +1457,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
result = func.get_last_ssa_state_variables_instances()
for variable_name, instances in result.items():
# TODO: investigate the next operation
last_state_variables_instances[variable_name] += instances
last_state_variables_instances[variable_name] += list(instances)
for func in self.functions + list(self.modifiers):
func.fix_phi(last_state_variables_instances, initial_state_variables_instances)

@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Optional, Type, Union
from typing import List, TYPE_CHECKING, Optional, Type
from slither.core.solidity_types import UserDefinedType
from slither.core.source_mapping.source_mapping import SourceMapping
@ -42,7 +42,7 @@ class CustomError(SourceMapping):
###################################################################################
@staticmethod
def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str:
def _convert_type_for_solidity_signature(t: Optional[Type]) -> str:
# pylint: disable=import-outside-toplevel
from slither.core.declarations import Contract
@ -72,7 +72,7 @@ class CustomError(SourceMapping):
Returns:
"""
parameters = [x.type for x in self.parameters]
parameters = [x.type for x in self.parameters if x.type]
self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")"
solidity_parameters = map(self._convert_type_for_solidity_signature, parameters)
self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")"

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, List, Dict, Union
from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM
from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel
@ -14,5 +15,5 @@ class UsingForTopLevel(TopLevel):
self.file_scope: "FileScope" = scope
@property
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
return self._using_for

@ -160,8 +160,8 @@ class Variable(SourceMapping):
return (
self.name,
[str(x) for x in export_nested_types_from_variable(self)],
[str(x) for x in export_return_type_from_variable(self)],
[str(x) for x in export_nested_types_from_variable(self)], # type: ignore
[str(x) for x in export_return_type_from_variable(self)], # type: ignore
)
@property
@ -179,4 +179,5 @@ class Variable(SourceMapping):
return f'{name}({",".join(parameters)})'
def __str__(self) -> str:
assert self._name
return self._name

@ -43,7 +43,7 @@ def costly_operations_in_loop(
if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable):
ret.append(ir.node)
break
if isinstance(ir, (InternalCall)):
if isinstance(ir, (InternalCall)) and ir.function:
costly_operations_in_loop(ir.function.entry_point, in_loop_counter, visited, ret)
for son in node.sons:

@ -37,6 +37,8 @@ def _handle_ir(
_remove_states(written)
if isinstance(ir, InternalCall):
if not ir.function:
return
if ir.function.all_high_level_calls() or ir.function.all_library_calls():
_remove_states(written)

@ -13,6 +13,7 @@ from slither.core.declarations.function import Function
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.variables.variable import Variable
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.output import Output
def _contract_subgraph(contract: Contract) -> str:
@ -222,7 +223,7 @@ class PrinterCallGraph(AbstractPrinter):
WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph"
def output(self, filename):
def output(self, filename: str) -> Output:
"""
Output the graph in filename
Args:

@ -5,6 +5,7 @@ from slither.core.declarations import Function
from slither.core.source_mapping.source_mapping import Source
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils import output
from slither.utils.output import Output
def _get_source_code(cst: Function) -> str:
@ -17,7 +18,7 @@ class ConstructorPrinter(AbstractPrinter):
ARGUMENT = "constructor-calls"
HELP = "Print the constructors executed"
def output(self, _filename):
def output(self, _filename: str) -> Output:
info = ""
for contract in self.slither.contracts_derived:
stack_name = []

@ -2,9 +2,13 @@
Module printing summary of the contract
"""
import collections
from typing import Dict, List
from slither.core.declarations import FunctionContract
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils import output
from slither.utils.colors import blue, green, magenta
from slither.utils.output import Output
class ContractSummary(AbstractPrinter):
@ -13,7 +17,7 @@ class ContractSummary(AbstractPrinter):
WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary"
def output(self, _filename): # pylint: disable=too-many-locals
def output(self, _filename: str) -> Output: # pylint: disable=too-many-locals
"""
_filename is not used
Args:
@ -53,17 +57,16 @@ class ContractSummary(AbstractPrinter):
# Order the function with
# contract_declarer -> list_functions
public = [
public_function = [
(f.contract_declarer.name, f)
for f in c.functions
if (not f.is_shadowed and not f.is_constructor_variables)
]
collect = collections.defaultdict(list)
for a, b in public:
collect: Dict[str, List[FunctionContract]] = collections.defaultdict(list)
for a, b in public_function:
collect[a].append(b)
public = list(collect.items())
for contract, functions in public:
for contract, functions in collect.items():
txt += blue(f" - From {contract}\n")
functions = sorted(functions, key=lambda f: f.full_name)
@ -90,7 +93,7 @@ class ContractSummary(AbstractPrinter):
self.info(txt)
res = self.generate_output(txt)
for contract, additional_fields in all_contracts:
res.add(contract, additional_fields=additional_fields)
for current_contract, current_additional_fields in all_contracts:
res.add(current_contract, additional_fields=current_additional_fields)
return res

@ -4,6 +4,7 @@
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.myprettytable import MyPrettyTable
from slither.utils.output import Output
class VariableOrder(AbstractPrinter):
@ -13,7 +14,7 @@ class VariableOrder(AbstractPrinter):
WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variable-order"
def output(self, _filename):
def output(self, _filename: str) -> Output:
"""
_filename is not used
Args:

@ -366,7 +366,7 @@ def last_name(
def is_used_later(
initial_node: Node,
variable: Union[StateIRVariable, LocalVariable],
variable: Union[StateIRVariable, LocalVariable, TemporaryVariableSSA],
) -> bool:
# TODO: does not handle the case where its read and written in the declaration node
# It can be problematic if this happens in a loop/if structure

@ -46,7 +46,7 @@ def is_valid_rvalue(v: SourceMapping) -> bool:
)
def is_valid_lvalue(v) -> bool:
def is_valid_lvalue(v: SourceMapping) -> bool:
return isinstance(
v,
(

@ -41,11 +41,11 @@ class LocalIRVariable(
self._non_ssa_version = local_variable
@property
def index(self):
def index(self) -> int:
return self._index
@index.setter
def index(self, idx):
def index(self, idx: int) -> None:
self._index = idx
@property

@ -30,11 +30,11 @@ class StateIRVariable(
self._non_ssa_version = state_variable
@property
def index(self):
def index(self) -> int:
return self._index
@index.setter
def index(self, idx):
def index(self, idx: int) -> None:
self._index = idx
@property

@ -7,8 +7,9 @@ class SlithIRVariable(Variable):
self._index = 0
@property
def ssa_name(self):
def ssa_name(self) -> str:
assert self.name
return self.name
def __str__(self):
def __str__(self) -> str:
return self.ssa_name

@ -1,6 +1,6 @@
import logging
import re
from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set
from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence
from slither.core.declarations import (
Modifier,
@ -9,10 +9,10 @@ from slither.core.declarations import (
StructureContract,
Function,
)
from slither.core.declarations.contract import Contract
from slither.core.declarations.contract import Contract, USING_FOR_KEY
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type
from slither.core.solidity_types import ElementaryType, TypeAliasContract
from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
@ -302,7 +302,7 @@ class ContractSolc(CallerContextExpression):
st.set_contract(self._contract)
st.set_offset(struct["src"], self._contract.compilation_unit)
st_parser = StructureContractSolc(st, struct, self)
st_parser = StructureContractSolc(st, struct, self) # type: ignore
self._contract.structures_as_dict[st.name] = st
self._structures_parser.append(st_parser)
@ -312,7 +312,7 @@ class ContractSolc(CallerContextExpression):
for struct in self._structuresNotParsed:
self._parse_struct(struct)
self._structuresNotParsed = None
self._structuresNotParsed = []
def _parse_custom_error(self, custom_error: Dict) -> None:
ce = CustomErrorContract(self.compilation_unit)
@ -329,7 +329,7 @@ class ContractSolc(CallerContextExpression):
for custom_error in self._customErrorParsed:
self._parse_custom_error(custom_error)
self._customErrorParsed = None
self._customErrorParsed = []
def parse_state_variables(self) -> None:
for father in self._contract.inheritance_reverse:
@ -356,6 +356,7 @@ class ContractSolc(CallerContextExpression):
var_parser = StateVariableSolc(var, varNotParsed)
self._variables_parser.append(var_parser)
assert var.name
self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var])
@ -365,7 +366,7 @@ class ContractSolc(CallerContextExpression):
modif.set_contract(self._contract)
modif.set_contract_declarer(self._contract)
modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser)
modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) # type: ignore
self._contract.compilation_unit.add_modifier(modif)
self._modifiers_no_params.append(modif_parser)
self._modifiers_parser.append(modif_parser)
@ -375,7 +376,7 @@ class ContractSolc(CallerContextExpression):
def parse_modifiers(self) -> None:
for modifier in self._modifiersNotParsed:
self._parse_modifier(modifier)
self._modifiersNotParsed = None
self._modifiersNotParsed = []
def _parse_function(self, function_data: Dict) -> None:
func = FunctionContract(self._contract.compilation_unit)
@ -383,7 +384,7 @@ class ContractSolc(CallerContextExpression):
func.set_contract(self._contract)
func.set_contract_declarer(self._contract)
func_parser = FunctionSolc(func, function_data, self, self._slither_parser)
func_parser = FunctionSolc(func, function_data, self, self._slither_parser) # type: ignore
self._contract.compilation_unit.add_function(func)
self._functions_no_params.append(func_parser)
self._functions_parser.append(func_parser)
@ -395,7 +396,7 @@ class ContractSolc(CallerContextExpression):
for function in self._functionsNotParsed:
self._parse_function(function)
self._functionsNotParsed = None
self._functionsNotParsed = []
# endregion
###################################################################################
@ -439,7 +440,8 @@ class ContractSolc(CallerContextExpression):
Cls_parser,
self._modifiers_parser,
)
self._contract.set_modifiers(modifiers)
# modifiers will be using Modifier so we can ignore the next type check
self._contract.set_modifiers(modifiers) # type: ignore
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing params {e}")
self._modifiers_no_params = []
@ -459,7 +461,8 @@ class ContractSolc(CallerContextExpression):
Cls_parser,
self._functions_parser,
)
self._contract.set_functions(functions)
# function will be using FunctionContract so we can ignore the next type check
self._contract.set_functions(functions) # type: ignore
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing params {e}")
self._functions_no_params = []
@ -470,7 +473,7 @@ class ContractSolc(CallerContextExpression):
Cls_parser: Callable,
element_parser: FunctionSolc,
explored_reference_id: Set[str],
parser: List[FunctionSolc],
parser: Union[List[FunctionSolc], List[ModifierSolc]],
all_elements: Dict[str, Function],
) -> None:
elem = Cls(self._contract.compilation_unit)
@ -508,13 +511,13 @@ class ContractSolc(CallerContextExpression):
def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals
self,
elements_no_params: List[FunctionSolc],
elements_no_params: Sequence[FunctionSolc],
getter: Callable[["ContractSolc"], List[FunctionSolc]],
getter_available: Callable[[Contract], List[FunctionContract]],
Cls: Callable,
Cls_parser: Callable,
parser: List[FunctionSolc],
) -> Dict[str, Union[FunctionContract, Modifier]]:
parser: Union[List[FunctionSolc], List[ModifierSolc]],
) -> Dict[str, Function]:
"""
Analyze the parameters of the given elements (Function or Modifier).
The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier)
@ -526,13 +529,13 @@ class ContractSolc(CallerContextExpression):
:param Cls: Class to create for collision
:return:
"""
all_elements = {}
all_elements: Dict[str, Function] = {}
explored_reference_id = set()
explored_reference_id: Set[str] = set()
try:
for father in self._contract.inheritance:
father_parser = self._slither_parser.underlying_contract_to_parser[father]
for element_parser in getter(father_parser):
for element_parser in getter(father_parser): # type: ignore
self._analyze_params_element(
Cls, Cls_parser, element_parser, explored_reference_id, parser, all_elements
)
@ -597,7 +600,7 @@ class ContractSolc(CallerContextExpression):
if self.is_compact_ast:
for using_for in self._usingForNotParsed:
if "typeName" in using_for and using_for["typeName"]:
type_name = parse_type(using_for["typeName"], self)
type_name: USING_FOR_KEY = parse_type(using_for["typeName"], self)
else:
type_name = "*"
if type_name not in self._contract.using_for:
@ -616,7 +619,7 @@ class ContractSolc(CallerContextExpression):
assert children and len(children) <= 2
if len(children) == 2:
new = parse_type(children[0], self)
old = parse_type(children[1], self)
old: USING_FOR_KEY = parse_type(children[1], self)
else:
new = parse_type(children[0], self)
old = "*"
@ -627,7 +630,7 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}")
def _analyze_function_list(self, function_list: List, type_name: Type) -> None:
def _analyze_function_list(self, function_list: List, type_name: USING_FOR_KEY) -> None:
for f in function_list:
full_name_split = f["function"]["name"].split(".")
if len(full_name_split) == 1:
@ -646,7 +649,9 @@ class ContractSolc(CallerContextExpression):
function_name = full_name_split[2]
self._analyze_library_function(library_name, function_name, type_name)
def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None:
def _check_aliased_import(
self, first_part: str, function_name: str, type_name: USING_FOR_KEY
) -> None:
# We check if the first part appear as alias for an import
# if it is then function_name must be a top level function
# otherwise it's a library function
@ -656,13 +661,13 @@ class ContractSolc(CallerContextExpression):
return
self._analyze_library_function(first_part, function_name, type_name)
def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None:
def _analyze_top_level_function(self, function_name: str, type_name: USING_FOR_KEY) -> None:
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)
def _analyze_library_function(
self, library_name: str, function_name: str, type_name: Type
self, library_name: str, function_name: str, type_name: USING_FOR_KEY
) -> None:
# Get the library function
found = False
@ -689,22 +694,13 @@ class ContractSolc(CallerContextExpression):
# for enum, we can parse and analyze it
# at the same time
self._analyze_enum(enum)
self._enumsNotParsed = None
self._enumsNotParsed = []
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing enum {e}")
def _analyze_enum(
self,
enum: Dict[
str,
Union[
str,
int,
List[Dict[str, Union[int, str]]],
Dict[str, str],
List[Dict[str, Union[Dict[str, str], int, str]]],
],
],
enum: Dict,
) -> None:
# Enum can be parsed in one pass
if self.is_compact_ast:
@ -753,13 +749,13 @@ class ContractSolc(CallerContextExpression):
event.set_contract(self._contract)
event.set_offset(event_to_parse["src"], self._contract.compilation_unit)
event_parser = EventSolc(event, event_to_parse, self)
event_parser.analyze(self)
event_parser = EventSolc(event, event_to_parse, self) # type: ignore
event_parser.analyze(self) # type: ignore
self._contract.events_as_dict[event.full_name] = event
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing event {e}")
self._eventsNotParsed = None
self._eventsNotParsed = []
# endregion
###################################################################################
@ -768,7 +764,7 @@ class ContractSolc(CallerContextExpression):
###################################################################################
###################################################################################
def delete_content(self):
def delete_content(self) -> None:
"""
Remove everything not parsed from the contract
This is used only if something went wrong with the inheritance parsing
@ -810,7 +806,7 @@ class ContractSolc(CallerContextExpression):
###################################################################################
###################################################################################
def __hash__(self):
def __hash__(self) -> int:
return self._contract.id
# endregion

@ -242,7 +242,7 @@ class FunctionSolc(CallerContextExpression):
if "payable" in attributes:
self._function.payable = attributes["payable"]
def analyze_params(self):
def analyze_params(self) -> None:
# Can be re-analyzed due to inheritance
if self._params_was_analyzed:
return
@ -272,7 +272,7 @@ class FunctionSolc(CallerContextExpression):
if returns:
self._parse_returns(returns)
def analyze_content(self):
def analyze_content(self) -> None:
if self._content_was_analyzed:
return
@ -308,8 +308,8 @@ class FunctionSolc(CallerContextExpression):
for node_parser in self._node_to_nodesolc.values():
node_parser.analyze_expressions(self)
for node_parser in self._node_to_yulobject.values():
node_parser.analyze_expressions()
for yul_parser in self._node_to_yulobject.values():
yul_parser.analyze_expressions()
self._rewrite_ternary_as_if_else()
@ -1297,7 +1297,7 @@ class FunctionSolc(CallerContextExpression):
son.remove_father(node)
node.set_sons(new_sons)
def _remove_alone_endif(self):
def _remove_alone_endif(self) -> None:
"""
Can occur on:
if(..){

@ -1,6 +1,6 @@
import logging
import re
from typing import Dict, Optional
from typing import Dict, Optional, Union
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.expressions.expression_parsing import parse_expression
@ -42,12 +42,12 @@ class VariableDeclarationSolc:
self._variable = variable
self._was_analyzed = False
self._elem_to_parse = None
self._initializedNotParsed = None
self._elem_to_parse: Optional[Union[Dict, UnknownType]] = None
self._initializedNotParsed: Optional[Dict] = None
self._is_compact_ast = False
self._reference_id = None
self._reference_id: Optional[int] = None
if "nodeType" in variable_data:
self._is_compact_ast = True
@ -87,7 +87,7 @@ class VariableDeclarationSolc:
declaration = variable_data["children"][0]
self._init_from_declaration(declaration, init)
elif nodeType == "VariableDeclaration":
self._init_from_declaration(variable_data, False)
self._init_from_declaration(variable_data, None)
else:
raise ParsingError(f"Incorrect variable declaration type {nodeType}")
@ -101,6 +101,7 @@ class VariableDeclarationSolc:
Return the solc id. It can be compared with the referencedDeclaration attr
Returns None if it was not parsed (legacy AST)
"""
assert self._reference_id
return self._reference_id
def _handle_comment(self, attributes: Dict) -> None:
@ -127,7 +128,7 @@ class VariableDeclarationSolc:
self._variable.visibility = "internal"
def _init_from_declaration(
self, var: Dict, init: Optional[bool]
self, var: Dict, init: Optional[Dict]
) -> None: # pylint: disable=too-many-branches
if self._is_compact_ast:
attributes = var
@ -195,7 +196,7 @@ class VariableDeclarationSolc:
self._initializedNotParsed = init
elif len(var["children"]) in [0, 1]:
self._variable.initialized = False
self._initializedNotParsed = []
self._initializedNotParsed = None
else:
assert len(var["children"]) == 2
self._variable.initialized = True
@ -212,5 +213,6 @@ class VariableDeclarationSolc:
self._elem_to_parse = None
if self._variable.initialized:
assert self._initializedNotParsed
self._variable.expression = parse_expression(self._initializedNotParsed, caller_context)
self._initializedNotParsed = None

@ -1,6 +1,6 @@
from importlib import metadata
import json
from typing import Optional
from typing import Optional, Any
import urllib
from packaging.version import parse, Version
@ -17,6 +17,7 @@ def get_installed_version(name: str) -> Optional[Version]:
def get_github_version(name: str) -> Optional[Version]:
try:
# type: ignore
with urllib.request.urlopen(
f"https://api.github.com/repos/crytic/{name}/releases/latest"
) as response:
@ -27,7 +28,7 @@ def get_github_version(name: str) -> Optional[Version]:
return None
def show_versions(**_kwargs) -> None:
def show_versions(**_kwargs: Any) -> None:
versions = {
"Slither": (get_installed_version("slither-analyzer"), get_github_version("slither")),
"crytic-compile": (

@ -2,6 +2,7 @@ from typing import Union
from eth_typing.evm import ChecksumAddress
from eth_utils import to_int, to_text, to_checksum_address
from web3 import Web3
def get_offset_value(hex_bytes: bytes, offset: int, size: int) -> bytes:
@ -48,7 +49,7 @@ def coerce_type(
if "address" in solidity_type:
if not isinstance(value, (str, bytes)):
raise TypeError
return to_checksum_address(value)
return to_checksum_address(value) # type: ignore
if not isinstance(value, bytes):
raise TypeError
@ -56,7 +57,7 @@ def coerce_type(
def get_storage_data(
web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str]
web3: Web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str]
) -> bytes:
"""
Retrieves the storage data from the blockchain at target address and slot.

@ -1,7 +1,11 @@
from typing import List
from slither.tools.upgradeability.checks.abstract_checks import (
CheckClassification,
AbstractCheck,
CHECK_INFO,
)
from slither.utils.output import Output
class VariableWithInit(AbstractCheck):
@ -37,11 +41,11 @@ Using initialize functions to write initial values in state variables.
REQUIRE_CONTRACT = True
def _check(self):
def _check(self) -> List[Output]:
results = []
for s in self.contract.state_variables_ordered:
if s.initialized and not (s.is_constant or s.is_immutable):
info = [s, " is a state variable with an initial value.\n"]
info: CHECK_INFO = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info)
results.append(json)
return results

@ -1,7 +1,12 @@
from typing import List
from slither.core.declarations import Contract
from slither.tools.upgradeability.checks.abstract_checks import (
CheckClassification,
AbstractCheck,
CHECK_INFO,
)
from slither.utils.output import Output
class MissingVariable(AbstractCheck):
@ -45,9 +50,12 @@ Do not change the order of the state variables in the updated contract.
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
def _check(self):
def _check(self) -> List[Output]:
contract1 = self.contract
contract2 = self.contract_v2
assert contract2
order1 = [
variable
for variable in contract1.state_variables_ordered
@ -63,7 +71,7 @@ Do not change the order of the state variables in the updated contract.
for idx, _ in enumerate(order1):
variable1 = order1[idx]
if len(order2) <= idx:
info = ["Variable missing in ", contract2, ": ", variable1, "\n"]
info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"]
json = self.generate_result(info)
results.append(json)
@ -108,13 +116,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
def _contract1(self):
def _contract1(self) -> Contract:
return self.contract
def _contract2(self):
def _contract2(self) -> Contract:
assert self.proxy
return self.proxy
def _check(self):
def _check(self) -> List[Output]:
contract1 = self._contract1()
contract2 = self._contract2()
order1 = [
@ -128,7 +137,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
if not (variable.is_constant or variable.is_immutable)
]
results = []
results: List[Output] = []
for idx, _ in enumerate(order1):
if len(order2) <= idx:
# Handle by MissingVariable
@ -137,7 +146,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info = [
info: CHECK_INFO = [
"Different variables between ",
contract1,
" and ",
@ -190,7 +199,8 @@ Respect the variable order of the original contract in the updated contract.
REQUIRE_PROXY = False
REQUIRE_CONTRACT_V2 = True
def _contract2(self):
def _contract2(self) -> Contract:
assert self.contract_v2
return self.contract_v2
@ -235,13 +245,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
def _contract1(self):
def _contract1(self) -> Contract:
return self.contract
def _contract2(self):
def _contract2(self) -> Contract:
assert self.proxy
return self.proxy
def _check(self):
def _check(self) -> List[Output]:
contract1 = self._contract1()
contract2 = self._contract2()
order1 = [
@ -264,7 +275,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
while idx < len(order2):
variable2 = order2[idx]
info = ["Extra variables in ", contract2, ": ", variable2, "\n"]
info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1
@ -299,5 +310,6 @@ Ensure that all the new variables are expected.
REQUIRE_PROXY = False
REQUIRE_CONTRACT_V2 = True
def _contract2(self):
def _contract2(self) -> Contract:
assert self.contract_v2
return self.contract_v2

@ -6,7 +6,7 @@ from collections import defaultdict
from contextlib import contextmanager
from inspect import getsourcefile
from tempfile import NamedTemporaryFile
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict, Callable
import pytest
from solc_select import solc_select
@ -15,6 +15,7 @@ from solc_select.solc_select import valid_version as solc_valid_version
from slither import Slither
from slither.core.cfg.node import Node, NodeType
from slither.core.declarations import Function, Contract
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import (
OperationWithLValue,
@ -34,10 +35,11 @@ from slither.slithir.variables import (
ReferenceVariable,
LocalIRVariable,
StateIRVariable,
TemporaryVariableSSA,
)
# Directory of currently executing script. Will be used as basis for temporary file names.
SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent
SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent # type:ignore
def valid_version(ver: str) -> bool:
@ -53,15 +55,15 @@ def valid_version(ver: str) -> bool:
return False
def have_ssa_if_ir(function: Function):
def have_ssa_if_ir(function: Function) -> None:
"""Verifies that all nodes in a function that have IR also have SSA IR"""
for n in function.nodes:
if n.irs:
assert n.irs_ssa
# pylint: disable=too-many-branches
def ssa_basic_properties(function: Function):
# pylint: disable=too-many-branches, too-many-locals
def ssa_basic_properties(function: Function) -> None:
"""Verifies that basic properties of ssa holds
1. Every name is defined only once
@ -75,12 +77,14 @@ def ssa_basic_properties(function: Function):
"""
ssa_lvalues = set()
ssa_rvalues = set()
lvalue_assignments = {}
lvalue_assignments: Dict[str, int] = {}
for n in function.nodes:
for ir in n.irs:
if isinstance(ir, OperationWithLValue):
if isinstance(ir, OperationWithLValue) and ir.lvalue:
name = ir.lvalue.name
if name is None:
continue
if name in lvalue_assignments:
lvalue_assignments[name] += 1
else:
@ -93,8 +97,9 @@ def ssa_basic_properties(function: Function):
ssa_lvalues.add(ssa.lvalue)
# 2 (if Local/State Var)
if isinstance(ssa.lvalue, (StateIRVariable, LocalIRVariable)):
assert ssa.lvalue.index > 0
ssa_lvalue = ssa.lvalue
if isinstance(ssa_lvalue, (StateIRVariable, LocalIRVariable)):
assert ssa_lvalue.index > 0
for rvalue in filter(
lambda x: not isinstance(x, (StateIRVariable, Constant)), ssa.read
@ -111,15 +116,18 @@ def ssa_basic_properties(function: Function):
undef_vars.add(rvalue.non_ssa_version)
# 4
ssa_defs = defaultdict(int)
ssa_defs: Dict[str, int] = defaultdict(int)
for v in ssa_lvalues:
if v and v.name:
ssa_defs[v.name] += 1
for (k, n) in lvalue_assignments.items():
assert ssa_defs[k] >= n
for (k, count) in lvalue_assignments.items():
assert ssa_defs[k] >= count
# Helper 5/6
def check_property_5_and_6(variables, ssavars):
def check_property_5_and_6(
variables: List[LocalVariable], ssavars: List[LocalIRVariable]
) -> None:
for var in filter(lambda x: x.name, variables):
ssa_vars = [x for x in ssavars if x.non_ssa_version == var]
assert len(ssa_vars) == 1
@ -136,7 +144,7 @@ def ssa_basic_properties(function: Function):
check_property_5_and_6(function.returns, function.returns_ssa)
def ssa_phi_node_properties(f: Function):
def ssa_phi_node_properties(f: Function) -> None:
"""Every phi-function should have as many args as predecessors
This does not apply if the phi-node refers to state variables,
@ -152,7 +160,7 @@ def ssa_phi_node_properties(f: Function):
# TODO (hbrodin): This should probably go into another file, not specific to SSA
def dominance_properties(f: Function):
def dominance_properties(f: Function) -> None:
"""Verifies properties related to dominators holds
1. Every node have an immediate dominator except entry_node which have none
@ -180,14 +188,16 @@ def dominance_properties(f: Function):
assert find_path(node.immediate_dominator, node)
def phi_values_inserted(f: Function):
def phi_values_inserted(f: Function) -> None:
"""Verifies that phi-values are inserted at the right places
For every node that has a dominance frontier, any def (including
phi) should be a phi function in its dominance frontier
"""
def have_phi_for_var(node: Node, var):
def have_phi_for_var(
node: Node, var: Union[StateIRVariable, LocalIRVariable, TemporaryVariableSSA]
) -> bool:
"""Checks if a node has a phi-instruction for var
The ssa version would ideally be checked, but then
@ -198,7 +208,14 @@ def phi_values_inserted(f: Function):
non_ssa = var.non_ssa_version
for ssa in node.irs_ssa:
if isinstance(ssa, Phi):
if non_ssa in map(lambda ssa_var: ssa_var.non_ssa_version, ssa.read):
if non_ssa in map(
lambda ssa_var: ssa_var.non_ssa_version,
[
r
for r in ssa.read
if isinstance(r, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA))
],
):
return True
return False
@ -206,12 +223,15 @@ def phi_values_inserted(f: Function):
for df in node.dominance_frontier:
for ssa in node.irs_ssa:
if isinstance(ssa, OperationWithLValue):
if is_used_later(node, ssa.lvalue):
assert have_phi_for_var(df, ssa.lvalue)
ssa_lvalue = ssa.lvalue
if isinstance(
ssa_lvalue, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA)
) and is_used_later(node, ssa_lvalue):
assert have_phi_for_var(df, ssa_lvalue)
@contextmanager
def select_solc_version(version: Optional[str]):
def select_solc_version(version: Optional[str]) -> None:
"""Selects solc version to use for running tests.
If no version is provided, latest is used."""
@ -256,17 +276,17 @@ def slither_from_source(source_code: str, solc_version: Optional[str] = None):
pathlib.Path(fname).unlink()
def verify_properties_hold(source_code_or_slither: Union[str, Slither]):
def verify_properties_hold(source_code_or_slither: Union[str, Slither]) -> None:
"""Ensures that basic properties of SSA hold true"""
def verify_func(func: Function):
def verify_func(func: Function) -> None:
have_ssa_if_ir(func)
phi_values_inserted(func)
ssa_basic_properties(func)
ssa_phi_node_properties(func)
dominance_properties(func)
def verify(slither):
def verify(slither: Slither) -> None:
for cu in slither.compilation_units:
for func in cu.functions_and_modifiers:
_dump_function(func)
@ -280,11 +300,12 @@ def verify_properties_hold(source_code_or_slither: Union[str, Slither]):
if isinstance(source_code_or_slither, Slither):
verify(source_code_or_slither)
else:
slither: Slither
with slither_from_source(source_code_or_slither) as slither:
verify(slither)
def _dump_function(f: Function):
def _dump_function(f: Function) -> None:
"""Helper function to print nodes/ssa ir for a function or modifier"""
print(f"---- {f.name} ----")
for n in f.nodes:
@ -294,13 +315,13 @@ def _dump_function(f: Function):
print("")
def _dump_functions(c: Contract):
def _dump_functions(c: Contract) -> None:
"""Helper function to print functions and modifiers of a contract"""
for f in c.functions_and_modifiers:
_dump_function(f)
def get_filtered_ssa(f: Union[Function, Node], flt) -> List[Operation]:
def get_filtered_ssa(f: Union[Function, Node], flt: Callable) -> List[Operation]:
"""Returns a list of all ssanodes filtered by filter for all nodes in function f"""
if isinstance(f, Function):
return [ssanode for node in f.nodes for ssanode in node.irs_ssa if flt(ssanode)]
@ -314,7 +335,7 @@ def get_ssa_of_type(f: Union[Function, Node], ssatype) -> List[Operation]:
return get_filtered_ssa(f, lambda ssanode: isinstance(ssanode, ssatype))
def test_multi_write():
def test_multi_write() -> None:
contract = """
pragma solidity ^0.8.11;
contract Test {
@ -327,7 +348,7 @@ def test_multi_write():
verify_properties_hold(contract)
def test_single_branch_phi():
def test_single_branch_phi() -> None:
contract = """
pragma solidity ^0.8.11;
contract Test {
@ -342,7 +363,7 @@ def test_single_branch_phi():
verify_properties_hold(contract)
def test_basic_phi():
def test_basic_phi() -> None:
contract = """
pragma solidity ^0.8.11;
contract Test {
@ -359,7 +380,7 @@ def test_basic_phi():
verify_properties_hold(contract)
def test_basic_loop_phi():
def test_basic_loop_phi() -> None:
contract = """
pragma solidity ^0.8.11;
contract Test {
@ -375,7 +396,7 @@ def test_basic_loop_phi():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_phi_propagation_loop():
def test_phi_propagation_loop() -> None:
contract = """
pragma solidity ^0.8.11;
contract Test {
@ -396,7 +417,7 @@ def test_phi_propagation_loop():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_free_function_properties():
def test_free_function_properties() -> None:
contract = """
pragma solidity ^0.8.11;
@ -417,7 +438,7 @@ def test_free_function_properties():
verify_properties_hold(contract)
def test_ssa_inter_transactional():
def test_ssa_inter_transactional() -> None:
source = """
pragma solidity ^0.8.11;
contract A {
@ -460,7 +481,7 @@ def test_ssa_inter_transactional():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_ssa_phi_callbacks():
def test_ssa_phi_callbacks() -> None:
source = """
pragma solidity ^0.8.11;
contract A {
@ -519,7 +540,7 @@ def test_ssa_phi_callbacks():
@pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.")
def test_storage_refers_to():
def test_storage_refers_to() -> None:
"""Test the storage aspects of the SSA IR
When declaring a var as being storage, start tracking what storage it refers_to.

Loading…
Cancel
Save