Merge pull request #1640 from crytic/dev-types-2

more types
pull/1624/head
Feist Josselin 2 years ago committed by GitHub
commit f7e087d865
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      slither/core/children/child_expression.py
  2. 1
      slither/core/declarations/__init__.py
  3. 16
      slither/core/declarations/contract.py
  4. 2
      slither/core/declarations/using_for_top_level.py
  5. 6
      slither/core/slither_core.py
  6. 2
      slither/core/solidity_types/array_type.py
  7. 4
      slither/core/solidity_types/function_type.py
  8. 6
      slither/core/solidity_types/type_alias.py
  9. 6
      slither/core/solidity_types/type_information.py
  10. 6
      slither/core/solidity_types/user_defined_type.py
  11. 2
      slither/core/variables/__init__.py
  12. 8
      slither/slithir/operations/phi_callback.py
  13. 2
      slither/slithir/variables/constant.py
  14. 3
      slither/slithir/variables/local_variable.py
  15. 14
      slither/slithir/variables/reference.py
  16. 6
      slither/solc_parsing/declarations/function.py
  17. 7
      slither/solc_parsing/solidity_types/type_parsing.py
  18. 4
      slither/solc_parsing/variables/variable_declaration.py
  19. 6
      slither/visitors/expression/constants_folding.py
  20. 18
      slither/visitors/slithir/expression_to_slithir.py

@ -1,7 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.slithir.operations import Operation
class ChildExpression: class ChildExpression:
@ -9,9 +10,9 @@ class ChildExpression:
super().__init__() super().__init__()
self._expression = None self._expression = None
def set_expression(self, expression: "Expression") -> None: def set_expression(self, expression: Union["Expression", "Operation"]) -> None:
self._expression = expression self._expression = expression
@property @property
def expression(self) -> "Expression": def expression(self) -> Union["Expression", "Operation"]:
return self._expression return self._expression

@ -17,3 +17,4 @@ from .structure_contract import StructureContract
from .structure_top_level import StructureTopLevel from .structure_top_level import StructureTopLevel
from .function_contract import FunctionContract from .function_contract import FunctionContract
from .function_top_level import FunctionTopLevel from .function_top_level import FunctionTopLevel
from .custom_error_contract import CustomErrorContract

@ -38,13 +38,13 @@ if TYPE_CHECKING:
EnumContract, EnumContract,
StructureContract, StructureContract,
FunctionContract, FunctionContract,
CustomErrorContract,
) )
from slither.slithir.variables.variable import SlithIRVariable from slither.slithir.variables.variable import SlithIRVariable
from slither.core.variables.variable import Variable from slither.core.variables import Variable, StateVariable
from slither.core.variables.state_variable import StateVariable
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.scope.scope import FileScope from slither.core.scope.scope import FileScope
from slither.core.cfg.node import Node
LOGGER = logging.getLogger("Contract") LOGGER = logging.getLogger("Contract")
@ -803,23 +803,25 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
return next((v for v in self.state_variables if v.name == canonical_name), None) return next((v for v in self.state_variables if v.name == canonical_name), None)
def get_structure_from_name(self, structure_name: str) -> Optional["Structure"]: def get_structure_from_name(self, structure_name: str) -> Optional["StructureContract"]:
""" """
Return a structure from a name Return a structure from a name
Args: Args:
structure_name (str): name of the structure structure_name (str): name of the structure
Returns: Returns:
Structure StructureContract
""" """
return next((st for st in self.structures if st.name == structure_name), None) return next((st for st in self.structures if st.name == structure_name), None)
def get_structure_from_canonical_name(self, structure_name: str) -> Optional["Structure"]: def get_structure_from_canonical_name(
self, structure_name: str
) -> Optional["StructureContract"]:
""" """
Return a structure from a canonical name Return a structure from a canonical name
Args: Args:
structure_name (str): canonical name of the structure structure_name (str): canonical name of the structure
Returns: Returns:
Structure StructureContract
""" """
return next((st for st in self.structures if st.canonical_name == structure_name), None) return next((st for st in self.structures if st.canonical_name == structure_name), None)

@ -14,5 +14,5 @@ class UsingForTopLevel(TopLevel):
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope
@property @property
def using_for(self) -> Dict[Type, List[Type]]: def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for return self._using_for

