Fix some mypy issues

pull/1055/head
Josselin 3 years ago
parent 03425069da
commit 0618fb46f5
  1. 78
      slither/analyses/data_dependency/data_dependency.py
  2. 16
      slither/core/expressions/binary_operation.py
  3. 2
      slither/core/expressions/expression.py
  4. 2
      slither/core/expressions/expression_typed.py
  5. 22
      slither/printers/abstract_printer.py

@ -16,7 +16,7 @@ from slither.core.declarations import (
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.top_level_variable import TopLevelVariable
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.operations import Index, OperationWithLValue, InternalCall from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation
from slither.slithir.variables import ( from slither.slithir.variables import (
Constant, Constant,
LocalIRVariable, LocalIRVariable,
@ -39,10 +39,14 @@ if TYPE_CHECKING:
################################################################################### ###################################################################################
Variable_types = Union[Variable, SolidityVariable]
Context_types = Union[Contract, Function]
def is_dependent( def is_dependent(
variable: Variable, variable: Variable_types,
source: Variable, source: Variable_types,
context: Union[Contract, Function], context: Context_types,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> bool: ) -> bool:
""" """
@ -70,9 +74,9 @@ def is_dependent(
def is_dependent_ssa( def is_dependent_ssa(
variable: Variable, variable: Variable_types,
source: Variable, source: Variable_types,
context: Union[Contract, Function], context: Context_types,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> bool: ) -> bool:
""" """
@ -106,7 +110,12 @@ GENERIC_TAINT = {
} }
def is_tainted(variable, context, only_unprotected=False, ignore_generic_taint=False): def is_tainted(
variable: Variable_types,
context: Context_types,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
) -> bool:
""" """
Args: Args:
variable variable
@ -128,7 +137,12 @@ def is_tainted(variable, context, only_unprotected=False, ignore_generic_taint=F
) )
def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_taint=False): def is_tainted_ssa(
variable: Variable_types,
context: Context_types,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
):
""" """
Args: Args:
variable variable
@ -151,8 +165,8 @@ def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_tai
def get_dependencies( def get_dependencies(
variable: Variable, variable: Variable_types,
context: Union[Contract, Function], context: Context_types,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
@ -171,7 +185,7 @@ def get_dependencies(
def get_all_dependencies( def get_all_dependencies(
context: Union[Contract, Function], only_unprotected: bool = False context: Context_types, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]: ) -> Dict[Variable, Set[Variable]]:
""" """
Return the dictionary of dependencies. Return the dictionary of dependencies.
@ -188,8 +202,8 @@ def get_all_dependencies(
def get_dependencies_ssa( def get_dependencies_ssa(
variable: Variable, variable: Variable_types,
context: Union[Contract, Function], context: Context_types,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
@ -208,7 +222,7 @@ def get_dependencies_ssa(
def get_all_dependencies_ssa( def get_all_dependencies_ssa(
context: Union[Contract, Function], only_unprotected: bool = False context: Context_types, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]: ) -> Dict[Variable, Set[Variable]]:
""" """
Return the dictionary of dependencies. Return the dictionary of dependencies.
@ -250,9 +264,9 @@ KEY_INPUT_SSA = "DATA_DEPENDENCY_INPUT_SSA"
################################################################################### ###################################################################################
def pprint_dependency(context): def pprint_dependency(caller_context: Context_types) -> None:
print("#### SSA ####") print("#### SSA ####")
context = context.context context = caller_context.context
for k, values in context[KEY_SSA].items(): for k, values in context[KEY_SSA].items():
print("{} ({}):".format(k, id(k))) print("{} ({}):".format(k, id(k)))
for v in values: for v in values:
@ -273,7 +287,7 @@ def pprint_dependency(context):
################################################################################### ###################################################################################
def compute_dependency(compilation_unit: "SlitherCompilationUnit"): def compute_dependency(compilation_unit: "SlitherCompilationUnit") -> None:
compilation_unit.context[KEY_INPUT] = set() compilation_unit.context[KEY_INPUT] = set()
compilation_unit.context[KEY_INPUT_SSA] = set() compilation_unit.context[KEY_INPUT_SSA] = set()
@ -281,14 +295,16 @@ def compute_dependency(compilation_unit: "SlitherCompilationUnit"):
compute_dependency_contract(contract, compilation_unit) compute_dependency_contract(contract, compilation_unit)
def compute_dependency_contract(contract, compilation_unit: "SlitherCompilationUnit"): def compute_dependency_contract(
contract: Contract, compilation_unit: "SlitherCompilationUnit"
) -> None:
if KEY_SSA in contract.context: if KEY_SSA in contract.context:
return return
contract.context[KEY_SSA] = {} contract.context[KEY_SSA] = {}
contract.context[KEY_SSA_UNPROTECTED] = {} contract.context[KEY_SSA_UNPROTECTED] = {}
for function in contract.functions + contract.modifiers: for function in contract.functions + list(contract.modifiers):
compute_dependency_function(function) compute_dependency_function(function)
propagate_function(contract, function, KEY_SSA, KEY_NON_SSA) propagate_function(contract, function, KEY_SSA, KEY_NON_SSA)
@ -303,7 +319,9 @@ def compute_dependency_contract(contract, compilation_unit: "SlitherCompilationU
propagate_contract(contract, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED) propagate_contract(contract, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED)
def propagate_function(contract, function, context_key, context_key_non_ssa): def propagate_function(
contract: Contract, function: Function, context_key: str, context_key_non_ssa: str
) -> None:
transitive_close_dependencies(function, context_key, context_key_non_ssa) transitive_close_dependencies(function, context_key, context_key_non_ssa)
# Propage data dependency # Propage data dependency
data_depencencies = function.context[context_key] data_depencencies = function.context[context_key]
@ -314,7 +332,9 @@ def propagate_function(contract, function, context_key, context_key_non_ssa):
contract.context[context_key][key].union(values) contract.context[context_key][key].union(values)
def transitive_close_dependencies(context, context_key, context_key_non_ssa): def transitive_close_dependencies(
context: Context_types, context_key: str, context_key_non_ssa: str
) -> None:
# transitive closure # transitive closure
changed = True changed = True
keys = context.context[context_key].keys() keys = context.context[context_key].keys()
@ -337,11 +357,11 @@ def transitive_close_dependencies(context, context_key, context_key_non_ssa):
context.context[context_key_non_ssa] = convert_to_non_ssa(context.context[context_key]) context.context[context_key_non_ssa] = convert_to_non_ssa(context.context[context_key])
def propagate_contract(contract, context_key, context_key_non_ssa): def propagate_contract(contract: Contract, context_key: str, context_key_non_ssa: str) -> None:
transitive_close_dependencies(contract, context_key, context_key_non_ssa) transitive_close_dependencies(contract, context_key, context_key_non_ssa)
def add_dependency(lvalue, function, ir, is_protected): def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_protected: bool) -> None:
if not lvalue in function.context[KEY_SSA]: if not lvalue in function.context[KEY_SSA]:
function.context[KEY_SSA][lvalue] = set() function.context[KEY_SSA][lvalue] = set()
if not is_protected: if not is_protected:
@ -362,7 +382,7 @@ def add_dependency(lvalue, function, ir, is_protected):
] ]
def compute_dependency_function(function): def compute_dependency_function(function: Function) -> None:
if KEY_SSA in function.context: if KEY_SSA in function.context:
return return
@ -387,7 +407,7 @@ def compute_dependency_function(function):
) )
def convert_variable_to_non_ssa(v): def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types:
if isinstance( if isinstance(
v, v,
( (
@ -417,9 +437,11 @@ def convert_variable_to_non_ssa(v):
return v return v
def convert_to_non_ssa(data_depencies): def convert_to_non_ssa(
data_depencies: Dict[Variable_types, Set[Variable_types]]
) -> Dict[Variable_types, Set[Variable_types]]:
# Need to create new set() as its changed during iteration # Need to create new set() as its changed during iteration
ret = {} ret: Dict[Variable_types, Set[Variable_types]] = {}
for (k, values) in data_depencies.items(): for (k, values) in data_depencies.items():
var = convert_variable_to_non_ssa(k) var = convert_variable_to_non_ssa(k)
if not var in ret: if not var in ret:

@ -40,8 +40,11 @@ class BinaryOperationType(Enum):
GREATER_SIGNED = 22 GREATER_SIGNED = 22
RIGHT_SHIFT_ARITHMETIC = 23 RIGHT_SHIFT_ARITHMETIC = 23
# pylint: disable=too-many-branches
@staticmethod @staticmethod
def get_type(operation_type: "BinaryOperation"): # pylint: disable=too-many-branches def get_type(
operation_type: "BinaryOperation",
) -> "BinaryOperationType":
if operation_type == "**": if operation_type == "**":
return BinaryOperationType.POWER return BinaryOperationType.POWER
if operation_type == "*": if operation_type == "*":
@ -93,7 +96,7 @@ class BinaryOperationType(Enum):
raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type))
def __str__(self): # pylint: disable=too-many-branches def __str__(self) -> str: # pylint: disable=too-many-branches
if self == BinaryOperationType.POWER: if self == BinaryOperationType.POWER:
return "**" return "**"
if self == BinaryOperationType.MULTIPLICATION: if self == BinaryOperationType.MULTIPLICATION:
@ -146,7 +149,12 @@ class BinaryOperationType(Enum):
class BinaryOperation(ExpressionTyped): class BinaryOperation(ExpressionTyped):
def __init__(self, left_expression, right_expression, expression_type): def __init__(
self,
left_expression: Expression,
right_expression: Expression,
expression_type: BinaryOperationType,
) -> None:
assert isinstance(left_expression, Expression) assert isinstance(left_expression, Expression)
assert isinstance(right_expression, Expression) assert isinstance(right_expression, Expression)
super().__init__() super().__init__()
@ -169,5 +177,5 @@ class BinaryOperation(ExpressionTyped):
def type(self) -> BinaryOperationType: def type(self) -> BinaryOperationType:
return self._type return self._type
def __str__(self): def __str__(self) -> str:
return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right)

@ -2,7 +2,7 @@ from slither.core.source_mapping.source_mapping import SourceMapping
class Expression(SourceMapping): class Expression(SourceMapping):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._is_lvalue = False self._is_lvalue = False

@ -7,7 +7,7 @@ if TYPE_CHECKING:
class ExpressionTyped(Expression): # pylint: disable=too-few-public-methods class ExpressionTyped(Expression): # pylint: disable=too-few-public-methods
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._type: Optional["Type"] = None self._type: Optional["Type"] = None

@ -1,6 +1,13 @@
import abc import abc
from logging import Logger
from typing import TYPE_CHECKING, Union, List, Optional, Dict
from slither.utils import output from slither.utils import output
from slither.utils.output import SupportedOutput
if TYPE_CHECKING:
from slither import Slither
class IncorrectPrinterInitialization(Exception): class IncorrectPrinterInitialization(Exception):
@ -13,7 +20,7 @@ class AbstractPrinter(metaclass=abc.ABCMeta):
WIKI = "" WIKI = ""
def __init__(self, slither, logger): def __init__(self, slither: "Slither", logger: Logger) -> None:
self.slither = slither self.slither = slither
self.contracts = slither.contracts self.contracts = slither.contracts
self.filename = slither.filename self.filename = slither.filename
@ -34,11 +41,15 @@ class AbstractPrinter(metaclass=abc.ABCMeta):
f"WIKI is not initialized {self.__class__.__name__}" f"WIKI is not initialized {self.__class__.__name__}"
) )
def info(self, info): def info(self, info: str) -> None:
if self.logger: if self.logger:
self.logger.info(info) self.logger.info(info)
def generate_output(self, info, additional_fields=None): def generate_output(
self,
info: Union[str, List[Union[str, SupportedOutput]]],
additional_fields: Optional[Dict] = None,
) -> output.Output:
if additional_fields is None: if additional_fields is None:
additional_fields = {} additional_fields = {}
printer_output = output.Output(info, additional_fields) printer_output = output.Output(info, additional_fields)
@ -47,6 +58,5 @@ class AbstractPrinter(metaclass=abc.ABCMeta):
return printer_output return printer_output
@abc.abstractmethod @abc.abstractmethod
def output(self, filename): def output(self, filename: str) -> output.Output:
"""TODO Documentation""" pass
return

Loading…
Cancel
Save