Fix some mypy issues

pull/1055/head
Josselin 3 years ago
parent 03425069da
commit 0618fb46f5
  1. 78
      slither/analyses/data_dependency/data_dependency.py
  2. 16
      slither/core/expressions/binary_operation.py
  3. 2
      slither/core/expressions/expression.py
  4. 2
      slither/core/expressions/expression_typed.py
  5. 22
      slither/printers/abstract_printer.py

@ -16,7 +16,7 @@ from slither.core.declarations import (
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.variables.top_level_variable import TopLevelVariable
from slither.core.variables.variable import Variable
from slither.slithir.operations import Index, OperationWithLValue, InternalCall
from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation
from slither.slithir.variables import (
Constant,
LocalIRVariable,
@ -39,10 +39,14 @@ if TYPE_CHECKING:
###################################################################################
Variable_types = Union[Variable, SolidityVariable]
Context_types = Union[Contract, Function]
def is_dependent(
variable: Variable,
source: Variable,
context: Union[Contract, Function],
variable: Variable_types,
source: Variable_types,
context: Context_types,
only_unprotected: bool = False,
) -> bool:
"""
@ -70,9 +74,9 @@ def is_dependent(
def is_dependent_ssa(
variable: Variable,
source: Variable,
context: Union[Contract, Function],
variable: Variable_types,
source: Variable_types,
context: Context_types,
only_unprotected: bool = False,
) -> bool:
"""
@ -106,7 +110,12 @@ GENERIC_TAINT = {
}
def is_tainted(variable, context, only_unprotected=False, ignore_generic_taint=False):
def is_tainted(
variable: Variable_types,
context: Context_types,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
) -> bool:
"""
Args:
variable
@ -128,7 +137,12 @@ def is_tainted(variable, context, only_unprotected=False, ignore_generic_taint=F
)
def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_taint=False):
def is_tainted_ssa(
variable: Variable_types,
context: Context_types,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
):
"""
Args:
variable
@ -151,8 +165,8 @@ def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_tai
def get_dependencies(
variable: Variable,
context: Union[Contract, Function],
variable: Variable_types,
context: Context_types,
only_unprotected: bool = False,
) -> Set[Variable]:
"""
@ -171,7 +185,7 @@ def get_dependencies(
def get_all_dependencies(
context: Union[Contract, Function], only_unprotected: bool = False
context: Context_types, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]:
"""
Return the dictionary of dependencies.
@ -188,8 +202,8 @@ def get_all_dependencies(
def get_dependencies_ssa(
variable: Variable,
context: Union[Contract, Function],
variable: Variable_types,
context: Context_types,
only_unprotected: bool = False,
) -> Set[Variable]:
"""
@ -208,7 +222,7 @@ def get_dependencies_ssa(
def get_all_dependencies_ssa(
context: Union[Contract, Function], only_unprotected: bool = False
context: Context_types, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]:
"""
Return the dictionary of dependencies.
@ -250,9 +264,9 @@ KEY_INPUT_SSA = "DATA_DEPENDENCY_INPUT_SSA"
###################################################################################
def pprint_dependency(context):
def pprint_dependency(caller_context: Context_types) -> None:
print("#### SSA ####")
context = context.context
context = caller_context.context
for k, values in context[KEY_SSA].items():
print("{} ({}):".format(k, id(k)))
for v in values:
@ -273,7 +287,7 @@ def pprint_dependency(context):
###################################################################################
def compute_dependency(compilation_unit: "SlitherCompilationUnit"):
def compute_dependency(compilation_unit: "SlitherCompilationUnit") -> None:
compilation_unit.context[KEY_INPUT] = set()
compilation_unit.context[KEY_INPUT_SSA] = set()
@ -281,14 +295,16 @@ def compute_dependency(compilation_unit: "SlitherCompilationUnit"):
compute_dependency_contract(contract, compilation_unit)
def compute_dependency_contract(contract, compilation_unit: "SlitherCompilationUnit"):
def compute_dependency_contract(
contract: Contract, compilation_unit: "SlitherCompilationUnit"
) -> None:
if KEY_SSA in contract.context:
return
contract.context[KEY_SSA] = {}
contract.context[KEY_SSA_UNPROTECTED] = {}
for function in contract.functions + contract.modifiers:
for function in contract.functions + list(contract.modifiers):
compute_dependency_function(function)
propagate_function(contract, function, KEY_SSA, KEY_NON_SSA)
@ -303,7 +319,9 @@ def compute_dependency_contract(contract, compilation_unit: "SlitherCompilationU
propagate_contract(contract, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED)
def propagate_function(contract, function, context_key, context_key_non_ssa):
def propagate_function(
contract: Contract, function: Function, context_key: str, context_key_non_ssa: str
) -> None:
transitive_close_dependencies(function, context_key, context_key_non_ssa)
# Propage data dependency
data_depencencies = function.context[context_key]
@ -314,7 +332,9 @@ def propagate_function(contract, function, context_key, context_key_non_ssa):
contract.context[context_key][key].union(values)
def transitive_close_dependencies(context, context_key, context_key_non_ssa):
def transitive_close_dependencies(
context: Context_types, context_key: str, context_key_non_ssa: str
) -> None:
# transitive closure
changed = True
keys = context.context[context_key].keys()
@ -337,11 +357,11 @@ def transitive_close_dependencies(context, context_key, context_key_non_ssa):
context.context[context_key_non_ssa] = convert_to_non_ssa(context.context[context_key])
def propagate_contract(contract, context_key, context_key_non_ssa):
def propagate_contract(contract: Contract, context_key: str, context_key_non_ssa: str) -> None:
transitive_close_dependencies(contract, context_key, context_key_non_ssa)
def add_dependency(lvalue, function, ir, is_protected):
def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_protected: bool) -> None:
if not lvalue in function.context[KEY_SSA]:
function.context[KEY_SSA][lvalue] = set()
if not is_protected:
@ -362,7 +382,7 @@ def add_dependency(lvalue, function, ir, is_protected):
]
def compute_dependency_function(function):
def compute_dependency_function(function: Function) -> None:
if KEY_SSA in function.context:
return
@ -387,7 +407,7 @@ def compute_dependency_function(function):
)
def convert_variable_to_non_ssa(v):
def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types:
if isinstance(
v,
(
@ -417,9 +437,11 @@ def convert_variable_to_non_ssa(v):
return v
def convert_to_non_ssa(data_depencies):
def convert_to_non_ssa(
data_depencies: Dict[Variable_types, Set[Variable_types]]
) -> Dict[Variable_types, Set[Variable_types]]:
# Need to create new set() as its changed during iteration
ret = {}
ret: Dict[Variable_types, Set[Variable_types]] = {}
for (k, values) in data_depencies.items():
var = convert_variable_to_non_ssa(k)
if not var in ret:

@ -40,8 +40,11 @@ class BinaryOperationType(Enum):
GREATER_SIGNED = 22
RIGHT_SHIFT_ARITHMETIC = 23
# pylint: disable=too-many-branches
@staticmethod
def get_type(operation_type: "BinaryOperation"): # pylint: disable=too-many-branches
def get_type(
operation_type: "BinaryOperation",
) -> "BinaryOperationType":
if operation_type == "**":
return BinaryOperationType.POWER
if operation_type == "*":
@ -93,7 +96,7 @@ class BinaryOperationType(Enum):
raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type))
def __str__(self): # pylint: disable=too-many-branches
def __str__(self) -> str: # pylint: disable=too-many-branches
if self == BinaryOperationType.POWER:
return "**"
if self == BinaryOperationType.MULTIPLICATION:
@ -146,7 +149,12 @@ class BinaryOperationType(Enum):
class BinaryOperation(ExpressionTyped):
def __init__(self, left_expression, right_expression, expression_type):
def __init__(
self,
left_expression: Expression,
right_expression: Expression,
expression_type: BinaryOperationType,
) -> None:
assert isinstance(left_expression, Expression)
assert isinstance(right_expression, Expression)
super().__init__()
@ -169,5 +177,5 @@ class BinaryOperation(ExpressionTyped):
def type(self) -> BinaryOperationType:
return self._type
def __str__(self):
def __str__(self) -> str:
return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right)