@ -8,7 +8,7 @@ import pathlib
import posixpath import posixpath
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Optional, Dict, List, Set, Union from typing import Optional, Dict, List, Set, Union, Tuple
from crytic_compile import CryticCompile from crytic_compile import CryticCompile
from crytic_compile.utils.naming import Filename from crytic_compile.utils.naming import Filename
@ -73,8 +73,8 @@ class SlitherCore(Context):
# Maps from file to detector name to the start/end ranges for that detector. # Maps from file to detector name to the start/end ranges for that detector.
# Infinity is used to signal a detector has no end range. # Infinity is used to signal a detector has no end range.
self._ignore_ranges: defaultdict[str, defaultdict[str, List[(int, int)]]] = defaultdict( self._ignore_ranges: Dict[str, Dict[str, List[Tuple[int, ...]]]] = defaultdict(
lambda: defaultdict(lambda: []) lambda: defaultdict(lambda: [(-1, -1)])
) )
self._compilation_units: List[SlitherCompilationUnit] = [] self._compilation_units: List[SlitherCompilationUnit] = []

@ -17,7 +17,7 @@ class ArrayType(Type):
def __init__( def __init__(
self, self,
t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"], t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"],
length: Optional[Union["Identifier", Literal, "BinaryOperation"]], length: Optional[Union["Identifier", Literal, "BinaryOperation", int]],
) -> None: ) -> None:
assert isinstance(t, Type) assert isinstance(t, Type)
if length: if length:

@ -1,4 +1,4 @@
from typing import List, Tuple from typing import List, Tuple, Any
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.function_type_variable import FunctionTypeVariable
@ -69,7 +69,7 @@ class FunctionType(Type):
return f"({params}) returns({return_values})" return f"({params}) returns({return_values})"
return f"({params})" return f"({params})"
def __eq__(self, other: ElementaryType) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, FunctionType): if not isinstance(other, FunctionType):
return False return False
return self.params == other.params and self.return_values == other.return_values return self.params == other.params and self.return_values == other.return_values

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Tuple
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
from slither.core.solidity_types import Type from slither.core.solidity_types import Type, ElementaryType
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -10,13 +10,13 @@ if TYPE_CHECKING:
class TypeAlias(Type): class TypeAlias(Type):
def __init__(self, underlying_type: Type, name: str) -> None: def __init__(self, underlying_type: ElementaryType, name: str) -> None:
super().__init__() super().__init__()
self.name = name self.name = name
self.underlying_type = underlying_type self.underlying_type = underlying_type
@property @property
def type(self) -> Type: def type(self) -> ElementaryType:
""" """
Return the underlying type. Alias for underlying_type Return the underlying type. Alias for underlying_type

@ -5,13 +5,13 @@ from slither.core.solidity_types.type import Type
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.enum import Enum
# Use to model the Type(X) function, which returns an undefined type # Use to model the Type(X) function, which returns an undefined type
# https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information
class TypeInformation(Type): class TypeInformation(Type):
def __init__(self, c: Union[ElementaryType, "Contract", "EnumTopLevel"]) -> None: def __init__(self, c: Union[ElementaryType, "Contract", "Enum"]) -> None:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
@ -21,7 +21,7 @@ class TypeInformation(Type):
self._type = c self._type = c
@property @property
def type(self) -> "Contract": def type(self) -> Union["Contract", ElementaryType, "Enum"]:
return self._type return self._type
@property @property

@ -8,14 +8,10 @@ if TYPE_CHECKING:
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations import EnumContract
from slither.core.declarations.structure_top_level import StructureTopLevel
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
class UserDefinedType(Type): class UserDefinedType(Type):
def __init__( def __init__(self, t: Union["Enum", "Contract", "Structure"]) -> None:
self, t: Union["EnumContract", "StructureTopLevel", "Contract", "StructureContract"]
) -> None:
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract

@ -0,0 +1,2 @@
from .state_variable import StateVariable
from .variable import Variable

@ -38,6 +38,10 @@ class PhiCallback(Phi):
def rvalues(self): def rvalues(self):
return self._rvalues return self._rvalues
@rvalues.setter
def rvalues(self, vals):
self._rvalues = vals
@property @property
def rvalue_no_callback(self): def rvalue_no_callback(self):
""" """
@ -45,10 +49,6 @@ class PhiCallback(Phi):
""" """
return self._rvalue_no_callback return self._rvalue_no_callback
@rvalues.setter
def rvalues(self, vals):
self._rvalues = vals
@property @property
def nodes(self): def nodes(self):
return self._nodes return self._nodes