@ -2,7 +2,7 @@ from slither.core.source_mapping.source_mapping import SourceMapping
class Expression(SourceMapping):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._is_lvalue = False

@ -7,7 +7,7 @@ if TYPE_CHECKING:
class ExpressionTyped(Expression): # pylint: disable=too-few-public-methods
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._type: Optional["Type"] = None

@ -1,6 +1,13 @@
import abc
from logging import Logger
from typing import TYPE_CHECKING, Union, List, Optional, Dict
from slither.utils import output
from slither.utils.output import SupportedOutput
if TYPE_CHECKING:
from slither import Slither
class IncorrectPrinterInitialization(Exception):
@ -13,7 +20,7 @@ class AbstractPrinter(metaclass=abc.ABCMeta):
WIKI = ""
def __init__(self, slither, logger):
def __init__(self, slither: "Slither", logger: Logger) -> None:
self.slither = slither
self.contracts = slither.contracts
self.filename = slither.filename
@ -34,11 +41,15 @@ class AbstractPrinter(metaclass=abc.ABCMeta):
f"WIKI is not initialized {self.__class__.__name__}"
)
def info(self, info):
def info(self, info: str) -> None:
if self.logger:
self.logger.info(info)
def generate_output(self, info, additional_fields=None):
def generate_output(
self,
info: Union[str, List[Union[str, SupportedOutput]]],
additional_fields: Optional[Dict] = None,
) -> output.Output:
if additional_fields is None:
additional_fields = {}
printer_output = output.Output(info, additional_fields)
@ -47,6 +58,5 @@ class AbstractPrinter(metaclass=abc.ABCMeta):
return printer_output
@abc.abstractmethod
def output(self, filename):
"""TODO Documentation"""
return
def output(self, filename: str) -> output.Output:
pass

Loading…
Cancel
Save