@ -11,7 +11,7 @@ from slither.utils.integer_conversion import convert_string_to_int
class Constant(SlithIRVariable): class Constant(SlithIRVariable):
def __init__( def __init__(
self, self,
val: str, val: Union[int, str],
constant_type: Optional[ElementaryType] = None, constant_type: Optional[ElementaryType] = None,
subdenomination: Optional[str] = None, subdenomination: Optional[str] = None,
) -> None: # pylint: disable=too-many-branches ) -> None: # pylint: disable=too-many-branches

@ -1,3 +1,4 @@
from typing import Set
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary import TemporaryVariable
from slither.slithir.variables.variable import SlithIRVariable from slither.slithir.variables.variable import SlithIRVariable
@ -31,7 +32,7 @@ class LocalIRVariable(
# Additional field # Additional field
# points to state variables # points to state variables
self._refers_to = set() self._refers_to: Set[StateIRVariable] = set()
# keep un-ssa version # keep un-ssa version
if isinstance(local_variable, LocalIRVariable): if isinstance(local_variable, LocalIRVariable):

@ -35,13 +35,6 @@ class ReferenceVariable(ChildNode, Variable):
""" """
return self._points_to return self._points_to
@property
def points_to_origin(self):
points = self.points_to
while isinstance(points, ReferenceVariable):
points = points.points_to
return points
@points_to.setter @points_to.setter
def points_to(self, points_to): def points_to(self, points_to):
# Can only be a rvalue of # Can only be a rvalue of
@ -55,6 +48,13 @@ class ReferenceVariable(ChildNode, Variable):
self._points_to = points_to self._points_to = points_to
@property
def points_to_origin(self):
points = self.points_to
while isinstance(points, ReferenceVariable):
points = points.points_to
return points
@property @property
def name(self) -> str: def name(self) -> str:
return f"REF_{self.index}" return f"REF_{self.index}"

@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, Optional, Union, List, TYPE_CHECKING from typing import Dict, Optional, Union, List, TYPE_CHECKING, Tuple
from slither.core.cfg.node import NodeType, link_nodes, insert_node, Node from slither.core.cfg.node import NodeType, link_nodes, insert_node, Node
from slither.core.cfg.scope import Scope from slither.core.cfg.scope import Scope
@ -445,7 +445,7 @@ class FunctionSolc(CallerContextExpression):
def _parse_for_compact_ast( # pylint: disable=no-self-use def _parse_for_compact_ast( # pylint: disable=no-self-use
self, statement: Dict self, statement: Dict
) -> (Optional[Dict], Optional[Dict], Optional[Dict], Dict): ) -> Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Dict]:
body = statement["body"] body = statement["body"]
init_expression = statement.get("initializationExpression", None) init_expression = statement.get("initializationExpression", None)
condition = statement.get("condition", None) condition = statement.get("condition", None)
@ -455,7 +455,7 @@ class FunctionSolc(CallerContextExpression):
def _parse_for_legacy_ast( def _parse_for_legacy_ast(
self, statement: Dict self, statement: Dict
) -> (Optional[Dict], Optional[Dict], Optional[Dict], Dict): ) -> Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Dict]:
# if we're using an old version of solc (anything below and including 0.4.11) or if the user # if we're using an old version of solc (anything below and including 0.4.11) or if the user
# explicitly enabled compact ast, we might need to make some best-effort guesses # explicitly enabled compact ast, we might need to make some best-effort guesses
children = statement[self.get_children("children")] children = statement[self.get_children("children")]

@ -22,7 +22,7 @@ from slither.solc_parsing.exceptions import ParsingError
from slither.solc_parsing.expressions.expression_parsing import CallerContextExpression from slither.solc_parsing.expressions.expression_parsing import CallerContextExpression
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Structure, Enum from slither.core.declarations import Structure, Enum, Function
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
@ -233,6 +233,7 @@ def parse_type(
sl: "SlitherCompilationUnit" sl: "SlitherCompilationUnit"
renaming: Dict[str, str] renaming: Dict[str, str]
user_defined_types: Dict[str, TypeAlias] user_defined_types: Dict[str, TypeAlias]
enums_direct_access: List["Enum"] = []
# Note: for convenicence top level functions use the same parser than function in contract # Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None # but contract_parser is set to None
if isinstance(caller_context, SlitherCompilationUnitSolc) or ( if isinstance(caller_context, SlitherCompilationUnitSolc) or (
@ -254,7 +255,7 @@ def parse_type(
all_structuress = [c.structures for c in sl.contracts] all_structuress = [c.structures for c in sl.contracts]
all_structures = [item for sublist in all_structuress for item in sublist] all_structures = [item for sublist in all_structuress for item in sublist]
all_structures += structures_direct_access all_structures += structures_direct_access
enums_direct_access = sl.enums_top_level enums_direct_access += sl.enums_top_level
all_enumss = [c.enums for c in sl.contracts] all_enumss = [c.enums for c in sl.contracts]
all_enums = [item for sublist in all_enumss for item in sublist] all_enums = [item for sublist in all_enumss for item in sublist]
all_enums += enums_direct_access all_enums += enums_direct_access
@ -316,7 +317,7 @@ def parse_type(
all_structuress = [c.structures for c in contract.file_scope.contracts.values()] all_structuress = [c.structures for c in contract.file_scope.contracts.values()]
all_structures = [item for sublist in all_structuress for item in sublist] all_structures = [item for sublist in all_structuress for item in sublist]
all_structures += contract.file_scope.structures.values() all_structures += contract.file_scope.structures.values()
enums_direct_access: List["Enum"] = contract.enums enums_direct_access += contract.enums
enums_direct_access += contract.file_scope.enums.values() enums_direct_access += contract.file_scope.enums.values()
all_enumss = [c.enums for c in contract.file_scope.contracts.values()] all_enumss = [c.enums for c in contract.file_scope.contracts.values()]
all_enums = [item for sublist in all_enumss for item in sublist] all_enums = [item for sublist in all_enumss for item in sublist]

@ -1,6 +1,6 @@
import logging import logging
import re import re
from typing import Dict from typing import Dict, Optional
from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
@ -127,7 +127,7 @@ class VariableDeclarationSolc:
self._variable.visibility = "internal" self._variable.visibility = "internal"
def _init_from_declaration( def _init_from_declaration(
self, var: Dict, init: bool self, var: Dict, init: Optional[bool]
) -> None: # pylint: disable=too-many-branches ) -> None: # pylint: disable=too-many-branches
if self._is_compact_ast: if self._is_compact_ast:
attributes = var attributes = var

@ -8,6 +8,8 @@ from slither.core.expressions import (
Identifier, Identifier,
BinaryOperation, BinaryOperation,
UnaryOperation, UnaryOperation,
TupleExpression,
TypeConversion,
) )
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
@ -23,7 +25,9 @@ class NotConstant(Exception):
KEY = "ConstantFolding" KEY = "ConstantFolding"
CONSTANT_TYPES_OPERATIONS = Union[Literal, BinaryOperation, UnaryOperation, Identifier] CONSTANT_TYPES_OPERATIONS = Union[
Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion
]
def get_val(expression: CONSTANT_TYPES_OPERATIONS) -> Union[bool, int, Fraction, str]: def get_val(expression: CONSTANT_TYPES_OPERATIONS) -> Union[bool, int, Fraction, str]:

@ -1,7 +1,6 @@
import logging import logging
from typing import Any, Union, List, TYPE_CHECKING from typing import Any, Union, List, TYPE_CHECKING, TypeVar, Generic
import slither.slithir.variables.reference
from slither.core.declarations import ( from slither.core.declarations import (
Function, Function,
SolidityVariable, SolidityVariable,
@ -67,13 +66,14 @@ from slither.visitors.expression.expression import ExpressionVisitor
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.slithir.operations import Operation
logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") logger = logging.getLogger("VISTIOR:ExpressionToSlithIR")
key = "expressionToSlithIR" key = "expressionToSlithIR"
def get(expression: Expression): def get(expression: Union[Expression, Operation]):
val = expression.context[key] val = expression.context[key]
# we delete the item to reduce memory use # we delete the item to reduce memory use
del expression.context[key] del expression.context[key]
@ -84,7 +84,7 @@ def get_without_removing(expression):
return expression.context[key] return expression.context[key]
def set_val(expression: Expression, val) -> None: def set_val(expression: Union[Expression, Operation], val) -> None:
expression.context[key] = val expression.context[key] = val
@ -121,7 +121,7 @@ _signed_to_unsigned = {
def convert_assignment( def convert_assignment(
left: Union[LocalVariable, StateVariable, ReferenceVariable], left: Union[LocalVariable, StateVariable, ReferenceVariable],
right: SourceMapping, right: Union[LocalVariable, StateVariable, ReferenceVariable],
t: AssignmentOperationType, t: AssignmentOperationType,
return_type, return_type,
) -> Union[Binary, Assignment]: ) -> Union[Binary, Assignment]:
@ -417,9 +417,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
cst = Constant(expression.value, expression.type, expression.subdenomination) cst = Constant(expression.value, expression.type, expression.subdenomination)
set_val(expression, cst) set_val(expression, cst)
def _post_member_access( def _post_member_access(self, expression: MemberAccess) -> None:
self, expression: slither.core.expressions.member_access.MemberAccess
) -> None:
expr = get(expression.expression) expr = get(expression.expression)
# Look for type(X).max / min # Look for type(X).max / min
@ -541,9 +539,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
val = expressions val = expressions
set_val(expression, val) set_val(expression, val)
def _post_type_conversion( def _post_type_conversion(self, expression: TypeConversion) -> None:
self, expression: slither.core.expressions.type_conversion.TypeConversion
) -> None:
expr = get(expression.expression) expr = get(expression.expression)
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
operation = TypeConversion(val, expr, expression.type) operation = TypeConversion(val, expr, expression.type)

Loading…
Cancel
Save