Merge branch 'dev' of github.com:crytic/slither into dev

pull/516/head
Josselin 4 years ago
commit 63dbe8d747
  1. 2
      .github/workflows/ci.yml
  2. 592
      slither/core/cfg/node.py
  3. 11
      slither/core/children/child_contract.py
  4. 9
      slither/core/children/child_event.py
  5. 9
      slither/core/children/child_expression.py
  6. 9
      slither/core/children/child_function.py
  7. 10
      slither/core/children/child_inheritance.py
  8. 17
      slither/core/children/child_node.py
  9. 10
      slither/core/children/child_slither.py
  10. 10
      slither/core/children/child_structure.py
  11. 15
      slither/core/context/context.py
  12. 798
      slither/core/declarations/contract.py
  13. 17
      slither/core/declarations/enum.py
  14. 40
      slither/core/declarations/event.py
  15. 802
      slither/core/declarations/function.py
  16. 6
      slither/core/declarations/import_directive.py
  17. 3
      slither/core/declarations/modifier.py
  18. 28
      slither/core/declarations/pragma_directive.py
  19. 151
      slither/core/declarations/solidity_variables.py
  20. 32
      slither/core/declarations/structure.py
  21. 34
      slither/core/dominators/node_dominator_tree.py
  22. 38
      slither/core/dominators/utils.py
  23. 4
      slither/core/exceptions.py
  24. 113
      slither/core/expressions/assignment_operation.py
  25. 149
      slither/core/expressions/binary_operation.py
  26. 43
      slither/core/expressions/call_expression.py
  27. 28
      slither/core/expressions/conditional_expression.py
  28. 5
      slither/core/expressions/elementary_type_name_expression.py
  29. 5
      slither/core/expressions/expression.py
  30. 9
      slither/core/expressions/expression_typed.py
  31. 12
      slither/core/expressions/identifier.py
  32. 24
      slither/core/expressions/index_access.py
  33. 17
      slither/core/expressions/literal.py
  34. 20
      slither/core/expressions/member_access.py
  35. 12
      slither/core/expressions/new_array.py
  36. 11
      slither/core/expressions/new_contract.py
  37. 7
      slither/core/expressions/new_elementary_type.py
  38. 4
      slither/core/expressions/super_call_expression.py
  39. 6
      slither/core/expressions/super_identifier.py
  40. 9
      slither/core/expressions/tuple_expression.py
  41. 10
      slither/core/expressions/type_conversion.py
  42. 101
      slither/core/expressions/unary_operation.py
  43. 172
      slither/core/slither_core.py
  44. 28
      slither/core/solidity_types/array_type.py
  45. 167
      slither/core/solidity_types/elementary_type.py
  46. 42
      slither/core/solidity_types/function_type.py
  47. 9
      slither/core/solidity_types/mapping_type.py
  48. 4
      slither/core/solidity_types/type.py
  49. 12
      slither/core/solidity_types/type_information.py
  50. 14
      slither/core/solidity_types/user_defined_type.py
  51. 76
      slither/core/source_mapping/source_mapping.py
  52. 6
      slither/core/variables/event_variable.py
  53. 3
      slither/core/variables/function_type_variable.py
  54. 32
      slither/core/variables/local_variable.py
  55. 11
      slither/core/variables/local_variable_init_from_tuple.py
  56. 33
      slither/core/variables/state_variable.py
  57. 3
      slither/core/variables/structure_variable.py
  58. 84
      slither/core/variables/variable.py
  59. 24
      slither/printers/guidance/echidna.py
  60. 87
      slither/slither.py
  61. 89
      slither/slithir/operations/binary.py
  62. 19
      slither/slithir/operations/unary.py
  63. 88
      slither/solc_parsing/cfg/node.py
  64. 632
      slither/solc_parsing/declarations/contract.py
  65. 54
      slither/solc_parsing/declarations/event.py
  66. 992
      slither/solc_parsing/declarations/function.py
  67. 57
      slither/solc_parsing/declarations/modifier.py
  68. 43
      slither/solc_parsing/declarations/structure.py
  69. 8
      slither/solc_parsing/exceptions.py
  70. 549
      slither/solc_parsing/expressions/expression_parsing.py
  71. 369
      slither/solc_parsing/slitherSolc.py
  72. 198
      slither/solc_parsing/solidity_types/type_parsing.py
  73. 19
      slither/solc_parsing/variables/event_variable.py
  74. 12
      slither/solc_parsing/variables/function_type_variable.py
  75. 35
      slither/solc_parsing/variables/local_variable.py
  76. 15
      slither/solc_parsing/variables/local_variable_init_from_tuple.py
  77. 12
      slither/solc_parsing/variables/state_variable.py
  78. 12
      slither/solc_parsing/variables/structure_variable.py
  79. 154
      slither/solc_parsing/variables/variable_declaration.py
  80. 14
      slither/tools/demo/__main__.py
  81. 37
      slither/tools/erc_conformance/__main__.py
  82. 14
      slither/tools/erc_conformance/erc/erc20.py
  83. 95
      slither/tools/erc_conformance/erc/ercs.py
  84. 30
      slither/tools/kspec_coverage/__main__.py
  85. 77
      slither/tools/kspec_coverage/analysis.py
  86. 3
      slither/tools/kspec_coverage/kspec_coverage.py
  87. 21
      slither/tools/possible_paths/__main__.py
  88. 37
      slither/tools/possible_paths/possible_paths.py
  89. 82
      slither/tools/properties/__main__.py
  90. 8
      slither/tools/properties/addresses/address.py
  91. 6
      slither/tools/properties/platforms/echidna.py
  92. 68
      slither/tools/properties/platforms/truffle.py
  93. 130
      slither/tools/properties/properties/erc20.py
  94. 32
      slither/tools/properties/properties/ercs/erc20/properties/burn.py
  95. 75
      slither/tools/properties/properties/ercs/erc20/properties/initialization.py
  96. 19
      slither/tools/properties/properties/ercs/erc20/properties/mint.py
  97. 20
      slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py
  98. 207
      slither/tools/properties/properties/ercs/erc20/properties/transfer.py
  99. 24
      slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py
  100. 2
      slither/tools/properties/properties/properties.py
  101. Some files were not shown because too many files have changed in this diff Show More

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
type: ["4", "5", "cli", "dapp", "data_dependency", "embark", "erc", "etherlime", "find_paths", "kspec", "printers", "simil", "slither_config", "truffle", "upgradability", "prop"] type: ["4", "5", "cli", "data_dependency", "embark", "erc", "etherlime", "find_paths", "kspec", "printers", "simil", "slither_config", "truffle", "upgradability", "prop"]
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v1
- name: Set up Python 3.6 - name: Set up Python 3.6

File diff suppressed because it is too large Load Diff

@ -1,14 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Contract
class ChildContract:
class ChildContract:
def __init__(self): def __init__(self):
super(ChildContract, self).__init__() super(ChildContract, self).__init__()
self._contract = None self._contract = None
def set_contract(self, contract): def set_contract(self, contract: "Contract"):
self._contract = contract self._contract = contract
@property @property
def contract(self): def contract(self) -> "Contract":
return self._contract return self._contract

@ -1,12 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Event
class ChildEvent: class ChildEvent:
def __init__(self): def __init__(self):
super(ChildEvent, self).__init__() super(ChildEvent, self).__init__()
self._event = None self._event = None
def set_event(self, event): def set_event(self, event: "Event"):
self._event = event self._event = event
@property @property
def event(self): def event(self) -> "Event":
return self._event return self._event

@ -1,12 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
class ChildExpression: class ChildExpression:
def __init__(self): def __init__(self):
super(ChildExpression, self).__init__() super(ChildExpression, self).__init__()
self._expression = None self._expression = None
def set_expression(self, expression): def set_expression(self, expression: "Expression"):
self._expression = expression self._expression = expression
@property @property
def expression(self): def expression(self) -> "Expression":
return self._expression return self._expression

@ -1,12 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Function
class ChildFunction: class ChildFunction:
def __init__(self): def __init__(self):
super(ChildFunction, self).__init__() super(ChildFunction, self).__init__()
self._function = None self._function = None
def set_function(self, function): def set_function(self, function: "Function"):
self._function = function self._function = function
@property @property
def function(self): def function(self) -> "Function":
return self._function return self._function

@ -1,13 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Contract
class ChildInheritance:
class ChildInheritance:
def __init__(self): def __init__(self):
super(ChildInheritance, self).__init__() super(ChildInheritance, self).__init__()
self._contract_declarer = None self._contract_declarer = None
def set_contract_declarer(self, contract): def set_contract_declarer(self, contract: "Contract"):
self._contract_declarer = contract self._contract_declarer = contract
@property @property
def contract_declarer(self): def contract_declarer(self) -> "Contract":
return self._contract_declarer return self._contract_declarer

@ -1,24 +1,31 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither import Slither
from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract
class ChildNode(object): class ChildNode(object):
def __init__(self): def __init__(self):
super(ChildNode, self).__init__() super(ChildNode, self).__init__()
self._node = None self._node = None
def set_node(self, node): def set_node(self, node: "Node"):
self._node = node self._node = node
@property @property
def node(self): def node(self) -> "Node":
return self._node return self._node
@property @property
def function(self): def function(self) -> "Function":
return self.node.function return self.node.function
@property @property
def contract(self): def contract(self) -> "Contract":
return self.node.function.contract return self.node.function.contract
@property @property
def slither(self): def slither(self) -> "Slither":
return self.contract.slither return self.contract.slither

@ -1,13 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither import Slither
class ChildSlither:
class ChildSlither:
def __init__(self): def __init__(self):
super(ChildSlither, self).__init__() super(ChildSlither, self).__init__()
self._slither = None self._slither = None
def set_slither(self, slither): def set_slither(self, slither: "Slither"):
self._slither = slither self._slither = slither
@property @property
def slither(self): def slither(self) -> "Slither":
return self._slither return self._slither

@ -1,13 +1,17 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Structure
class ChildStructure:
class ChildStructure:
def __init__(self): def __init__(self):
super(ChildStructure, self).__init__() super(ChildStructure, self).__init__()
self._structure = None self._structure = None
def set_structure(self, structure): def set_structure(self, structure: "Structure"):
self._structure = structure self._structure = structure
@property @property
def structure(self): def structure(self) -> "Structure":
return self._structure return self._structure

@ -1,14 +1,15 @@
class Context: from collections import defaultdict
from typing import Dict
class Context:
def __init__(self): def __init__(self):
super(Context, self).__init__() super(Context, self).__init__()
self._context = {} self._context = {"MEMBERS": defaultdict(None)}
@property @property
def context(self): def context(self) -> Dict:
''' """
Dict used by analysis Dict used by analysis
''' """
return self._context return self._context

File diff suppressed because it is too large Load Diff

@ -1,25 +1,32 @@
from typing import List, TYPE_CHECKING
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
if TYPE_CHECKING:
from slither.core.declarations import Contract
class Enum(ChildContract, SourceMapping): class Enum(ChildContract, SourceMapping):
def __init__(self, name, canonical_name, values): def __init__(self, name: str, canonical_name: str, values: List[str]):
super().__init__()
self._name = name self._name = name
self._canonical_name = canonical_name self._canonical_name = canonical_name
self._values = values self._values = values
@property @property
def canonical_name(self): def canonical_name(self) -> str:
return self._canonical_name return self._canonical_name
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@property @property
def values(self): def values(self) -> List[str]:
return self._values return self._values
def is_declared_by(self, contract): def is_declared_by(self, contract: "Contract") -> bool:
""" """
Check if the element is declared by the contract Check if the element is declared by the contract
:param contract: :param contract:

@ -1,47 +1,57 @@
from typing import List, Tuple, TYPE_CHECKING
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.event_variable import EventVariable
class Event(ChildContract, SourceMapping): if TYPE_CHECKING:
from slither.core.declarations import Contract
class Event(ChildContract, SourceMapping):
def __init__(self): def __init__(self):
super(Event, self).__init__() super(Event, self).__init__()
self._name = None self._name = None
self._elems = [] self._elems: List[EventVariable] = []
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@name.setter
def name(self, name: str):
self._name = name
@property @property
def signature(self): def signature(self) -> Tuple[str, List[str]]:
''' Return the function signature """ Return the function signature
Returns: Returns:
(str, list(str)): name, list parameters type (str, list(str)): name, list parameters type
''' """
return self.name, [str(x.type) for x in self.elems] return self.name, [str(x.type) for x in self.elems]
@property @property
def full_name(self): def full_name(self) -> str:
''' Return the function signature as a str """ Return the function signature as a str
Returns: Returns:
str: func_name(type1,type2) str: func_name(type1,type2)
''' """
name, parameters = self.signature name, parameters = self.signature
return name+'('+','.join(parameters)+')' return name + "(" + ",".join(parameters) + ")"
@property @property
def canonical_name(self): def canonical_name(self) -> str:
''' Return the function signature as a str """ Return the function signature as a str
Returns: Returns:
str: contract.func_name(type1,type2) str: contract.func_name(type1,type2)
''' """
return self.contract.name + self.full_name return self.contract.name + self.full_name
@property @property
def elems(self): def elems(self) -> List["EventVariable"]:
return self._elems return self._elems
def is_declared_by(self, contract): def is_declared_by(self, contract: "Contract") -> bool:
""" """
Check if the element is declared by the contract Check if the element is declared by the contract
:param contract: :param contract:

File diff suppressed because it is too large Load Diff

@ -1,13 +1,13 @@
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
class Import(SourceMapping):
def __init__(self, filename): class Import(SourceMapping):
def __init__(self, filename: str):
super(Import, self).__init__() super(Import, self).__init__()
self._filename = filename self._filename = filename
@property @property
def filename(self): def filename(self) -> str:
return self._filename return self._filename
def __str__(self): def __str__(self):

@ -3,5 +3,6 @@
""" """
from .function import Function from .function import Function
class Modifier(Function): pass
class Modifier(Function):
pass

@ -1,37 +1,39 @@
from typing import List
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
class Pragma(SourceMapping):
def __init__(self, directive): class Pragma(SourceMapping):
def __init__(self, directive: List[str]):
super(Pragma, self).__init__() super(Pragma, self).__init__()
self._directive = directive self._directive = directive
@property @property
def directive(self): def directive(self) -> List[str]:
''' """
list(str) list(str)
''' """
return self._directive return self._directive
@property @property
def version(self): def version(self) -> str:
return ''.join(self.directive[1:]) return "".join(self.directive[1:])
@property @property
def name(self): def name(self) -> str:
return self.version return self.version
@property @property
def is_solidity_version(self): def is_solidity_version(self) -> bool:
if len(self._directive) > 0: if len(self._directive) > 0:
return self._directive[0].lower() == 'solidity' return self._directive[0].lower() == "solidity"
return False return False
@property @property
def is_abi_encoder_v2(self): def is_abi_encoder_v2(self) -> bool:
if len(self._directive) == 2: if len(self._directive) == 2:
return self._directive[0] == 'experimental' and self._directive[1] == 'ABIEncoderV2' return self._directive[0] == "experimental" and self._directive[1] == "ABIEncoderV2"
return False return False
def __str__(self): def __str__(self):
return 'pragma '+''.join(self.directive) return "pragma " + "".join(self.directive)

@ -1,64 +1,73 @@
# https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html # https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html
from typing import List, Dict, Union
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.solidity_types import ElementaryType, TypeInformation from slither.core.solidity_types import ElementaryType, TypeInformation
SOLIDITY_VARIABLES = {"now":'uint256', SOLIDITY_VARIABLES = {
"this":'address', "now": "uint256",
'abi':'address', # to simplify the conversion, assume that abi return an address "this": "address",
'msg':'', "abi": "address", # to simplify the conversion, assume that abi return an address
'tx':'', "msg": "",
'block':'', "tx": "",
'super':''} "block": "",
"super": "",
SOLIDITY_VARIABLES_COMPOSED = {"block.coinbase":"address", }
"block.difficulty":"uint256",
"block.gaslimit":"uint256", SOLIDITY_VARIABLES_COMPOSED = {
"block.number":"uint256", "block.coinbase": "address",
"block.timestamp":"uint256", "block.difficulty": "uint256",
"block.blockhash":"uint256", # alias for blockhash. It's a call "block.gaslimit": "uint256",
"msg.data":"bytes", "block.number": "uint256",
"msg.gas":"uint256", "block.timestamp": "uint256",
"msg.sender":"address", "block.blockhash": "uint256", # alias for blockhash. It's a call
"msg.sig":"bytes4", "msg.data": "bytes",
"msg.value":"uint256", "msg.gas": "uint256",
"tx.gasprice":"uint256", "msg.sender": "address",
"tx.origin":"address"} "msg.sig": "bytes4",
"msg.value": "uint256",
"tx.gasprice": "uint256",
SOLIDITY_FUNCTIONS = {"gasleft()":['uint256'], "tx.origin": "address",
"assert(bool)":[], }
"require(bool)":[],
"require(bool,string)":[],
"revert()":[], SOLIDITY_FUNCTIONS: Dict[str, List[str]] = {
"revert(string)":[], "gasleft()": ["uint256"],
"addmod(uint256,uint256,uint256)":['uint256'], "assert(bool)": [],
"mulmod(uint256,uint256,uint256)":['uint256'], "require(bool)": [],
"keccak256()":['bytes32'], "require(bool,string)": [],
"keccak256(bytes)":['bytes32'], # Solidity 0.5 "revert()": [],
"sha256()":['bytes32'], "revert(string)": [],
"sha256(bytes)":['bytes32'], # Solidity 0.5 "addmod(uint256,uint256,uint256)": ["uint256"],
"sha3()":['bytes32'], "mulmod(uint256,uint256,uint256)": ["uint256"],
"ripemd160()":['bytes32'], "keccak256()": ["bytes32"],
"ripemd160(bytes)":['bytes32'], # Solidity 0.5 "keccak256(bytes)": ["bytes32"], # Solidity 0.5
"ecrecover(bytes32,uint8,bytes32,bytes32)":['address'], "sha256()": ["bytes32"],
"selfdestruct(address)":[], "sha256(bytes)": ["bytes32"], # Solidity 0.5
"suicide(address)":[], "sha3()": ["bytes32"],
"log0(bytes32)":[], "ripemd160()": ["bytes32"],
"log1(bytes32,bytes32)":[], "ripemd160(bytes)": ["bytes32"], # Solidity 0.5
"log2(bytes32,bytes32,bytes32)":[], "ecrecover(bytes32,uint8,bytes32,bytes32)": ["address"],
"log3(bytes32,bytes32,bytes32,bytes32)":[], "selfdestruct(address)": [],
"blockhash(uint256)":['bytes32'], "suicide(address)": [],
"log0(bytes32)": [],
"log1(bytes32,bytes32)": [],
"log2(bytes32,bytes32,bytes32)": [],
"log3(bytes32,bytes32,bytes32,bytes32)": [],
"blockhash(uint256)": ["bytes32"],
# the following need a special handling # the following need a special handling
# as they are recognized as a SolidityVariableComposed # as they are recognized as a SolidityVariableComposed
# and converted to a SolidityFunction by SlithIR # and converted to a SolidityFunction by SlithIR
"this.balance()":['uint256'], "this.balance()": ["uint256"],
"abi.encode()":['bytes'], "abi.encode()": ["bytes"],
"abi.encodePacked()":['bytes'], "abi.encodePacked()": ["bytes"],
"abi.encodeWithSelector()":["bytes"], "abi.encodeWithSelector()": ["bytes"],
"abi.encodeWithSignature()":["bytes"], "abi.encodeWithSignature()": ["bytes"],
# abi.decode returns an a list arbitrary types # abi.decode returns an a list arbitrary types
"abi.decode()":[], "abi.decode()": [],
"type(address)":[]} "type(address)": [],
}
def solidity_function_signature(name): def solidity_function_signature(name):
""" """
@ -70,25 +79,25 @@ def solidity_function_signature(name):
Returns: Returns:
str str
""" """
return name+' returns({})'.format(','.join(SOLIDITY_FUNCTIONS[name])) return name + " returns({})".format(",".join(SOLIDITY_FUNCTIONS[name]))
class SolidityVariable(Context):
def __init__(self, name): class SolidityVariable(Context):
def __init__(self, name: str):
super(SolidityVariable, self).__init__() super(SolidityVariable, self).__init__()
self._check_name(name) self._check_name(name)
self._name = name self._name = name
# dev function, will be removed once the code is stable # dev function, will be removed once the code is stable
def _check_name(self, name): def _check_name(self, name: str):
assert name in SOLIDITY_VARIABLES assert name in SOLIDITY_VARIABLES
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@property @property
def type(self): def type(self) -> ElementaryType:
return ElementaryType(SOLIDITY_VARIABLES[self.name]) return ElementaryType(SOLIDITY_VARIABLES[self.name])
def __str__(self): def __str__(self):
@ -100,19 +109,20 @@ class SolidityVariable(Context):
def __hash__(self): def __hash__(self):
return hash(self.name) return hash(self.name)
class SolidityVariableComposed(SolidityVariable): class SolidityVariableComposed(SolidityVariable):
def __init__(self, name): def __init__(self, name: str):
super(SolidityVariableComposed, self).__init__(name) super(SolidityVariableComposed, self).__init__(name)
def _check_name(self, name): def _check_name(self, name: str):
assert name in SOLIDITY_VARIABLES_COMPOSED assert name in SOLIDITY_VARIABLES_COMPOSED
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@property @property
def type(self): def type(self) -> ElementaryType:
return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name]) return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name])
def __str__(self): def __str__(self):
@ -131,25 +141,28 @@ class SolidityFunction:
# https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information
# As a result, we set return_type during the Ir conversion # As a result, we set return_type during the Ir conversion
def __init__(self, name): def __init__(self, name: str):
assert name in SOLIDITY_FUNCTIONS assert name in SOLIDITY_FUNCTIONS
self._name = name self._name = name
self._return_type = [ElementaryType(x) for x in SOLIDITY_FUNCTIONS[self.name]] # Can be TypeInformation if type(address) is used
self._return_type: List[Union[TypeInformation, ElementaryType]] = [
ElementaryType(x) for x in SOLIDITY_FUNCTIONS[self.name]
]
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@property @property
def full_name(self): def full_name(self) -> str:
return self.name return self.name
@property @property
def return_type(self): def return_type(self) -> List[Union[TypeInformation, ElementaryType]]:
return self._return_type return self._return_type
@return_type.setter @return_type.setter
def return_type(self, r): def return_type(self, r: List[Union[TypeInformation, ElementaryType]]):
self._return_type = r self._return_type = r
def __str__(self): def __str__(self):

@ -1,30 +1,43 @@
from slither.core.source_mapping.source_mapping import SourceMapping from typing import List, TYPE_CHECKING, Dict
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.variable import Variable if TYPE_CHECKING:
from slither.core.variables.structure_variable import StructureVariable
class Structure(ChildContract, SourceMapping):
class Structure(ChildContract, SourceMapping):
def __init__(self): def __init__(self):
super(Structure, self).__init__() super(Structure, self).__init__()
self._name = None self._name = None
self._canonical_name = None self._canonical_name = None
self._elems = None self._elems: Dict[str, "StructureVariable"] = dict()
# Name of the elements in the order of declaration # Name of the elements in the order of declaration
self._elems_ordered = None self._elems_ordered: List[str] = []
@property @property
def canonical_name(self): def canonical_name(self) -> str:
return self._canonical_name return self._canonical_name
@canonical_name.setter
def canonical_name(self, name: str):
self._canonical_name = name
@property @property
def name(self): def name(self) -> str:
return self._name return self._name
@name.setter
def name(self, new_name: str):
self._name = new_name
@property @property
def elems(self): def elems(self) -> Dict[str, "StructureVariable"]:
return self._elems return self._elems
def add_elem_in_order(self, s: str):
self._elems_ordered.append(s)
def is_declared_by(self, contract): def is_declared_by(self, contract):
""" """
@ -35,12 +48,11 @@ class Structure(ChildContract, SourceMapping):
return self.contract == contract return self.contract == contract
@property @property
def elems_ordered(self): def elems_ordered(self) -> List["StructureVariable"]:
ret = [] ret = []
for e in self._elems_ordered: for e in self._elems_ordered:
ret.append(self._elems[e]) ret.append(self._elems[e])
return ret return ret
def __str__(self): def __str__(self):
return self.name return self.name

@ -1,37 +1,27 @@
''' """
Nodes of the dominator tree Nodes of the dominator tree
''' """
from typing import TYPE_CHECKING, Set, List
from slither.core.children.child_function import ChildFunction if TYPE_CHECKING:
from slither.core.cfg.node import Node
class DominatorNode(object):
class DominatorNode(object):
def __init__(self): def __init__(self):
self._succ = set() self._succ: Set["Node"] = set()
self._nodes = [] self._nodes: List["Node"] = []
def add_node(self, node): def add_node(self, node: "Node"):
self._nodes.append(node) self._nodes.append(node)
def add_successor(self, succ): def add_successor(self, succ: "Node"):
self._succ.add(succ) self._succ.add(succ)
@property @property
def cfg_nodes(self): def cfg_nodes(self) -> List["Node"]:
return self._nodes return self._nodes
@property @property
def sucessors(self): def sucessors(self) -> Set["Node"]:
'''
Returns:
dict(Node)
'''
return self._succ return self._succ
class DominatorTree(ChildFunction):
def __init__(self, entry_point):
super(DominatorTree, self).__init__()

@ -1,6 +1,12 @@
from typing import List, TYPE_CHECKING
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
def intersection_predecessor(node): if TYPE_CHECKING:
from slither.core.cfg.node import Node
def intersection_predecessor(node: "Node"):
if not node.fathers: if not node.fathers:
return set() return set()
ret = node.fathers[0].dominators ret = node.fathers[0].dominators
@ -8,13 +14,14 @@ def intersection_predecessor(node):
ret = ret.intersection(pred.dominators) ret = ret.intersection(pred.dominators)
return ret return ret
def compute_dominators(nodes):
''' def compute_dominators(nodes: List["Node"]):
"""
Naive implementation of Cooper, Harvey, Kennedy algo Naive implementation of Cooper, Harvey, Kennedy algo
See 'A Simple,Fast Dominance Algorithm' See 'A Simple,Fast Dominance Algorithm'
Compute strict domniators Compute strict domniators
''' """
changed = True changed = True
for n in nodes: for n in nodes:
@ -36,33 +43,38 @@ def compute_dominators(nodes):
for dominator in node.dominators: for dominator in node.dominators:
if dominator != node: if dominator != node:
[idom_candidates.remove(d) for d in dominator.dominators if d in idom_candidates and d!=dominator] [
idom_candidates.remove(d)
for d in dominator.dominators
if d in idom_candidates and d != dominator
]
assert len(idom_candidates)<=1 assert len(idom_candidates) <= 1
if idom_candidates: if idom_candidates:
idom = idom_candidates.pop() idom = idom_candidates.pop()
node.immediate_dominator = idom node.immediate_dominator = idom
idom.dominator_successors.add(node) idom.dominator_successors.add(node)
def compute_dominance_frontier(nodes: List["Node"]):
def compute_dominance_frontier(nodes): """
'''
Naive implementation of Cooper, Harvey, Kennedy algo Naive implementation of Cooper, Harvey, Kennedy algo
See 'A Simple,Fast Dominance Algorithm' See 'A Simple,Fast Dominance Algorithm'
Compute dominance frontier Compute dominance frontier
''' """
for node in nodes: for node in nodes:
if len(node.fathers) >= 2: if len(node.fathers) >= 2:
for father in node.fathers: for father in node.fathers:
runner = father runner = father
# Corner case: if there is a if without else # Corner case: if there is a if without else
# we need to add update the conditional node # we need to add update the conditional node
if runner == node.immediate_dominator and runner.type == NodeType.IF and node.type == NodeType.ENDIF: if (
runner == node.immediate_dominator
and runner.type == NodeType.IF
and node.type == NodeType.ENDIF
):
runner.dominance_frontier = runner.dominance_frontier.union({node}) runner.dominance_frontier = runner.dominance_frontier.union({node})
while runner != node.immediate_dominator: while runner != node.immediate_dominator:
runner.dominance_frontier = runner.dominance_frontier.union({node}) runner.dominance_frontier = runner.dominance_frontier.union({node})
runner = runner.immediate_dominator runner = runner.immediate_dominator

@ -1,3 +1,5 @@
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
class SlitherCoreError(SlitherException): pass
class SlitherCoreError(SlitherException):
pass

@ -1,11 +1,18 @@
import logging import logging
from enum import Enum
from typing import Optional, TYPE_CHECKING, List
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError
if TYPE_CHECKING:
from slither.core.solidity_types.type import Type
logger = logging.getLogger("AssignmentOperation") logger = logging.getLogger("AssignmentOperation")
class AssignmentOperationType:
class AssignmentOperationType(Enum):
ASSIGN = 0 # = ASSIGN = 0 # =
ASSIGN_OR = 1 # |= ASSIGN_OR = 1 # |=
ASSIGN_CARET = 2 # ^= ASSIGN_CARET = 2 # ^=
@ -19,93 +26,93 @@ class AssignmentOperationType:
ASSIGN_MODULO = 10 # %= ASSIGN_MODULO = 10 # %=
@staticmethod @staticmethod
def get_type(operation_type): def get_type(operation_type: "AssignmentOperationType"):
if operation_type == '=': if operation_type == "=":
return AssignmentOperationType.ASSIGN return AssignmentOperationType.ASSIGN
if operation_type == '|=': if operation_type == "|=":
return AssignmentOperationType.ASSIGN_OR return AssignmentOperationType.ASSIGN_OR
if operation_type == '^=': if operation_type == "^=":
return AssignmentOperationType.ASSIGN_CARET return AssignmentOperationType.ASSIGN_CARET
if operation_type == '&=': if operation_type == "&=":
return AssignmentOperationType.ASSIGN_AND return AssignmentOperationType.ASSIGN_AND
if operation_type == '<<=': if operation_type == "<<=":
return AssignmentOperationType.ASSIGN_LEFT_SHIFT return AssignmentOperationType.ASSIGN_LEFT_SHIFT
if operation_type == '>>=': if operation_type == ">>=":
return AssignmentOperationType.ASSIGN_RIGHT_SHIFT return AssignmentOperationType.ASSIGN_RIGHT_SHIFT
if operation_type == '+=': if operation_type == "+=":
return AssignmentOperationType.ASSIGN_ADDITION return AssignmentOperationType.ASSIGN_ADDITION
if operation_type == '-=': if operation_type == "-=":
return AssignmentOperationType.ASSIGN_SUBTRACTION return AssignmentOperationType.ASSIGN_SUBTRACTION
if operation_type == '*=': if operation_type == "*=":
return AssignmentOperationType.ASSIGN_MULTIPLICATION return AssignmentOperationType.ASSIGN_MULTIPLICATION
if operation_type == '/=': if operation_type == "/=":
return AssignmentOperationType.ASSIGN_DIVISION return AssignmentOperationType.ASSIGN_DIVISION
if operation_type == '%=': if operation_type == "%=":
return AssignmentOperationType.ASSIGN_MODULO return AssignmentOperationType.ASSIGN_MODULO
raise SlitherCoreError('get_type: Unknown operation type {})'.format(operation_type)) raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type))
@staticmethod def __str__(self):
def str(operation_type): if self == AssignmentOperationType.ASSIGN:
if operation_type == AssignmentOperationType.ASSIGN: return "="
return '=' if self == AssignmentOperationType.ASSIGN_OR:
if operation_type == AssignmentOperationType.ASSIGN_OR: return "|="
return '|=' if self == AssignmentOperationType.ASSIGN_CARET:
if operation_type == AssignmentOperationType.ASSIGN_CARET: return "^="
return '^=' if self == AssignmentOperationType.ASSIGN_AND:
if operation_type == AssignmentOperationType.ASSIGN_AND: return "&="
return '&=' if self == AssignmentOperationType.ASSIGN_LEFT_SHIFT:
if operation_type == AssignmentOperationType.ASSIGN_LEFT_SHIFT: return "<<="
return '<<=' if self == AssignmentOperationType.ASSIGN_RIGHT_SHIFT:
if operation_type == AssignmentOperationType.ASSIGN_RIGHT_SHIFT: return ">>="
return '>>=' if self == AssignmentOperationType.ASSIGN_ADDITION:
if operation_type == AssignmentOperationType.ASSIGN_ADDITION: return "+="
return '+=' if self == AssignmentOperationType.ASSIGN_SUBTRACTION:
if operation_type == AssignmentOperationType.ASSIGN_SUBTRACTION: return "-="
return '-=' if self == AssignmentOperationType.ASSIGN_MULTIPLICATION:
if operation_type == AssignmentOperationType.ASSIGN_MULTIPLICATION: return "*="
return '*=' if self == AssignmentOperationType.ASSIGN_DIVISION:
if operation_type == AssignmentOperationType.ASSIGN_DIVISION: return "/="
return '/=' if self == AssignmentOperationType.ASSIGN_MODULO:
if operation_type == AssignmentOperationType.ASSIGN_MODULO: return "%="
return '%=' raise SlitherCoreError("str: Unknown operation type {})".format(self))
raise SlitherCoreError('str: Unknown operation type {})'.format(operation_type))
class AssignmentOperation(ExpressionTyped):
def __init__(self, left_expression, right_expression, expression_type, expression_return_type): class AssignmentOperation(ExpressionTyped):
def __init__(
self,
left_expression: Expression,
right_expression: Expression,
expression_type: AssignmentOperationType,
expression_return_type: Optional["Type"],
):
assert isinstance(left_expression, Expression) assert isinstance(left_expression, Expression)
assert isinstance(right_expression, Expression) assert isinstance(right_expression, Expression)
super(AssignmentOperation, self).__init__() super(AssignmentOperation, self).__init__()
left_expression.set_lvalue() left_expression.set_lvalue()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
self._type = expression_type self._type = expression_type
self._expression_return_type = expression_return_type self._expression_return_type: Optional["Type"] = expression_return_type
@property @property
def expressions(self): def expressions(self) -> List[Expression]:
return self._expressions return self._expressions
@property @property
def expression_return_type(self): def expression_return_type(self) -> Optional["Type"]:
return self._expression_return_type return self._expression_return_type
@property @property
def expression_left(self): def expression_left(self) -> Expression:
return self._expressions[0] return self._expressions[0]
@property @property
def expression_right(self): def expression_right(self) -> Expression:
return self._expressions[1] return self._expressions[1]
@property @property
def type(self): def type(self) -> AssignmentOperationType:
return self._type return self._type
@property
def type_str(self):
return AssignmentOperationType.str(self._type)
def __str__(self): def __str__(self):
return str(self.expression_left) + " "+ self.type_str + " " + str(self.expression_right) return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right)

@ -1,11 +1,16 @@
import logging import logging
from enum import Enum
from typing import List
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError
logger = logging.getLogger("BinaryOperation") logger = logging.getLogger("BinaryOperation")
class BinaryOperationType:
class BinaryOperationType(Enum):
POWER = 0 # ** POWER = 0 # **
MULTIPLICATION = 1 # * MULTIPLICATION = 1 # *
DIVISION = 2 # / DIVISION = 2 # /
@ -27,119 +32,113 @@ class BinaryOperationType:
OROR = 18 # || OROR = 18 # ||
@staticmethod @staticmethod
def get_type(operation_type): def get_type(operation_type: "BinaryOperation"):
if operation_type == '**': if operation_type == "**":
return BinaryOperationType.POWER return BinaryOperationType.POWER
if operation_type == '*': if operation_type == "*":
return BinaryOperationType.MULTIPLICATION return BinaryOperationType.MULTIPLICATION
if operation_type == '/': if operation_type == "/":
return BinaryOperationType.DIVISION return BinaryOperationType.DIVISION
if operation_type == '%': if operation_type == "%":
return BinaryOperationType.MODULO return BinaryOperationType.MODULO
if operation_type == '+': if operation_type == "+":
return BinaryOperationType.ADDITION return BinaryOperationType.ADDITION
if operation_type == '-': if operation_type == "-":
return BinaryOperationType.SUBTRACTION return BinaryOperationType.SUBTRACTION
if operation_type == '<<': if operation_type == "<<":
return BinaryOperationType.LEFT_SHIFT return BinaryOperationType.LEFT_SHIFT
if operation_type == '>>': if operation_type == ">>":
return BinaryOperationType.RIGHT_SHIFT return BinaryOperationType.RIGHT_SHIFT
if operation_type == '&': if operation_type == "&":
return BinaryOperationType.AND return BinaryOperationType.AND
if operation_type == '^': if operation_type == "^":
return BinaryOperationType.CARET return BinaryOperationType.CARET
if operation_type == '|': if operation_type == "|":
return BinaryOperationType.OR return BinaryOperationType.OR
if operation_type == '<': if operation_type == "<":
return BinaryOperationType.LESS return BinaryOperationType.LESS
if operation_type == '>': if operation_type == ">":
return BinaryOperationType.GREATER return BinaryOperationType.GREATER
if operation_type == '<=': if operation_type == "<=":
return BinaryOperationType.LESS_EQUAL return BinaryOperationType.LESS_EQUAL
if operation_type == '>=': if operation_type == ">=":
return BinaryOperationType.GREATER_EQUAL return BinaryOperationType.GREATER_EQUAL
if operation_type == '==': if operation_type == "==":
return BinaryOperationType.EQUAL return BinaryOperationType.EQUAL
if operation_type == '!=': if operation_type == "!=":
return BinaryOperationType.NOT_EQUAL return BinaryOperationType.NOT_EQUAL
if operation_type == '&&': if operation_type == "&&":
return BinaryOperationType.ANDAND return BinaryOperationType.ANDAND
if operation_type == '||': if operation_type == "||":
return BinaryOperationType.OROR return BinaryOperationType.OROR
raise SlitherCoreError('get_type: Unknown operation type {})'.format(operation_type)) raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type))
@staticmethod def __str__(self):
def str(operation_type): if self == BinaryOperationType.POWER:
if operation_type == BinaryOperationType.POWER: return "**"
return '**' if self == BinaryOperationType.MULTIPLICATION:
if operation_type == BinaryOperationType.MULTIPLICATION: return "*"
return '*' if self == BinaryOperationType.DIVISION:
if operation_type == BinaryOperationType.DIVISION: return "/"
return '/' if self == BinaryOperationType.MODULO:
if operation_type == BinaryOperationType.MODULO: return "%"
return '%' if self == BinaryOperationType.ADDITION:
if operation_type == BinaryOperationType.ADDITION: return "+"
return '+' if self == BinaryOperationType.SUBTRACTION:
if operation_type == BinaryOperationType.SUBTRACTION: return "-"
return '-' if self == BinaryOperationType.LEFT_SHIFT:
if operation_type == BinaryOperationType.LEFT_SHIFT: return "<<"
return '<<' if self == BinaryOperationType.RIGHT_SHIFT:
if operation_type == BinaryOperationType.RIGHT_SHIFT: return ">>"
return '>>' if self == BinaryOperationType.AND:
if operation_type == BinaryOperationType.AND: return "&"
return '&' if self == BinaryOperationType.CARET:
if operation_type == BinaryOperationType.CARET: return "^"
return '^' if self == BinaryOperationType.OR:
if operation_type == BinaryOperationType.OR: return "|"
return '|' if self == BinaryOperationType.LESS:
if operation_type == BinaryOperationType.LESS: return "<"
return '<' if self == BinaryOperationType.GREATER:
if operation_type == BinaryOperationType.GREATER: return ">"
return '>' if self == BinaryOperationType.LESS_EQUAL:
if operation_type == BinaryOperationType.LESS_EQUAL: return "<="
return '<=' if self == BinaryOperationType.GREATER_EQUAL:
if operation_type == BinaryOperationType.GREATER_EQUAL: return ">="
return '>=' if self == BinaryOperationType.EQUAL:
if operation_type == BinaryOperationType.EQUAL: return "=="
return '==' if self == BinaryOperationType.NOT_EQUAL:
if operation_type == BinaryOperationType.NOT_EQUAL: return "!="
return '!=' if self == BinaryOperationType.ANDAND:
if operation_type == BinaryOperationType.ANDAND: return "&&"
return '&&' if self == BinaryOperationType.OROR:
if operation_type == BinaryOperationType.OROR: return "||"
return '||' raise SlitherCoreError("str: Unknown operation type {})".format(self))
raise SlitherCoreError('str: Unknown operation type {})'.format(operation_type))
class BinaryOperation(ExpressionTyped):
class BinaryOperation(ExpressionTyped):
def __init__(self, left_expression, right_expression, expression_type): def __init__(self, left_expression, right_expression, expression_type):
assert isinstance(left_expression, Expression) assert isinstance(left_expression, Expression)
assert isinstance(right_expression, Expression) assert isinstance(right_expression, Expression)
super(BinaryOperation, self).__init__() super(BinaryOperation, self).__init__()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
self._type = expression_type self._type: BinaryOperationType = expression_type
@property @property
def expressions(self): def expressions(self) -> List[Expression]:
return self._expressions return self._expressions
@property @property
def expression_left(self): def expression_left(self) -> Expression:
return self._expressions[0] return self._expressions[0]
@property @property
def expression_right(self): def expression_right(self) -> Expression:
return self._expressions[1] return self._expressions[1]
@property @property
def type(self): def type(self) -> BinaryOperationType:
return self._type return self._type
@property
def type_str(self):
return BinaryOperationType.str(self._type)
def __str__(self): def __str__(self):
return str(self.expression_left) + ' ' + self.type_str + ' ' + str(self.expression_right) return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right)

@ -1,23 +1,24 @@
from typing import Optional, List
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
class CallExpression(Expression): class CallExpression(Expression):
def __init__(self, called, arguments, type_call): def __init__(self, called, arguments, type_call):
assert isinstance(called, Expression) assert isinstance(called, Expression)
super(CallExpression, self).__init__() super(CallExpression, self).__init__()
self._called = called self._called: Expression = called
self._arguments = arguments self._arguments: List[Expression] = arguments
self._type_call = type_call self._type_call: str = type_call
# gas and value are only available if the syntax is {gas: , value: } # gas and value are only available if the syntax is {gas: , value: }
# For the .gas().value(), the member are considered as function call # For the .gas().value(), the member are considered as function call
# And converted later to the correct info (convert.py) # And converted later to the correct info (convert.py)
self._gas = None self._gas: Optional[Expression] = None
self._value = None self._value: Optional[Expression] = None
self._salt = None self._salt: Optional[Expression] = None
@property @property
def call_value(self): def call_value(self) -> Optional[Expression]:
return self._value return self._value
@call_value.setter @call_value.setter
@ -25,7 +26,7 @@ class CallExpression(Expression):
self._value = v self._value = v
@property @property
def call_gas(self): def call_gas(self) -> Optional[Expression]:
return self._gas return self._gas
@call_gas.setter @call_gas.setter
@ -41,24 +42,32 @@ class CallExpression(Expression):
self._salt = salt self._salt = salt
@property @property
def called(self): def call_salt(self):
return self._salt
@call_salt.setter
def call_salt(self, salt):
self._salt = salt
@property
def called(self) -> Expression:
return self._called return self._called
@property @property
def arguments(self): def arguments(self) -> List[Expression]:
return self._arguments return self._arguments
@property @property
def type_call(self): def type_call(self) -> str:
return self._type_call return self._type_call
def __str__(self): def __str__(self):
txt = str(self._called) txt = str(self._called)
if self.call_gas or self.call_value: if self.call_gas or self.call_value:
gas = f'gas: {self.call_gas}' if self.call_gas else '' gas = f"gas: {self.call_gas}" if self.call_gas else ""
value = f'value: {self.call_value}' if self.call_value else '' value = f"value: {self.call_value}" if self.call_value else ""
salt = f'salt: {self.call_salt}' if self.call_salt else '' salt = f"salt: {self.call_salt}" if self.call_salt else ""
if gas or value or salt: if gas or value or salt:
options = [gas, value, salt] options = [gas, value, salt]
txt += '{' + ','.join([o for o in options if o != '']) + '}' txt += "{" + ",".join([o for o in options if o != ""]) + "}"
return txt + '(' + ','.join([str(a) for a in self._arguments]) + ')' return txt + "(" + ",".join([str(a) for a in self._arguments]) + ")"

@ -1,32 +1,40 @@
from typing import List
from .expression import Expression from .expression import Expression
class ConditionalExpression(Expression):
class ConditionalExpression(Expression):
def __init__(self, if_expression, then_expression, else_expression): def __init__(self, if_expression, then_expression, else_expression):
assert isinstance(if_expression, Expression) assert isinstance(if_expression, Expression)
assert isinstance(then_expression, Expression) assert isinstance(then_expression, Expression)
assert isinstance(else_expression, Expression) assert isinstance(else_expression, Expression)
super(ConditionalExpression, self).__init__() super(ConditionalExpression, self).__init__()
self._if_expression = if_expression self._if_expression: Expression = if_expression
self._then_expression = then_expression self._then_expression: Expression = then_expression
self._else_expression = else_expression self._else_expression: Expression = else_expression
@property @property
def expressions(self): def expressions(self) -> List[Expression]:
return [self._if_expression, self._then_expression, self._else_expression] return [self._if_expression, self._then_expression, self._else_expression]
@property @property
def if_expression(self): def if_expression(self) -> Expression:
return self._if_expression return self._if_expression
@property @property
def else_expression(self): def else_expression(self) -> Expression:
return self._else_expression return self._else_expression
@property @property
def then_expression(self): def then_expression(self) -> Expression:
return self._then_expression return self._then_expression
def __str__(self): def __str__(self):
return 'if ' + str(self._if_expression) + ' then ' + str(self._then_expression) + ' else ' + str(self._else_expression) return (
"if "
+ str(self._if_expression)
+ " then "
+ str(self._then_expression)
+ " else "
+ str(self._else_expression)
)

@ -4,17 +4,16 @@
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
class ElementaryTypeNameExpression(Expression):
class ElementaryTypeNameExpression(Expression):
def __init__(self, t): def __init__(self, t):
assert isinstance(t, Type) assert isinstance(t, Type)
super(ElementaryTypeNameExpression, self).__init__() super(ElementaryTypeNameExpression, self).__init__()
self._type = t self._type = t
@property @property
def type(self): def type(self) -> Type:
return self._type return self._type
def __str__(self): def __str__(self):
return str(self._type) return str(self._type)

@ -1,15 +1,14 @@
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
class Expression( SourceMapping):
class Expression(SourceMapping):
def __init__(self): def __init__(self):
super(Expression, self).__init__() super(Expression, self).__init__()
self._is_lvalue = False self._is_lvalue = False
@property @property
def is_lvalue(self): def is_lvalue(self) -> bool:
return self._is_lvalue return self._is_lvalue
def set_lvalue(self): def set_lvalue(self):
self._is_lvalue = True self._is_lvalue = True

@ -1,13 +1,16 @@
from typing import Optional, TYPE_CHECKING
from .expression import Expression from .expression import Expression
class ExpressionTyped(Expression): if TYPE_CHECKING:
from ..solidity_types.type import Type
class ExpressionTyped(Expression):
def __init__(self): def __init__(self):
super(ExpressionTyped, self).__init__() super(ExpressionTyped, self).__init__()
self._type = None self._type: Optional["Type"] = None
@property @property
def type(self): def type(self):
return self._type return self._type

@ -1,15 +1,19 @@
from typing import TYPE_CHECKING
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression_typed import ExpressionTyped
class Identifier(ExpressionTyped): if TYPE_CHECKING:
from slither.core.variables.variable import Variable
class Identifier(ExpressionTyped):
def __init__(self, value): def __init__(self, value):
super(Identifier, self).__init__() super(Identifier, self).__init__()
self._value = value self._value: "Variable" = value
@property @property
def value(self): def value(self) -> "Variable":
return self._value return self._value
def __str__(self): def __str__(self):
return str(self._value) return str(self._value)

@ -1,32 +1,36 @@
from typing import List, TYPE_CHECKING
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.solidity_types.type import Type
class IndexAccess(ExpressionTyped): if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type
class IndexAccess(ExpressionTyped):
def __init__(self, left_expression, right_expression, index_type): def __init__(self, left_expression, right_expression, index_type):
super(IndexAccess, self).__init__() super(IndexAccess, self).__init__()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
# TODO type of undexAccess is not always a Type # TODO type of undexAccess is not always a Type
# assert isinstance(index_type, Type) # assert isinstance(index_type, Type)
self._type = index_type self._type: "Type" = index_type
@property @property
def expressions(self): def expressions(self) -> List["Expression"]:
return self._expressions return self._expressions
@property @property
def expression_left(self): def expression_left(self) -> "Expression":
return self._expressions[0] return self._expressions[0]
@property @property
def expression_right(self): def expression_right(self) -> "Expression":
return self._expressions[1] return self._expressions[1]
@property @property
def type(self): def type(self) -> "Type":
return self._type return self._type
def __str__(self): def __str__(self):
return str(self.expression_left) + '[' + str(self.expression_right) + ']' return str(self.expression_left) + "[" + str(self.expression_right) + "]"

@ -1,24 +1,29 @@
from typing import Optional, Union, TYPE_CHECKING
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.utils.arithmetic import convert_subdenomination from slither.utils.arithmetic import convert_subdenomination
class Literal(Expression): if TYPE_CHECKING:
from slither.core.solidity_types.type import Type
class Literal(Expression):
def __init__(self, value, type, subdenomination=None): def __init__(self, value, type, subdenomination=None):
super(Literal, self).__init__() super(Literal, self).__init__()
self._value = value self._value: Union[int, str] = value
self._type = type self._type = type
self._subdenomination = subdenomination self._subdenomination: Optional[str] = subdenomination
@property @property
def value(self): def value(self) -> Union[int, str]:
return self._value return self._value
@property @property
def type(self): def type(self) -> "Type":
return self._type return self._type
@property @property
def subdenomination(self): def subdenomination(self) -> Optional[str]:
return self._subdenomination return self._subdenomination
def __str__(self): def __str__(self):

@ -1,29 +1,33 @@
from typing import TYPE_CHECKING
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression_typed import ExpressionTyped
if TYPE_CHECKING:
from slither.core.solidity_types.type import Type
class MemberAccess(ExpressionTyped):
class MemberAccess(ExpressionTyped):
def __init__(self, member_name, member_type, expression): def __init__(self, member_name, member_type, expression):
# assert isinstance(member_type, Type) # assert isinstance(member_type, Type)
# TODO member_type is not always a Type # TODO member_type is not always a Type
assert isinstance(expression, Expression) assert isinstance(expression, Expression)
super(MemberAccess, self).__init__() super(MemberAccess, self).__init__()
self._type = member_type self._type: "Type" = member_type
self._member_name = member_name self._member_name: str = member_name
self._expression = expression self._expression: Expression = expression
@property @property
def expression(self): def expression(self) -> Expression:
return self._expression return self._expression
@property @property
def member_name(self): def member_name(self) -> str:
return self._member_name return self._member_name
@property @property
def type(self): def type(self) -> "Type":
return self._type return self._type
def __str__(self): def __str__(self):
return str(self.expression) + '.' + self.member_name return str(self.expression) + "." + self.member_name

@ -1,23 +1,23 @@
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
class NewArray(Expression): class NewArray(Expression):
# note: dont conserve the size of the array if provided # note: dont conserve the size of the array if provided
def __init__(self, depth, array_type): def __init__(self, depth, array_type):
super(NewArray, self).__init__() super(NewArray, self).__init__()
assert isinstance(array_type, Type) assert isinstance(array_type, Type)
self._depth = depth self._depth: int = depth
self._array_type = array_type self._array_type: Type = array_type
@property @property
def array_type(self): def array_type(self) -> Type:
return self._array_type return self._array_type
@property @property
def depth(self): def depth(self) -> int:
return self._depth return self._depth
def __str__(self): def __str__(self):
return 'new ' + str(self._array_type) + '[]'* self._depth return "new " + str(self._array_type) + "[]" * self._depth

@ -1,17 +1,16 @@
from .expression import Expression from .expression import Expression
class NewContract(Expression):
class NewContract(Expression):
def __init__(self, contract_name): def __init__(self, contract_name):
super(NewContract, self).__init__() super(NewContract, self).__init__()
self._contract_name = contract_name self._contract_name: str = contract_name
self._gas = None self._gas = None
self._value = None self._value = None
self._salt = None self._salt = None
@property @property
def contract_name(self): def contract_name(self) -> str:
return self._contract_name return self._contract_name
@property @property
@ -30,7 +29,5 @@ class NewContract(Expression):
def call_salt(self, salt): def call_salt(self, salt):
self._salt = salt self._salt = salt
def __str__(self): def __str__(self):
return 'new ' + str(self._contract_name) return "new " + str(self._contract_name)

@ -1,17 +1,16 @@
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
class NewElementaryType(Expression):
class NewElementaryType(Expression):
def __init__(self, new_type): def __init__(self, new_type):
assert isinstance(new_type, ElementaryType) assert isinstance(new_type, ElementaryType)
super(NewElementaryType, self).__init__() super(NewElementaryType, self).__init__()
self._type = new_type self._type = new_type
@property @property
def type(self): def type(self) -> ElementaryType:
return self._type return self._type
def __str__(self): def __str__(self):
return 'new ' + str(self._type) return "new " + str(self._type)

@ -1,4 +1,6 @@
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.call_expression import CallExpression
class SuperCallExpression(CallExpression): pass
class SuperCallExpression(CallExpression):
pass

@ -1,8 +1,6 @@
from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
class SuperIdentifier(Identifier):
class SuperIdentifier(Identifier):
def __str__(self): def __str__(self):
return 'super.' + str(self._value) return "super." + str(self._value)

@ -1,17 +1,18 @@
from typing import List
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
class TupleExpression(Expression):
class TupleExpression(Expression):
def __init__(self, expressions): def __init__(self, expressions):
assert all(isinstance(x, Expression) for x in expressions if x) assert all(isinstance(x, Expression) for x in expressions if x)
super(TupleExpression, self).__init__() super(TupleExpression, self).__init__()
self._expressions = expressions self._expressions = expressions
@property @property
def expressions(self): def expressions(self) -> List[Expression]:
return self._expressions return self._expressions
def __str__(self): def __str__(self):
expressions_str = [str(e) for e in self.expressions] expressions_str = [str(e) for e in self.expressions]
return '(' + ','.join(expressions_str) + ')' return "(" + ",".join(expressions_str) + ")"

@ -4,18 +4,16 @@ from slither.core.solidity_types.type import Type
class TypeConversion(ExpressionTyped): class TypeConversion(ExpressionTyped):
def __init__(self, expression, expression_type): def __init__(self, expression, expression_type):
super(TypeConversion, self).__init__() super(TypeConversion, self).__init__()
assert isinstance(expression, Expression) assert isinstance(expression, Expression)
assert isinstance(expression_type, Type) assert isinstance(expression_type, Type)
self._expression = expression self._expression: Expression = expression
self._type = expression_type self._type: Type = expression_type
@property @property
def expression(self): def expression(self) -> Expression:
return self._expression return self._expression
def __str__(self): def __str__(self):
return str(self.type) + '(' + str(self.expression) + ')' return str(self.type) + "(" + str(self.expression) + ")"

@ -1,11 +1,14 @@
import logging import logging
from enum import Enum
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError
logger = logging.getLogger("UnaryOperation") logger = logging.getLogger("UnaryOperation")
class UnaryOperationType:
class UnaryOperationType(Enum):
BANG = 0 # ! BANG = 0 # !
TILD = 1 # ~ TILD = 1 # ~
DELETE = 2 # delete DELETE = 2 # delete
@ -19,96 +22,100 @@ class UnaryOperationType:
@staticmethod @staticmethod
def get_type(operation_type, isprefix): def get_type(operation_type, isprefix):
if isprefix: if isprefix:
if operation_type == '!': if operation_type == "!":
return UnaryOperationType.BANG return UnaryOperationType.BANG
if operation_type == '~': if operation_type == "~":
return UnaryOperationType.TILD return UnaryOperationType.TILD
if operation_type == 'delete': if operation_type == "delete":
return UnaryOperationType.DELETE return UnaryOperationType.DELETE
if operation_type == '++': if operation_type == "++":
return UnaryOperationType.PLUSPLUS_PRE return UnaryOperationType.PLUSPLUS_PRE
if operation_type == '--': if operation_type == "--":
return UnaryOperationType.MINUSMINUS_PRE return UnaryOperationType.MINUSMINUS_PRE
if operation_type == '+': if operation_type == "+":
return UnaryOperationType.PLUS_PRE return UnaryOperationType.PLUS_PRE
if operation_type == '-': if operation_type == "-":
return UnaryOperationType.MINUS_PRE return UnaryOperationType.MINUS_PRE
else: else:
if operation_type == '++': if operation_type == "++":
return UnaryOperationType.PLUSPLUS_POST return UnaryOperationType.PLUSPLUS_POST
if operation_type == '--': if operation_type == "--":
return UnaryOperationType.MINUSMINUS_POST return UnaryOperationType.MINUSMINUS_POST
raise SlitherCoreError('get_type: Unknown operation type {}'.format(operation_type)) raise SlitherCoreError("get_type: Unknown operation type {}".format(operation_type))
@staticmethod def __str__(self):
def str(operation_type): if self == UnaryOperationType.BANG:
if operation_type == UnaryOperationType.BANG: return "!"
return '!' if self == UnaryOperationType.TILD:
if operation_type == UnaryOperationType.TILD: return "~"
return '~' if self == UnaryOperationType.DELETE:
if operation_type == UnaryOperationType.DELETE: return "delete"
return 'delete' if self == UnaryOperationType.PLUS_PRE:
if operation_type == UnaryOperationType.PLUS_PRE: return "+"
return '+' if self == UnaryOperationType.MINUS_PRE:
if operation_type == UnaryOperationType.MINUS_PRE: return "-"
return '-' if self in [UnaryOperationType.PLUSPLUS_PRE, UnaryOperationType.PLUSPLUS_POST]:
if operation_type in [UnaryOperationType.PLUSPLUS_PRE, UnaryOperationType.PLUSPLUS_POST]: return "++"
return '++' if self in [
if operation_type in [UnaryOperationType.MINUSMINUS_PRE, UnaryOperationType.MINUSMINUS_POST]: UnaryOperationType.MINUSMINUS_PRE,
return '--' UnaryOperationType.MINUSMINUS_POST,
]:
return "--"
raise SlitherCoreError('str: Unknown operation type {}'.format(operation_type)) raise SlitherCoreError("str: Unknown operation type {}".format(self))
@staticmethod @staticmethod
def is_prefix(operation_type): def is_prefix(operation_type):
if operation_type in [UnaryOperationType.BANG, if operation_type in [
UnaryOperationType.BANG,
UnaryOperationType.TILD, UnaryOperationType.TILD,
UnaryOperationType.DELETE, UnaryOperationType.DELETE,
UnaryOperationType.PLUSPLUS_PRE, UnaryOperationType.PLUSPLUS_PRE,
UnaryOperationType.MINUSMINUS_PRE, UnaryOperationType.MINUSMINUS_PRE,
UnaryOperationType.PLUS_PRE, UnaryOperationType.PLUS_PRE,
UnaryOperationType.MINUS_PRE]: UnaryOperationType.MINUS_PRE,
]:
return True return True
elif operation_type in [UnaryOperationType.PLUSPLUS_POST, UnaryOperationType.MINUSMINUS_POST]: elif operation_type in [
UnaryOperationType.PLUSPLUS_POST,
UnaryOperationType.MINUSMINUS_POST,
]:
return False return False
raise SlitherCoreError('is_prefix: Unknown operation type {}'.format(operation_type)) raise SlitherCoreError("is_prefix: Unknown operation type {}".format(operation_type))
class UnaryOperation(ExpressionTyped):
class UnaryOperation(ExpressionTyped):
def __init__(self, expression, expression_type): def __init__(self, expression, expression_type):
assert isinstance(expression, Expression) assert isinstance(expression, Expression)
super(UnaryOperation, self).__init__() super(UnaryOperation, self).__init__()
self._expression = expression self._expression: Expression = expression
self._type = expression_type self._type: UnaryOperationType = expression_type
if expression_type in [UnaryOperationType.DELETE, if expression_type in [
UnaryOperationType.DELETE,
UnaryOperationType.PLUSPLUS_PRE, UnaryOperationType.PLUSPLUS_PRE,
UnaryOperationType.MINUSMINUS_PRE, UnaryOperationType.MINUSMINUS_PRE,
UnaryOperationType.PLUSPLUS_POST, UnaryOperationType.PLUSPLUS_POST,
UnaryOperationType.MINUSMINUS_POST, UnaryOperationType.MINUSMINUS_POST,
UnaryOperationType.PLUS_PRE, UnaryOperationType.PLUS_PRE,
UnaryOperationType.MINUS_PRE]: UnaryOperationType.MINUS_PRE,
]:
expression.set_lvalue() expression.set_lvalue()
@property @property
def expression(self): def expression(self) -> Expression:
return self._expression return self._expression
@property @property
def type_str(self): def type(self) -> UnaryOperationType:
return UnaryOperationType.str(self._type)
@property
def type(self):
return self._type return self._type
@property @property
def is_prefix(self): def is_prefix(self) -> bool:
return UnaryOperationType.is_prefix(self._type) return UnaryOperationType.is_prefix(self._type)
def __str__(self): def __str__(self):
if self.is_prefix: if self.is_prefix:
return self.type_str + ' ' + str(self._expression) return str(self.type) + " " + str(self._expression)
else: else:
return str(self._expression) + ' ' + self.type_str return str(self._expression) + " " + str(self.type)

@ -6,39 +6,47 @@ import logging
import json import json
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Optional, Dict, List, Set, Union
from crytic_compile import CryticCompile
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.declarations import Contract, Pragma, Import, Function, Modifier
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import InternalCall from slither.slithir.operations import InternalCall
from slither.slithir.variables import Constant
from slither.utils.colors import red from slither.utils.colors import red
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
logging.basicConfig() logging.basicConfig()
class Slither(Context):
class SlitherCore(Context):
""" """
Slither static analyzer Slither static analyzer
""" """
def __init__(self): def __init__(self):
super(Slither, self).__init__() super(SlitherCore, self).__init__()
self._contracts = {} self._contracts: Dict[str, Contract] = {}
self._filename = None self._filename: Optional[str] = None
self._source_units = {} self._source_units: Dict[int, str] = {}
self._solc_version = None # '0.3' or '0.4':! self._solc_version: Optional[str] = None # '0.3' or '0.4':!
self._pragma_directives = [] self._pragma_directives: List[Pragma] = []
self._import_directives = [] self._import_directives: List[Import] = []
self._raw_source_code = {} self._raw_source_code: Dict[str, str] = {}
self._all_functions = set() self._all_functions: Set[Function] = set()
self._all_modifiers = set() self._all_modifiers: Set[Modifier] = set()
self._all_state_variables = None # Memoize
self._all_state_variables: Optional[Set[StateVariable]] = None
self._previous_results_filename = 'slither.db.json'
self._results_to_hide = [] self._previous_results_filename: str = "slither.db.json"
self._previous_results = [] self._results_to_hide: List = []
self._previous_results_ids = set() self._previous_results: List = []
self._paths_to_filter = set() self._previous_results_ids: Set[str] = set()
self._paths_to_filter: Set[str] = set()
self._crytic_compile = None
self._crytic_compile: Optional[CryticCompile] = None
self._generate_patches = False self._generate_patches = False
self._exclude_dependencies = False self._exclude_dependencies = False
@ -55,20 +63,24 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def source_code(self): def source_code(self) -> Dict[str, str]:
""" {filename: source_code (str)}: source code """ """ {filename: source_code (str)}: source code """
return self._raw_source_code return self._raw_source_code
@property @property
def source_units(self): def source_units(self) -> Dict[int, str]:
return self._source_units return self._source_units
@property @property
def filename(self): def filename(self) -> Optional[str]:
"""str: Filename.""" """str: Filename."""
return self._filename return self._filename
def _add_source_code(self, path): @filename.setter
def filename(self, filename: str):
self._filename = filename
def add_source_code(self, path):
""" """
:param path: :param path:
:return: :return:
@ -76,11 +88,11 @@ class Slither(Context):
if self.crytic_compile and path in self.crytic_compile.src_content: if self.crytic_compile and path in self.crytic_compile.src_content:
self.source_code[path] = self.crytic_compile.src_content[path] self.source_code[path] = self.crytic_compile.src_content[path]
else: else:
with open(path, encoding='utf8', newline='') as f: with open(path, encoding="utf8", newline="") as f:
self.source_code[path] = f.read() self.source_code[path] = f.read()
@property @property
def markdown_root(self): def markdown_root(self) -> str:
return self._markdown_root return self._markdown_root
# endregion # endregion
@ -91,19 +103,19 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def solc_version(self): def solc_version(self) -> str:
"""str: Solidity version.""" """str: Solidity version."""
if self.crytic_compile: if self.crytic_compile:
return self.crytic_compile.compiler_version.version return self.crytic_compile.compiler_version.version
return self._solc_version return self._solc_version
@property @property
def pragma_directives(self): def pragma_directives(self) -> List[Pragma]:
""" list(core.declarations.Pragma): Pragma directives.""" """ list(core.declarations.Pragma): Pragma directives."""
return self._pragma_directives return self._pragma_directives
@property @property
def import_directives(self): def import_directives(self) -> List[Import]:
""" list(core.declarations.Import): Import directives""" """ list(core.declarations.Import): Import directives"""
return self._import_directives return self._import_directives
@ -115,22 +127,23 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def contracts(self): def contracts(self) -> List[Contract]:
"""list(Contract): List of contracts.""" """list(Contract): List of contracts."""
return list(self._contracts.values()) return list(self._contracts.values())
@property @property
def contracts_derived(self): def contracts_derived(self) -> List[Contract]:
"""list(Contract): List of contracts that are derived and not inherited.""" """list(Contract): List of contracts that are derived and not inherited."""
inheritance = (x.inheritance for x in self.contracts) inheritance = (x.inheritance for x in self.contracts)
inheritance = [item for sublist in inheritance for item in sublist] inheritance = [item for sublist in inheritance for item in sublist]
return [c for c in self._contracts.values() if c not in inheritance] return [c for c in self._contracts.values() if c not in inheritance]
def contracts_as_dict(self): @property
def contracts_as_dict(self) -> Dict[str, Contract]:
"""list(dict(str: Contract): List of contracts as dict: name -> Contract.""" """list(dict(str: Contract): List of contracts as dict: name -> Contract."""
return self._contracts return self._contracts
def get_contract_from_name(self, contract_name): def get_contract_from_name(self, contract_name: Union[str, Constant]) -> Optional[Contract]:
""" """
Return a contract from a name Return a contract from a name
Args: Args:
@ -148,24 +161,24 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def functions(self): def functions(self) -> List[Function]:
return list(self._all_functions) return list(self._all_functions)
def add_function(self, func): def add_function(self, func: Function):
self._all_functions.add(func) self._all_functions.add(func)
@property @property
def modifiers(self): def modifiers(self) -> List[Modifier]:
return list(self._all_modifiers) return list(self._all_modifiers)
def add_modifier(self, modif): def add_modifier(self, modif: Modifier):
self._all_modifiers.add(modif) self._all_modifiers.add(modif)
@property @property
def functions_and_modifiers(self): def functions_and_modifiers(self) -> List[Function]:
return self.functions + self.modifiers return self.functions + self.modifiers
def _propagate_function_calls(self): def propagate_function_calls(self):
for f in self.functions_and_modifiers: for f in self.functions_and_modifiers:
for node in f.nodes: for node in f.nodes:
for ir in node.irs_ssa: for ir in node.irs_ssa:
@ -180,7 +193,7 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def state_variables(self): def state_variables(self) -> List[StateVariable]:
if self._all_state_variables is None: if self._all_state_variables is None:
state_variables = [c.state_variables for c in self.contracts] state_variables = [c.state_variables for c in self.contracts]
state_variables = [item for sublist in state_variables for item in sublist] state_variables = [item for sublist in state_variables for item in sublist]
@ -194,13 +207,13 @@ class Slither(Context):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def print_functions(self, d): def print_functions(self, d: str):
""" """
Export all the functions to dot files Export all the functions to dot files
""" """
for c in self.contracts: for c in self.contracts:
for f in c.functions: for f in c.functions:
f.cfg_to_dot(os.path.join(d, '{}.{}.dot'.format(c.name, f.name))) f.cfg_to_dot(os.path.join(d, "{}.{}.dot".format(c.name, f.name)))
# endregion # endregion
################################################################################### ###################################################################################
@ -209,44 +222,53 @@ class Slither(Context):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def relative_path_format(self, path): def relative_path_format(self, path: str) -> str:
""" """
Strip relative paths of "." and ".." Strip relative paths of "." and ".."
""" """
return path.split('..')[-1].strip('.').strip('/') return path.split("..")[-1].strip(".").strip("/")
def valid_result(self, r): def valid_result(self, r: Dict) -> bool:
''' """
Check if the result is valid Check if the result is valid
A result is invalid if: A result is invalid if:
- All its source paths belong to the source path filtered - All its source paths belong to the source path filtered
- Or a similar result was reported and saved during a previous run - Or a similar result was reported and saved during a previous run
- The --exclude-dependencies flag is set and results are only related to dependencies - The --exclude-dependencies flag is set and results are only related to dependencies
''' """
source_mapping_elements = [elem['source_mapping']['filename_absolute'] source_mapping_elements = [
for elem in r['elements'] if 'source_mapping' in elem] elem["source_mapping"]["filename_absolute"]
source_mapping_elements = map(lambda x: os.path.normpath(x) if x else x, source_mapping_elements) for elem in r["elements"]
if "source_mapping" in elem
]
source_mapping_elements = map(
lambda x: os.path.normpath(x) if x else x, source_mapping_elements
)
matching = False matching = False
for path in self._paths_to_filter: for path in self._paths_to_filter:
try: try:
if any(bool(re.search(self.relative_path_format(path), src_mapping)) if any(
for src_mapping in source_mapping_elements): bool(re.search(self.relative_path_format(path), src_mapping))
for src_mapping in source_mapping_elements
):
matching = True matching = True
break break
except re.error: except re.error:
logger.error(f'Incorrect regular expression for --filter-paths {path}.' logger.error(
'\nSlither supports the Python re format' f"Incorrect regular expression for --filter-paths {path}."
': https://docs.python.org/3/library/re.html') "\nSlither supports the Python re format"
": https://docs.python.org/3/library/re.html"
)
if r['elements'] and matching: if r["elements"] and matching:
return False return False
if r['elements'] and self._exclude_dependencies: if r["elements"] and self._exclude_dependencies:
return not all(element['source_mapping']['is_dependency'] for element in r['elements']) return not all(element["source_mapping"]["is_dependency"] for element in r["elements"])
if r['id'] in self._previous_results_ids: if r["id"] in self._previous_results_ids:
return False return False
# Conserve previous result filtering. This is conserved for compatibility, but is meant to be removed # Conserve previous result filtering. This is conserved for compatibility, but is meant to be removed
return not r['description'] in [pr['description'] for pr in self._previous_results] return not r["description"] in [pr["description"] for pr in self._previous_results]
def load_previous_results(self): def load_previous_results(self):
filename = self._previous_results_filename filename = self._previous_results_filename
@ -256,27 +278,29 @@ class Slither(Context):
self._previous_results = json.load(f) self._previous_results = json.load(f)
if self._previous_results: if self._previous_results:
for r in self._previous_results: for r in self._previous_results:
if 'id' in r: if "id" in r:
self._previous_results_ids.add(r['id']) self._previous_results_ids.add(r["id"])
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
logger.error(red('Impossible to decode {}. Consider removing the file'.format(filename))) logger.error(
red("Impossible to decode {}. Consider removing the file".format(filename))
)
def write_results_to_hide(self): def write_results_to_hide(self):
if not self._results_to_hide: if not self._results_to_hide:
return return
filename = self._previous_results_filename filename = self._previous_results_filename
with open(filename, 'w', encoding='utf8') as f: with open(filename, "w", encoding="utf8") as f:
results = self._results_to_hide + self._previous_results results = self._results_to_hide + self._previous_results
json.dump(results, f) json.dump(results, f)
def save_results_to_hide(self, results): def save_results_to_hide(self, results: List[Dict]):
self._results_to_hide += results self._results_to_hide += results
def add_path_to_filter(self, path): def add_path_to_filter(self, path: str):
''' """
Add path to filter Add path to filter
Path are used through direct comparison (no regex) Path are used through direct comparison (no regex)
''' """
self._paths_to_filter.add(path) self._paths_to_filter.add(path)
# endregion # endregion
@ -287,7 +311,7 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def crytic_compile(self): def crytic_compile(self) -> Optional[CryticCompile]:
return self._crytic_compile return self._crytic_compile
# endregion # endregion
@ -298,14 +322,13 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def generate_patches(self): def generate_patches(self) -> bool:
return self._generate_patches return self._generate_patches
@generate_patches.setter @generate_patches.setter
def generate_patches(self, p): def generate_patches(self, p: bool):
self._generate_patches = p self._generate_patches = p
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -314,10 +337,11 @@ class Slither(Context):
################################################################################### ###################################################################################
@property @property
def contract_name_collisions(self): def contract_name_collisions(self) -> Dict:
return self._contract_name_collisions return self._contract_name_collisions
@property @property
def contracts_with_missing_inheritance(self): def contracts_with_missing_inheritance(self) -> Set:
return self._contract_with_missing_inheritance return self._contract_with_missing_inheritance
# endregion # endregion

@ -1,20 +1,21 @@
from slither.core.variables.variable import Variable from typing import Optional
from slither.core.solidity_types.type import Type
from slither.core.expressions.expression import Expression
from slither.core.expressions import Literal from slither.core.expressions import Literal
from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type
from slither.visitors.expression.constants_folding import ConstantFolding from slither.visitors.expression.constants_folding import ConstantFolding
class ArrayType(Type):
class ArrayType(Type):
def __init__(self, t, length): def __init__(self, t, length):
assert isinstance(t, Type) assert isinstance(t, Type)
if length: if length:
if isinstance(length, int): if isinstance(length, int):
length = Literal(length, 'uint256') length = Literal(length, "uint256")
assert isinstance(length, Expression) assert isinstance(length, Expression)
super(ArrayType, self).__init__() super(ArrayType, self).__init__()
self._type = t self._type: Type = t
self._length = length self._length: Optional[Expression] = length
if length: if length:
if not isinstance(length, Literal): if not isinstance(length, Literal):
@ -25,18 +26,21 @@ class ArrayType(Type):
self._length_value = None self._length_value = None
@property @property
def type(self): def type(self) -> Type:
return self._type return self._type
@property @property
def length(self): def length(self) -> Optional[Expression]:
return self._length return self._length
@property
def lenght_value(self) -> Optional[Literal]:
return self._length_value
def __str__(self): def __str__(self):
if self._length: if self._length:
return str(self._type)+'[{}]'.format(str(self._length_value)) return str(self._type) + "[{}]".format(str(self._length_value))
return str(self._type)+'[]' return str(self._type) + "[]"
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, ArrayType): if not isinstance(other, ArrayType):

@ -1,69 +1,175 @@
import itertools import itertools
from typing import Optional
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
# see https://solidity.readthedocs.io/en/v0.4.24/miscellaneous.html?highlight=grammar # see https://solidity.readthedocs.io/en/v0.4.24/miscellaneous.html?highlight=grammar
Int = ['int', 'int8', 'int16', 'int24', 'int32', 'int40', 'int48', 'int56', 'int64', 'int72', 'int80', 'int88', 'int96', 'int104', 'int112', 'int120', 'int128', 'int136', 'int144', 'int152', 'int160', 'int168', 'int176', 'int184', 'int192', 'int200', 'int208', 'int216', 'int224', 'int232', 'int240', 'int248', 'int256'] Int = [
"int",
Uint = ['uint', 'uint8', 'uint16', 'uint24', 'uint32', 'uint40', 'uint48', 'uint56', 'uint64', 'uint72', 'uint80', 'uint88', 'uint96', 'uint104', 'uint112', 'uint120', 'uint128', 'uint136', 'uint144', 'uint152', 'uint160', 'uint168', 'uint176', 'uint184', 'uint192', 'uint200', 'uint208', 'uint216', 'uint224', 'uint232', 'uint240', 'uint248', 'uint256'] "int8",
"int16",
Byte = ['byte', 'bytes', 'bytes1', 'bytes2', 'bytes3', 'bytes4', 'bytes5', 'bytes6', 'bytes7', 'bytes8', 'bytes9', 'bytes10', 'bytes11', 'bytes12', 'bytes13', 'bytes14', 'bytes15', 'bytes16', 'bytes17', 'bytes18', 'bytes19', 'bytes20', 'bytes21', 'bytes22', 'bytes23', 'bytes24', 'bytes25', 'bytes26', 'bytes27', 'bytes28', 'bytes29', 'bytes30', 'bytes31', 'bytes32'] "int24",
"int32",
"int40",
"int48",
"int56",
"int64",
"int72",
"int80",
"int88",
"int96",
"int104",
"int112",
"int120",
"int128",
"int136",
"int144",
"int152",
"int160",
"int168",
"int176",
"int184",
"int192",
"int200",
"int208",
"int216",
"int224",
"int232",
"int240",
"int248",
"int256",
]
Uint = [
"uint",
"uint8",
"uint16",
"uint24",
"uint32",
"uint40",
"uint48",
"uint56",
"uint64",
"uint72",
"uint80",
"uint88",
"uint96",
"uint104",
"uint112",
"uint120",
"uint128",
"uint136",
"uint144",
"uint152",
"uint160",
"uint168",
"uint176",
"uint184",
"uint192",
"uint200",
"uint208",
"uint216",
"uint224",
"uint232",
"uint240",
"uint248",
"uint256",
]
Byte = [
"byte",
"bytes",
"bytes1",
"bytes2",
"bytes3",
"bytes4",
"bytes5",
"bytes6",
"bytes7",
"bytes8",
"bytes9",
"bytes10",
"bytes11",
"bytes12",
"bytes13",
"bytes14",
"bytes15",
"bytes16",
"bytes17",
"bytes18",
"bytes19",
"bytes20",
"bytes21",
"bytes22",
"bytes23",
"bytes24",
"bytes25",
"bytes26",
"bytes27",
"bytes28",
"bytes29",
"bytes30",
"bytes31",
"bytes32",
]
# https://solidity.readthedocs.io/en/v0.4.24/types.html#fixed-point-numbers # https://solidity.readthedocs.io/en/v0.4.24/types.html#fixed-point-numbers
M = list(range(8, 257, 8)) M = list(range(8, 257, 8))
N = list(range(0, 81)) N = list(range(0, 81))
MN = list(itertools.product(M,N)) MN = list(itertools.product(M, N))
Fixed = ['fixed{}x{}'.format(m,n) for (m,n) in MN] + ['fixed'] Fixed = ["fixed{}x{}".format(m, n) for (m, n) in MN] + ["fixed"]
Ufixed = ['ufixed{}x{}'.format(m,n) for (m,n) in MN] + ['ufixed'] Ufixed = ["ufixed{}x{}".format(m, n) for (m, n) in MN] + ["ufixed"]
ElementaryTypeName = ['address', 'bool', 'string', 'var'] + Int + Uint + Byte + Fixed + Ufixed ElementaryTypeName = ["address", "bool", "string", "var"] + Int + Uint + Byte + Fixed + Ufixed
class NonElementaryType(Exception): pass
class ElementaryType(Type): class NonElementaryType(Exception):
pass
class ElementaryType(Type):
def __init__(self, t): def __init__(self, t):
if t not in ElementaryTypeName: if t not in ElementaryTypeName:
raise NonElementaryType raise NonElementaryType
super(ElementaryType, self).__init__() super(ElementaryType, self).__init__()
if t == 'uint': if t == "uint":
t = 'uint256' t = "uint256"
elif t == 'int': elif t == "int":
t = 'int256' t = "int256"
elif t == 'byte': elif t == "byte":
t = 'bytes1' t = "bytes1"
self._type = t self._type = t
@property @property
def type(self): def type(self) -> str:
return self._type return self._type
@property @property
def name(self): def name(self) -> str:
return self.type return self.type
@property @property
def size(self): def size(self) -> Optional[int]:
''' """
Return the size in bits Return the size in bits
Return None if the size is not known Return None if the size is not known
Returns: Returns:
int int
''' """
t = self._type t = self._type
if t.startswith('uint'): if t.startswith("uint"):
return int(t[len('uint'):]) return int(t[len("uint") :])
if t.startswith('int'): if t.startswith("int"):
return int(t[len('int'):]) return int(t[len("int") :])
if t == 'bool': if t == "bool":
return int(8) return int(8)
if t == 'address': if t == "address":
return int(160) return int(160)
if t.startswith('bytes'): if t.startswith("bytes"):
return int(t[len('bytes'):]) return int(t[len("bytes") :])
return None return None
def __str__(self): def __str__(self):
@ -76,4 +182,3 @@ class ElementaryType(Type):
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))

@ -1,25 +1,29 @@
from typing import List
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.function_type_variable import FunctionTypeVariable
class FunctionType(Type):
def __init__(self, params, return_values): class FunctionType(Type):
def __init__(
self, params: List[FunctionTypeVariable], return_values: List[FunctionTypeVariable]
):
assert all(isinstance(x, FunctionTypeVariable) for x in params) assert all(isinstance(x, FunctionTypeVariable) for x in params)
assert all(isinstance(x, FunctionTypeVariable) for x in return_values) assert all(isinstance(x, FunctionTypeVariable) for x in return_values)
super(FunctionType, self).__init__() super(FunctionType, self).__init__()
self._params = params self._params: List[FunctionTypeVariable] = params
self._return_values = return_values self._return_values: List[FunctionTypeVariable] = return_values
@property @property
def params(self): def params(self) -> List[FunctionTypeVariable]:
return self._params return self._params
@property @property
def return_values(self): def return_values(self) -> List[FunctionTypeVariable]:
return self._return_values return self._return_values
@property @property
def return_type(self): def return_type(self) -> List[Type]:
return [x.type for x in self.return_values] return [x.type for x in self.return_values]
def __str__(self): def __str__(self):
@ -28,33 +32,31 @@ class FunctionType(Type):
params = ",".join([str(x.type) for x in self._params]) params = ",".join([str(x.type) for x in self._params])
return_values = ",".join([str(x.type) for x in self._return_values]) return_values = ",".join([str(x.type) for x in self._return_values])
if return_values: if return_values:
return 'function({}) returns({})'.format(params, return_values) return "function({}) returns({})".format(params, return_values)
return 'function({})'.format(params) return "function({})".format(params)
@property @property
def parameters_signature(self): def parameters_signature(self) -> str:
''' """
Return the parameters signature(without the return statetement) Return the parameters signature(without the return statetement)
''' """
# Use x.type # Use x.type
# x.name may be empty # x.name may be empty
params = ",".join([str(x.type) for x in self._params]) params = ",".join([str(x.type) for x in self._params])
return '({})'.format(params) return "({})".format(params)
@property @property
def signature(self): def signature(self) -> str:
''' """
Return the signature(with the return statetement if it exists) Return the signature(with the return statetement if it exists)
''' """
# Use x.type # Use x.type
# x.name may be empty # x.name may be empty
params = ",".join([str(x.type) for x in self._params]) params = ",".join([str(x.type) for x in self._params])
return_values = ",".join([str(x.type) for x in self._return_values]) return_values = ",".join([str(x.type) for x in self._return_values])
if return_values: if return_values:
return '({}) returns({})'.format(params, return_values) return "({}) returns({})".format(params, return_values)
return '({})'.format(params) return "({})".format(params)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, FunctionType): if not isinstance(other, FunctionType):

@ -1,7 +1,7 @@
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
class MappingType(Type):
class MappingType(Type):
def __init__(self, type_from, type_to): def __init__(self, type_from, type_to):
assert isinstance(type_from, Type) assert isinstance(type_from, Type)
assert isinstance(type_to, Type) assert isinstance(type_to, Type)
@ -10,15 +10,15 @@ class MappingType(Type):
self._to = type_to self._to = type_to
@property @property
def type_from(self): def type_from(self) -> Type:
return self._from return self._from
@property @property
def type_to(self): def type_to(self) -> Type:
return self._to return self._to
def __str__(self): def __str__(self):
return 'mapping({} => {})'.format(str(self._from), str(self._to)) return "mapping({} => {})".format(str(self._from), str(self._to))
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, MappingType): if not isinstance(other, MappingType):
@ -27,4 +27,3 @@ class MappingType(Type):
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))

@ -1,3 +1,5 @@
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
class Type(SourceMapping): pass
class Type(SourceMapping):
pass

@ -1,21 +1,27 @@
from typing import TYPE_CHECKING
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
if TYPE_CHECKING:
from slither.core.declarations.contract import Contract
# Use to model the Type(X) function, which returns an undefined type # Use to model the Type(X) function, which returns an undefined type
# https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information
class TypeInformation(Type): class TypeInformation(Type):
def __init__(self, c): def __init__(self, c):
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
assert isinstance(c, (Contract)) assert isinstance(c, Contract)
super(TypeInformation, self).__init__() super(TypeInformation, self).__init__()
self._type = c self._type = c
@property @property
def type(self): def type(self) -> "Contract":
return self._type return self._type
def __str__(self): def __str__(self):
return f'type({self.type.name})' return f"type({self.type.name})"
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, TypeInformation): if not isinstance(other, TypeInformation):

@ -1,8 +1,14 @@
from typing import Union, TYPE_CHECKING
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
if TYPE_CHECKING:
from slither.core.declarations.structure import Structure
from slither.core.declarations.enum import Enum
from slither.core.declarations.contract import Contract
class UserDefinedType(Type): class UserDefinedType(Type):
def __init__(self, t): def __init__(self, t):
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
@ -13,7 +19,7 @@ class UserDefinedType(Type):
self._type = t self._type = t
@property @property
def type(self): def type(self) -> Union["Contract", "Enum", "Structure"]:
return self._type return self._type
def __str__(self): def __str__(self):
@ -21,7 +27,7 @@ class UserDefinedType(Type):
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
if isinstance(self.type, (Enum, Structure)): if isinstance(self.type, (Enum, Structure)):
return str(self.type.contract)+'.'+str(self.type.name) return str(self.type.contract) + "." + str(self.type.name)
return str(self.type.name) return str(self.type.name)
def __eq__(self, other): def __eq__(self, other):
@ -29,7 +35,5 @@ class UserDefinedType(Type):
return False return False
return self.type == other.type return self.type == other.type
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))

@ -1,16 +1,17 @@
import re import re
from typing import Dict, Union, Optional
from slither.core.context.context import Context from slither.core.context.context import Context
class SourceMapping(Context): class SourceMapping(Context):
def __init__(self): def __init__(self):
super(SourceMapping, self).__init__() super(SourceMapping, self).__init__()
self._source_mapping = None # TODO create a namedtuple for the source mapping rather than a dict
self._source_mapping: Optional[Dict] = None
@property @property
def source_mapping(self): def source_mapping(self) -> Optional[Dict]:
return self._source_mapping return self._source_mapping
@staticmethod @staticmethod
@ -21,7 +22,7 @@ class SourceMapping(Context):
Not done in an efficient way Not done in an efficient way
""" """
source_code = source_code.encode('utf-8') source_code = source_code.encode("utf-8")
total_length = len(source_code) total_length = len(source_code)
source_code = source_code.splitlines(True) source_code = source_code.splitlines(True)
counter = 0 counter = 0
@ -38,7 +39,11 @@ class SourceMapping(Context):
# Determine our column numbers. # Determine our column numbers.
if starting_column is None and counter + line_length > start: if starting_column is None and counter + line_length > start:
starting_column = (start - counter) + 1 starting_column = (start - counter) + 1
if starting_column is not None and ending_column is None and counter + line_length > start + length: if (
starting_column is not None
and ending_column is None
and counter + line_length > start + length
):
ending_column = ((start + length) - counter) + 1 ending_column = ((start + length) - counter) + 1
# Advance the current position counter, and determine line numbers. # Advance the current position counter, and determine line numbers.
@ -50,19 +55,19 @@ class SourceMapping(Context):
if counter > start + length: if counter > start + length:
break break
return (lines, starting_column, ending_column) return lines, starting_column, ending_column
@staticmethod @staticmethod
def _convert_source_mapping(offset, slither): def _convert_source_mapping(offset: str, slither):
''' """
Convert a text offset to a real offset Convert a text offset to a real offset
see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings
Returns: Returns:
(dict): {'start':0, 'length':0, 'filename': 'file.sol'} (dict): {'start':0, 'length':0, 'filename': 'file.sol'}
''' """
sourceUnits = slither.source_units sourceUnits = slither.source_units
position = re.findall('([0-9]*):([0-9]*):([-]?[0-9]*)', offset) position = re.findall("([0-9]*):([0-9]*):([-]?[0-9]*)", offset)
if len(position) != 1: if len(position) != 1:
return {} return {}
@ -72,7 +77,7 @@ class SourceMapping(Context):
f = int(f) f = int(f)
if f not in sourceUnits: if f not in sourceUnits:
return {'start':s, 'length':l} return {"start": s, "length": l}
filename_used = sourceUnits[f] filename_used = sourceUnits[f]
filename_absolute = None filename_absolute = None
filename_relative = None filename_relative = None
@ -91,7 +96,10 @@ class SourceMapping(Context):
is_dependency = slither.crytic_compile.is_dependency(filename_absolute) is_dependency = slither.crytic_compile.is_dependency(filename_absolute)
if filename_absolute in slither.source_code or filename_absolute in slither.crytic_compile.src_content: if (
filename_absolute in slither.source_code
or filename_absolute in slither.crytic_compile.src_content
):
filename = filename_absolute filename = filename_absolute
elif filename_relative in slither.source_code: elif filename_relative in slither.source_code:
filename = filename_relative filename = filename_relative
@ -104,51 +112,47 @@ class SourceMapping(Context):
if slither.crytic_compile and filename in slither.crytic_compile.src_content: if slither.crytic_compile and filename in slither.crytic_compile.src_content:
source_code = slither.crytic_compile.src_content[filename] source_code = slither.crytic_compile.src_content[filename]
(lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l)
s,
l)
elif filename in slither.source_code: elif filename in slither.source_code:
source_code = slither.source_code[filename] source_code = slither.source_code[filename]
(lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l)
s,
l)
else: else:
(lines, starting_column, ending_column) = ([], None, None) (lines, starting_column, ending_column) = ([], None, None)
return {'start':s, return {
'length':l, "start": s,
'filename_used': filename_used, "length": l,
'filename_relative': filename_relative, "filename_used": filename_used,
'filename_absolute': filename_absolute, "filename_relative": filename_relative,
'filename_short': filename_short, "filename_absolute": filename_absolute,
'is_dependency': is_dependency, "filename_short": filename_short,
'lines' : lines, "is_dependency": is_dependency,
'starting_column': starting_column, "lines": lines,
'ending_column': ending_column "starting_column": starting_column,
"ending_column": ending_column,
} }
def set_offset(self, offset, slither): def set_offset(self, offset: Union[Dict, str], slither):
if isinstance(offset, dict): if isinstance(offset, dict):
self._source_mapping = offset self._source_mapping = offset
else: else:
self._source_mapping = self._convert_source_mapping(offset, slither) self._source_mapping = self._convert_source_mapping(offset, slither)
def _get_lines_str(self, line_descr=""): def _get_lines_str(self, line_descr=""):
lines = self.source_mapping.get('lines', None) lines = self.source_mapping.get("lines", None)
if not lines: if not lines:
lines = '' lines = ""
elif len(lines) == 1: elif len(lines) == 1:
lines = '#{}{}'.format(line_descr, lines[0]) lines = "#{}{}".format(line_descr, lines[0])
else: else:
lines = '#{}{}-{}{}'.format(line_descr, lines[0], line_descr, lines[-1]) lines = "#{}{}-{}{}".format(line_descr, lines[0], line_descr, lines[-1])
return lines return lines
def source_mapping_to_markdown(self, markdown_root): def source_mapping_to_markdown(self, markdown_root: str) -> str:
lines = self._get_lines_str(line_descr="L") lines = self._get_lines_str(line_descr="L")
return f'{markdown_root}{self.source_mapping["filename_relative"]}{lines}' return f'{markdown_root}{self.source_mapping["filename_relative"]}{lines}'
@property @property
def source_mapping_str(self): def source_mapping_str(self) -> str:
lines = self._get_lines_str() lines = self._get_lines_str()
return f'{self.source_mapping["filename_short"]}{lines}' return f'{self.source_mapping["filename_short"]}{lines}'

@ -1,16 +1,20 @@
from .variable import Variable from .variable import Variable
from slither.core.children.child_event import ChildEvent from slither.core.children.child_event import ChildEvent
class EventVariable(ChildEvent, Variable): class EventVariable(ChildEvent, Variable):
def __init__(self): def __init__(self):
super(EventVariable, self).__init__() super(EventVariable, self).__init__()
self._indexed = False self._indexed = False
@property @property
def indexed(self): def indexed(self) -> bool:
""" """
Indicates whether the event variable is indexed in the bloom filter. Indicates whether the event variable is indexed in the bloom filter.
:return: Returns True if the variable is indexed in bloom filter, False otherwise. :return: Returns True if the variable is indexed in bloom filter, False otherwise.
""" """
return self._indexed return self._indexed
@indexed.setter
def indexed(self, is_indexed: bool):
self._indexed = is_indexed

@ -8,5 +8,6 @@
from .variable import Variable from .variable import Variable
class FunctionTypeVariable(Variable): pass
class FunctionTypeVariable(Variable):
pass

@ -1,45 +1,51 @@
from typing import Optional
from .variable import Variable from .variable import Variable
from slither.core.children.child_function import ChildFunction from slither.core.children.child_function import ChildFunction
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.array_type import ArrayType
from slither.core.solidity_types.mapping_type import MappingType from slither.core.solidity_types.mapping_type import MappingType
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
class LocalVariable(ChildFunction, Variable): class LocalVariable(ChildFunction, Variable):
def __init__(self): def __init__(self):
super(LocalVariable, self).__init__() super(LocalVariable, self).__init__()
self._location = None self._location: Optional[str] = None
def set_location(self, loc): def set_location(self, loc: str):
self._location = loc self._location = loc
@property @property
def location(self): def location(self) -> Optional[str]:
''' """
Variable Location Variable Location
Can be storage/memory or default Can be storage/memory or default
Returns: Returns:
(str) (str)
''' """
return self._location return self._location
@property @property
def is_storage(self): def is_scalar(self) -> bool:
return isinstance(self.type, ElementaryType) and not self.is_storage
@property
def is_storage(self) -> bool:
""" """
Return true if the variable is located in storage Return true if the variable is located in storage
See https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location See https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location
Returns: Returns:
(bool) (bool)
""" """
if self.location == 'memory': if self.location == "memory":
return False return False
# Use by slithIR SSA # Use by slithIR SSA
if self.location == 'reference_to_storage': if self.location == "reference_to_storage":
return False return False
if self.location == 'storage': if self.location == "storage":
return True return True
if isinstance(self.type, (ArrayType, MappingType)): if isinstance(self.type, (ArrayType, MappingType)):
@ -51,7 +57,5 @@ class LocalVariable(ChildFunction, Variable):
return False return False
@property @property
def canonical_name(self): def canonical_name(self) -> str:
return '{}.{}'.format(self.function.canonical_name, self.name) return "{}.{}".format(self.function.canonical_name, self.name)

@ -1,5 +1,8 @@
from typing import Optional
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
class LocalVariableInitFromTuple(LocalVariable): class LocalVariableInitFromTuple(LocalVariable):
""" """
Use on this pattern: Use on this pattern:
@ -12,8 +15,12 @@ class LocalVariableInitFromTuple(LocalVariable):
def __init__(self): def __init__(self):
super(LocalVariableInitFromTuple, self).__init__() super(LocalVariableInitFromTuple, self).__init__()
self._tuple_index = None self._tuple_index: Optional[int] = None
@property @property
def tuple_index(self): def tuple_index(self) -> Optional[int]:
return self._tuple_index return self._tuple_index
@tuple_index.setter
def tuple_index(self, idx: int):
self._tuple_index = idx

@ -1,14 +1,20 @@
from typing import Optional, TYPE_CHECKING, Tuple, List
from .variable import Variable from .variable import Variable
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.utils.type import export_nested_types_from_variable from slither.utils.type import export_nested_types_from_variable
class StateVariable(ChildContract, Variable): if TYPE_CHECKING:
from ..cfg.node import Node
from ..declarations import Contract
class StateVariable(ChildContract, Variable):
def __init__(self): def __init__(self):
super(StateVariable, self).__init__() super(StateVariable, self).__init__()
self._node_initialization = None self._node_initialization: Optional["Node"] = None
def is_declared_by(self, contract): def is_declared_by(self, contract: "Contract") -> bool:
""" """
Check if the element is declared by the contract Check if the element is declared by the contract
:param contract: :param contract:
@ -16,7 +22,6 @@ class StateVariable(ChildContract, Variable):
""" """
return self.contract == contract return self.contract == contract
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region Signature # region Signature
@ -24,21 +29,21 @@ class StateVariable(ChildContract, Variable):
################################################################################### ###################################################################################
@property @property
def signature(self): def signature(self) -> Tuple[str, List[str], str]:
""" """
Return the signature of the state variable as a function signature Return the signature of the state variable as a function signature
:return: (str, list(str), list(str)), as (name, list parameters type, list return values type) :return: (str, list(str), list(str)), as (name, list parameters type, list return values type)
""" """
return self.name, [str(x) for x in export_nested_types_from_variable(self)], self.type return self.name, [str(x) for x in export_nested_types_from_variable(self)], str(self.type)
@property @property
def signature_str(self): def signature_str(self) -> str:
""" """
Return the signature of the state variable as a function signature Return the signature of the state variable as a function signature
:return: str: func_name(type1,type2) returns(type3) :return: str: func_name(type1,type2) returns(type3)
""" """
name, parameters, returnVars = self.signature name, parameters, returnVars = self.signature
return name+'('+','.join(parameters)+') returns('+','.join(returnVars)+')' return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")"
# endregion # endregion
################################################################################### ###################################################################################
@ -48,18 +53,18 @@ class StateVariable(ChildContract, Variable):
################################################################################### ###################################################################################
@property @property
def canonical_name(self): def canonical_name(self) -> str:
return '{}.{}'.format(self.contract.name, self.name) return "{}.{}".format(self.contract.name, self.name)
@property @property
def full_name(self): def full_name(self) -> str:
""" """
Return the name of the state variable as a function signaure Return the name of the state variable as a function signaure
str: func_name(type1,type2) str: func_name(type1,type2)
:return: the function signature without the return values :return: the function signature without the return values
""" """
name, parameters, _ = self.signature name, parameters, _ = self.signature
return name+'('+','.join(parameters)+')' return name + "(" + ",".join(parameters) + ")"
# endregion # endregion
################################################################################### ###################################################################################
@ -69,7 +74,7 @@ class StateVariable(ChildContract, Variable):
################################################################################### ###################################################################################
@property @property
def node_initialization(self): def node_initialization(self) -> Optional["Node"]:
""" """
Node for the state variable initalization Node for the state variable initalization
:return: :return:
@ -80,8 +85,6 @@ class StateVariable(ChildContract, Variable):
def node_initialization(self, node_initialization): def node_initialization(self, node_initialization):
self._node_initialization = node_initialization self._node_initialization = node_initialization
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -1,5 +1,6 @@
from .variable import Variable from .variable import Variable
from slither.core.children.child_structure import ChildStructure from slither.core.children.child_structure import ChildStructure
class StructureVariable(ChildStructure, Variable): pass
class StructureVariable(ChildStructure, Variable):
pass

@ -1,24 +1,32 @@
""" """
Variable module Variable module
""" """
from typing import Optional, TYPE_CHECKING, List, Union
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
class Variable(SourceMapping): if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
class Variable(SourceMapping):
def __init__(self): def __init__(self):
super(Variable, self).__init__() super(Variable, self).__init__()
self._name = None self._name: Optional[str] = None
self._initial_expression = None self._initial_expression: Optional["Expression"] = None
self._type = None self._type: Optional[Type] = None
self._initialized = None self._initialized: Optional[bool] = None
self._visibility = None self._visibility: Optional[str] = None
self._is_constant = False self._is_constant = False
@property @property
def expression(self): def is_scalar(self) -> bool:
return isinstance(self.type, ElementaryType)
@property
def expression(self) -> Optional["Expression"]:
""" """
Expression: Expression of the node (if initialized) Expression: Expression of the node (if initialized)
Initial expression may be different than the expression of the node Initial expression may be different than the expression of the node
@ -32,25 +40,33 @@ class Variable(SourceMapping):
""" """
return self._initial_expression return self._initial_expression
@expression.setter
def expression(self, expr: "Expression"):
self._initial_expression = expr
@property @property
def initialized(self): def initialized(self) -> Optional[bool]:
""" """
boolean: True if the variable is initialized at construction boolean: True if the variable is initialized at construction
""" """
return self._initialized return self._initialized
@initialized.setter
def initialized(self, is_init: bool):
self._initialized = is_init
@property @property
def uninitialized(self): def uninitialized(self) -> bool:
""" """
boolean: True if the variable is not initialized boolean: True if the variable is not initialized
""" """
return not self._initialized return not self._initialized
@property @property
def name(self): def name(self) -> str:
''' """
str: variable name str: variable name
''' """
return self._name return self._name
@name.setter @name.setter
@ -58,20 +74,32 @@ class Variable(SourceMapping):
self._name = name self._name = name
@property @property
def type(self): def type(self) -> Optional[Union[Type, List[Type]]]:
return self._type return self._type
@type.setter
def type(self, types: Union[Type, List[Type]]):
self._type = types
@property @property
def is_constant(self): def is_constant(self) -> bool:
return self._is_constant return self._is_constant
@is_constant.setter
def is_constant(self, is_cst: bool):
self._is_constant = is_cst
@property @property
def visibility(self): def visibility(self) -> Optional[str]:
''' """
str: variable visibility str: variable visibility
''' """
return self._visibility return self._visibility
@visibility.setter
def visibility(self, v: str):
self._visibility = v
def set_type(self, t): def set_type(self, t):
if isinstance(t, str): if isinstance(t, str):
t = ElementaryType(t) t = ElementaryType(t)
@ -80,25 +108,21 @@ class Variable(SourceMapping):
@property @property
def function_name(self): def function_name(self):
''' """
Return the name of the variable as a function signature Return the name of the variable as a function signature
:return: :return:
''' """
from slither.core.solidity_types import ArrayType, MappingType from slither.core.solidity_types import ArrayType, MappingType
from slither.utils.type import export_nested_types_from_variable
variable_getter_args = "" variable_getter_args = ""
if type(self.type) is ArrayType: return_type = self.type
length = 0 assert return_type
v = self
while type(v.type) is ArrayType: if isinstance(return_type, (ArrayType, MappingType)):
length += 1 variable_getter_args = ",".join(map(str, export_nested_types_from_variable(self)))
v = v.type
variable_getter_args = ','.join(["uint256"] * length)
elif type(self.type) is MappingType:
variable_getter_args = self.type.type_from
return f"{self.name}({variable_getter_args})" return f"{self.name}({variable_getter_args})"
def __str__(self): def __str__(self):
return self._name return self._name

@ -10,7 +10,7 @@ from slither.core.cfg.node import Node
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityFunction, SolidityVariable from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityFunction, SolidityVariable
from slither.core.expressions import NewContract from slither.core.expressions import NewContract
from slither.core.slither_core import Slither from slither.core.slither_core import SlitherCore
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
@ -26,7 +26,7 @@ def _get_name(f: Function) -> str:
return f.solidity_signature return f.solidity_signature
def _extract_payable(slither: Slither) -> Dict[str, List[str]]: def _extract_payable(slither: SlitherCore) -> Dict[str, List[str]]:
ret: Dict[str, List[str]] = {} ret: Dict[str, List[str]] = {}
for contract in slither.contracts: for contract in slither.contracts:
payable_functions = [_get_name(f) for f in contract.functions_entry_points if f.payable] payable_functions = [_get_name(f) for f in contract.functions_entry_points if f.payable]
@ -35,7 +35,7 @@ def _extract_payable(slither: Slither) -> Dict[str, List[str]]:
return ret return ret
def _extract_solidity_variable_usage(slither: Slither, sol_var: SolidityVariable) -> Dict[str, List[str]]: def _extract_solidity_variable_usage(slither: SlitherCore, sol_var: SolidityVariable) -> Dict[str, List[str]]:
ret: Dict[str, List[str]] = {} ret: Dict[str, List[str]] = {}
for contract in slither.contracts: for contract in slither.contracts:
functions_using_sol_var = [] functions_using_sol_var = []
@ -53,7 +53,7 @@ def _is_constant(f: Function) -> bool:
""" """
Heuristic: Heuristic:
- If view/pure with Solidity >= 0.4 -> Return true - If view/pure with Solidity >= 0.4 -> Return true
- If it contains assembly -> Return false (Slither doesn't analyze asm) - If it contains assembly -> Return false (SlitherCore doesn't analyze asm)
- Otherwise check for the rules from - Otherwise check for the rules from
https://solidity.readthedocs.io/en/v0.5.0/contracts.html?highlight=pure#view-functions https://solidity.readthedocs.io/en/v0.5.0/contracts.html?highlight=pure#view-functions
with an exception: internal dynamic call are not correctly handled, so we consider them as non-constant with an exception: internal dynamic call are not correctly handled, so we consider them as non-constant
@ -93,7 +93,7 @@ def _is_constant(f: Function) -> bool:
return True return True
def _extract_constant_functions(slither: Slither) -> Dict[str, List[str]]: def _extract_constant_functions(slither: SlitherCore) -> Dict[str, List[str]]:
ret: Dict[str, List[str]] = {} ret: Dict[str, List[str]] = {}
for contract in slither.contracts: for contract in slither.contracts:
cst_functions = [_get_name(f) for f in contract.functions_entry_points if _is_constant(f)] cst_functions = [_get_name(f) for f in contract.functions_entry_points if _is_constant(f)]
@ -103,7 +103,7 @@ def _extract_constant_functions(slither: Slither) -> Dict[str, List[str]]:
return ret return ret
def _extract_assert(slither: Slither) -> Dict[str, List[str]]: def _extract_assert(slither: SlitherCore) -> Dict[str, List[str]]:
ret: Dict[str, List[str]] = {} ret: Dict[str, List[str]] = {}
for contract in slither.contracts: for contract in slither.contracts:
functions_using_assert = [] functions_using_assert = []
@ -145,7 +145,7 @@ def _extract_constants_from_irs(irs: List[Operation],
if isinstance(ir, Binary): if isinstance(ir, Binary):
for r in ir.read: for r in ir.read:
if isinstance(r, Constant): if isinstance(r, Constant):
all_cst_used_in_binary[BinaryType.str(ir.type)].append(ConstantValue(str(r.value), str(r.type))) all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type)))
if isinstance(ir, TypeConversion): if isinstance(ir, TypeConversion):
if isinstance(ir.variable, Constant): if isinstance(ir.variable, Constant):
all_cst_used.append(ConstantValue(str(ir.variable.value), str(ir.type))) all_cst_used.append(ConstantValue(str(ir.variable.value), str(ir.type)))
@ -169,7 +169,7 @@ def _extract_constants_from_irs(irs: List[Operation],
context_explored) context_explored)
def _extract_constants(slither: Slither) -> Tuple[Dict[str, Dict[str, List]], Dict[str, Dict[str, Dict]]]: def _extract_constants(slither: SlitherCore) -> Tuple[Dict[str, Dict[str, List]], Dict[str, Dict[str, Dict]]]:
# contract -> function -> [ {"value": value, "type": type} ] # contract -> function -> [ {"value": value, "type": type} ]
ret_cst_used: Dict[str, Dict[str, List[ConstantValue]]] = defaultdict(dict) ret_cst_used: Dict[str, Dict[str, List[ConstantValue]]] = defaultdict(dict)
# contract -> function -> binary_operand -> [ {"value": value, "type": type ] # contract -> function -> binary_operand -> [ {"value": value, "type": type ]
@ -196,7 +196,7 @@ def _extract_constants(slither: Slither) -> Tuple[Dict[str, Dict[str, List]], Di
return ret_cst_used, ret_cst_used_in_binary return ret_cst_used, ret_cst_used_in_binary
def _extract_function_relations(slither: Slither) -> Dict[str, Dict[str, Dict[str, List[str]]]]: def _extract_function_relations(slither: SlitherCore) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
# contract -> function -> [functions] # contract -> function -> [functions]
ret: Dict[str, Dict[str, Dict[str, List[str]]]] = defaultdict(dict) ret: Dict[str, Dict[str, Dict[str, List[str]]]] = defaultdict(dict)
for contract in slither.contracts: for contract in slither.contracts:
@ -217,7 +217,7 @@ def _extract_function_relations(slither: Slither) -> Dict[str, Dict[str, Dict[st
return ret return ret
def _have_external_calls(slither: Slither) -> Dict[str, List[str]]: def _have_external_calls(slither: SlitherCore) -> Dict[str, List[str]]:
""" """
Detect the functions with external calls Detect the functions with external calls
:param slither: :param slither:
@ -233,7 +233,7 @@ def _have_external_calls(slither: Slither) -> Dict[str, List[str]]:
return ret return ret
def _use_balance(slither: Slither) -> Dict[str, List[str]]: def _use_balance(slither: SlitherCore) -> Dict[str, List[str]]:
""" """
Detect the functions with external calls Detect the functions with external calls
:param slither: :param slither:
@ -250,7 +250,7 @@ def _use_balance(slither: Slither) -> Dict[str, List[str]]:
return ret return ret
def _call_a_parameter(slither: Slither) -> Dict[str, List[Dict]]: def _call_a_parameter(slither: SlitherCore) -> Dict[str, List[Dict]]:
""" """
Detect the functions with external calls Detect the functions with external calls
:param slither: :param slither:

@ -1,18 +1,13 @@
import logging import logging
import os import os
import subprocess
import sys
import glob
import json
import platform
from crytic_compile import CryticCompile, InvalidCompilation from crytic_compile import CryticCompile, InvalidCompilation
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from .solc_parsing.exceptions import VariableNotFound from .core.slither_core import SlitherCore
from .solc_parsing.slitherSolc import SlitherSolc
from .exceptions import SlitherError from .exceptions import SlitherError
from .solc_parsing.slitherSolc import SlitherSolc
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
logging.basicConfig() logging.basicConfig()
@ -21,10 +16,9 @@ logger_detector = logging.getLogger("Detectors")
logger_printer = logging.getLogger("Printers") logger_printer = logging.getLogger("Printers")
class Slither(SlitherSolc): class Slither(SlitherCore):
def __init__(self, target, **kwargs): def __init__(self, target, **kwargs):
''' """
Args: Args:
target (str | list(json) | CryticCompile) target (str | list(json) | CryticCompile)
Keyword Args: Keyword Args:
@ -46,14 +40,16 @@ class Slither(SlitherSolc):
embark_ignore_compile (bool): do not run embark build (default False) embark_ignore_compile (bool): do not run embark build (default False)
embark_overwrite_config (bool): overwrite original config file (default false) embark_overwrite_config (bool): overwrite original config file (default false)
''' """
super().__init__()
self._parser: SlitherSolc # This could be another parser, like SlitherVyper, interface needs to be determined
# list of files provided (see --splitted option) # list of files provided (see --splitted option)
if isinstance(target, list): if isinstance(target, list):
self._init_from_list(target) self._init_from_list(target)
elif isinstance(target, str) and target.endswith('.json'): elif isinstance(target, str) and target.endswith(".json"):
self._init_from_raw_json(target) self._init_from_raw_json(target)
else: else:
super(Slither, self).__init__('') self._parser = SlitherSolc("", self)
try: try:
if isinstance(target, CryticCompile): if isinstance(target, CryticCompile):
crytic_compile = target crytic_compile = target
@ -61,54 +57,55 @@ class Slither(SlitherSolc):
crytic_compile = CryticCompile(target, **kwargs) crytic_compile = CryticCompile(target, **kwargs)
self._crytic_compile = crytic_compile self._crytic_compile = crytic_compile
except InvalidCompilation as e: except InvalidCompilation as e:
raise SlitherError('Invalid compilation: \n'+str(e)) raise SlitherError("Invalid compilation: \n" + str(e))
for path, ast in crytic_compile.asts.items(): for path, ast in crytic_compile.asts.items():
self._parse_contracts_from_loaded_json(ast, path) self._parser.parse_contracts_from_loaded_json(ast, path)
self._add_source_code(path) self.add_source_code(path)
if kwargs.get('generate_patches', False): if kwargs.get("generate_patches", False):
self.generate_patches = True self.generate_patches = True
self._markdown_root = kwargs.get('markdown_root', "") self._markdown_root = kwargs.get("markdown_root", "")
self._detectors = [] self._detectors = []
self._printers = [] self._printers = []
filter_paths = kwargs.get('filter_paths', []) filter_paths = kwargs.get("filter_paths", [])
for p in filter_paths: for p in filter_paths:
self.add_path_to_filter(p) self.add_path_to_filter(p)
self._exclude_dependencies = kwargs.get('exclude_dependencies', False) self._exclude_dependencies = kwargs.get("exclude_dependencies", False)
triage_mode = kwargs.get('triage_mode', False) triage_mode = kwargs.get("triage_mode", False)
self._triage_mode = triage_mode self._triage_mode = triage_mode
self._analyze_contracts() self._parser.analyze_contracts()
def _init_from_raw_json(self, filename): def _init_from_raw_json(self, filename):
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise SlitherError('{} does not exist (are you in the correct directory?)'.format(filename)) raise SlitherError(
assert filename.endswith('json') "{} does not exist (are you in the correct directory?)".format(filename)
with open(filename, encoding='utf8') as astFile: )
assert filename.endswith("json")
with open(filename, encoding="utf8") as astFile:
stdout = astFile.read() stdout = astFile.read()
if not stdout: if not stdout:
raise SlitherError('Empty AST file: %s', filename) raise SlitherError("Empty AST file: %s", filename)
contracts_json = stdout.split('\n=') contracts_json = stdout.split("\n=")
super(Slither, self).__init__(filename) self._parser = SlitherSolc(filename, self)
for c in contracts_json: for c in contracts_json:
self._parse_contracts_from_json(c) self._parser.parse_contracts_from_json(c)
def _init_from_list(self, contract): def _init_from_list(self, contract):
super(Slither, self).__init__('') self._parser = SlitherSolc("", self)
for c in contract: for c in contract:
if 'absolutePath' in c: if "absolutePath" in c:
path = c['absolutePath'] path = c["absolutePath"]
else: else:
path = c['attributes']['absolutePath'] path = c["attributes"]["absolutePath"]
self._parse_contracts_from_loaded_json(c, path) self._parser.parse_contracts_from_loaded_json(c, path)
@property @property
def detectors(self): def detectors(self):
@ -138,7 +135,7 @@ class Slither(SlitherSolc):
""" """
:param detector_class: Class inheriting from `AbstractDetector`. :param detector_class: Class inheriting from `AbstractDetector`.
""" """
self._check_common_things('detector', detector_class, AbstractDetector, self._detectors) self._check_common_things("detector", detector_class, AbstractDetector, self._detectors)
instance = detector_class(self, logger_detector) instance = detector_class(self, logger_detector)
self._detectors.append(instance) self._detectors.append(instance)
@ -147,7 +144,7 @@ class Slither(SlitherSolc):
""" """
:param printer_class: Class inheriting from `AbstractPrinter`. :param printer_class: Class inheriting from `AbstractPrinter`.
""" """
self._check_common_things('printer', printer_class, AbstractPrinter, self._printers) self._check_common_things("printer", printer_class, AbstractPrinter, self._printers)
instance = printer_class(self, logger_printer) instance = printer_class(self, logger_printer)
self._printers.append(instance) self._printers.append(instance)
@ -179,19 +176,19 @@ class Slither(SlitherSolc):
) )
if any(type(obj) == cls for obj in instances_list): if any(type(obj) == cls for obj in instances_list):
raise Exception( raise Exception("You can't register {!r} twice.".format(cls))
"You can't register {!r} twice.".format(cls)
)
def _run_solc(self, filename, solc, disable_solc_warnings, solc_arguments, ast_format): def _run_solc(self, filename, solc, disable_solc_warnings, solc_arguments, ast_format):
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise SlitherError('{} does not exist (are you in the correct directory?)'.format(filename)) raise SlitherError(
assert filename.endswith('json') "{} does not exist (are you in the correct directory?)".format(filename)
with open(filename, encoding='utf8') as astFile: )
assert filename.endswith("json")
with open(filename, encoding="utf8") as astFile:
stdout = astFile.read() stdout = astFile.read()
if not stdout: if not stdout:
raise SlitherError('Empty AST file: %s', filename) raise SlitherError("Empty AST file: %s", filename)
stdout = stdout.split('\n=') stdout = stdout.split("\n=")
return stdout return stdout

@ -1,4 +1,5 @@
import logging import logging
from enum import Enum
from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ElementaryType
from slither.slithir.exceptions import SlithIRError from slither.slithir.exceptions import SlithIRError
@ -8,7 +9,8 @@ from slither.slithir.variables import ReferenceVariable
logger = logging.getLogger("BinaryOperationIR") logger = logging.getLogger("BinaryOperationIR")
class BinaryType(object):
class BinaryType(Enum):
POWER = 0 # ** POWER = 0 # **
MULTIPLICATION = 1 # * MULTIPLICATION = 1 # *
DIVISION = 2 # / DIVISION = 2 # /
@ -83,47 +85,47 @@ class BinaryType(object):
raise SlithIRError('get_type: Unknown operation type {})'.format(operation_type)) raise SlithIRError('get_type: Unknown operation type {})'.format(operation_type))
@staticmethod def __str__(self):
def str(operation_type): if self == BinaryType.POWER:
if operation_type == BinaryType.POWER: return "**"
return '**' if self == BinaryType.MULTIPLICATION:
if operation_type == BinaryType.MULTIPLICATION: return "*"
return '*' if self == BinaryType.DIVISION:
if operation_type == BinaryType.DIVISION: return "/"
return '/' if self == BinaryType.MODULO:
if operation_type == BinaryType.MODULO: return "%"
return '%' if self == BinaryType.ADDITION:
if operation_type == BinaryType.ADDITION: return "+"
return '+' if self == BinaryType.SUBTRACTION:
if operation_type == BinaryType.SUBTRACTION: return "-"
return '-' if self == BinaryType.LEFT_SHIFT:
if operation_type == BinaryType.LEFT_SHIFT: return "<<"
return '<<' if self == BinaryType.RIGHT_SHIFT:
if operation_type == BinaryType.RIGHT_SHIFT: return ">>"
return '>>' if self == BinaryType.AND:
if operation_type == BinaryType.AND: return "&"
return '&' if self == BinaryType.CARET:
if operation_type == BinaryType.CARET: return "^"
return '^' if self == BinaryType.OR:
if operation_type == BinaryType.OR: return "|"
return '|' if self == BinaryType.LESS:
if operation_type == BinaryType.LESS: return "<"
return '<' if self == BinaryType.GREATER:
if operation_type == BinaryType.GREATER: return ">"
return '>' if self == BinaryType.LESS_EQUAL:
if operation_type == BinaryType.LESS_EQUAL: return "<="
return '<=' if self == BinaryType.GREATER_EQUAL:
if operation_type == BinaryType.GREATER_EQUAL: return ">="
return '>=' if self == BinaryType.EQUAL:
if operation_type == BinaryType.EQUAL: return "=="
return '==' if self == BinaryType.NOT_EQUAL:
if operation_type == BinaryType.NOT_EQUAL: return "!="
return '!=' if self == BinaryType.ANDAND:
if operation_type == BinaryType.ANDAND: return "&&"
return '&&' if self == BinaryType.OROR:
if operation_type == BinaryType.OROR: return "||"
return '||' raise SlithIRError("str: Unknown operation type {} {})".format(self, type(self)))
raise SlithIRError('str: Unknown operation type {})'.format(operation_type))
class Binary(OperationWithLValue): class Binary(OperationWithLValue):
@ -131,6 +133,7 @@ class Binary(OperationWithLValue):
assert is_valid_rvalue(left_variable) assert is_valid_rvalue(left_variable)
assert is_valid_rvalue(right_variable) assert is_valid_rvalue(right_variable)
assert is_valid_lvalue(result) assert is_valid_lvalue(result)
assert isinstance(operation_type, BinaryType)
super(Binary, self).__init__() super(Binary, self).__init__()
self._variables = [left_variable, right_variable] self._variables = [left_variable, right_variable]
self._type = operation_type self._type = operation_type
@ -162,7 +165,7 @@ class Binary(OperationWithLValue):
@property @property
def type_str(self): def type_str(self):
return BinaryType.str(self._type) return str(self._type)
def __str__(self): def __str__(self):
if isinstance(self.lvalue, ReferenceVariable): if isinstance(self.lvalue, ReferenceVariable):

@ -1,4 +1,6 @@
import logging import logging
from enum import Enum
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue
from slither.slithir.exceptions import SlithIRError from slither.slithir.exceptions import SlithIRError
@ -6,7 +8,7 @@ from slither.slithir.exceptions import SlithIRError
logger = logging.getLogger("BinaryOperationIR") logger = logging.getLogger("BinaryOperationIR")
class UnaryType: class UnaryType(Enum):
BANG = 0 # ! BANG = 0 # !
TILD = 1 # ~ TILD = 1 # ~
@ -19,14 +21,13 @@ class UnaryType:
return UnaryType.TILD return UnaryType.TILD
raise SlithIRError('get_type: Unknown operation type {}'.format(operation_type)) raise SlithIRError('get_type: Unknown operation type {}'.format(operation_type))
@staticmethod def __str__(self):
def str(operation_type): if self == UnaryType.BANG:
if operation_type == UnaryType.BANG: return "!"
return '!' if self == UnaryType.TILD:
if operation_type == UnaryType.TILD: return "~"
return '~'
raise SlithIRError('str: Unknown operation type {}'.format(operation_type)) raise SlithIRError("str: Unknown operation type {}".format(self))
class Unary(OperationWithLValue): class Unary(OperationWithLValue):
@ -53,7 +54,7 @@ class Unary(OperationWithLValue):
@property @property
def type_str(self): def type_str(self):
return UnaryType.str(self._type) return str(self._type)
def __str__(self): def __str__(self):
return "{} = {} {} ".format(self.lvalue, self.type_str, self.rvalue) return "{} = {} {} ".format(self.lvalue, self.type_str, self.rvalue)

@ -1,68 +1,64 @@
from typing import Optional, Dict
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
from slither.core.expressions.assignment_operation import (
AssignmentOperation,
AssignmentOperationType,
)
from slither.core.expressions.identifier import Identifier
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.visitors.expression.find_calls import FindCalls
from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.read_var import ReadVar
from slither.visitors.expression.write_var import WriteVar from slither.visitors.expression.write_var import WriteVar
from slither.visitors.expression.find_calls import FindCalls
from slither.visitors.expression.export_values import ExportValues
from slither.core.declarations.solidity_variables import SolidityVariable, SolidityFunction
from slither.core.declarations.function import Function
from slither.core.variables.state_variable import StateVariable
from slither.core.expressions.identifier import Identifier class NodeSolc:
from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType def __init__(self, node: Node):
self._unparsed_expression: Optional[Dict] = None
class NodeSolc(Node): self._node = node
def __init__(self, nodeType, nodeId): @property
super(NodeSolc, self).__init__(nodeType, nodeId) def underlying_node(self) -> Node:
self._unparsed_expression = None return self._node
def add_unparsed_expression(self, expression): def add_unparsed_expression(self, expression: Dict):
assert self._unparsed_expression is None assert self._unparsed_expression is None
self._unparsed_expression = expression self._unparsed_expression = expression
def analyze_expressions(self, caller_context): def analyze_expressions(self, caller_context):
if self.type == NodeType.VARIABLE and not self._expression: if self._node.type == NodeType.VARIABLE and not self._node.expression:
self._expression = self.variable_declaration.expression self._node.add_expression(self._node.variable_declaration.expression)
if self._unparsed_expression: if self._unparsed_expression:
expression = parse_expression(self._unparsed_expression, caller_context) expression = parse_expression(self._unparsed_expression, caller_context)
self._expression = expression self._node.add_expression(expression)
self._unparsed_expression = None # self._unparsed_expression = None
if self.expression: if self._node.expression:
if self.type == NodeType.VARIABLE: if self._node.type == NodeType.VARIABLE:
# Update the expression to be an assignement to the variable # Update the expression to be an assignement to the variable
#print(self.variable_declaration) _expression = AssignmentOperation(
_expression = AssignmentOperation(Identifier(self.variable_declaration), Identifier(self._node.variable_declaration),
self.expression, self._node.expression,
AssignmentOperationType.ASSIGN, AssignmentOperationType.ASSIGN,
self.variable_declaration.type) self._node.variable_declaration.type,
_expression.set_offset(self.expression.source_mapping, self.slither) )
self._expression = _expression _expression.set_offset(self._node.expression.source_mapping, self._node.slither)
self._node.add_expression(_expression, bypass_verif_empty=True)
expression = self.expression
pp = ReadVar(expression)
self._expression_vars_read = pp.result()
# self._vars_read = [item for sublist in vars_read for item in sublist]
# self._state_vars_read = [x for x in self.variables_read if\
# isinstance(x, (StateVariable))]
# self._solidity_vars_read = [x for x in self.variables_read if\
# isinstance(x, (SolidityVariable))]
pp = WriteVar(expression)
self._expression_vars_written = pp.result()
# self._vars_written = [item for sublist in vars_written for item in sublist] expression = self._node.expression
# self._state_vars_written = [x for x in self.variables_written if\ read_var = ReadVar(expression)
# isinstance(x, StateVariable)] self._node.variables_read_as_expression = read_var.result()
pp = FindCalls(expression) write_var = WriteVar(expression)
self._expression_calls = pp.result() self._node.variables_written_as_expression = write_var.result()
self._external_calls_as_expressions = [c for c in self.calls_as_expression if not isinstance(c.called, Identifier)]
self._internal_calls_as_expressions = [c for c in self.calls_as_expression if isinstance(c.called, Identifier)]
find_call = FindCalls(expression)
self._node.calls_as_expression = find_call.result()
self._node.external_calls_as_expressions = [
c for c in self._node.calls_as_expression if not isinstance(c.called, Identifier)
]
self._node.internal_calls_as_expressions = [
c for c in self._node.calls_as_expression if isinstance(c.called, Identifier)
]

@ -1,55 +1,67 @@
import logging import logging
from typing import List, Dict, Callable, TYPE_CHECKING, Union
from slither.core.declarations import Modifier, Structure, Event
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function, FunctionType
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
from slither.core.cfg.node import Node, NodeType from slither.core.declarations.function import Function
from slither.core.expressions import AssignmentOperation, Identifier, AssignmentOperationType from slither.core.variables.state_variable import StateVariable
from slither.slithir.variables import StateIRVariable
from slither.solc_parsing.declarations.event import EventSolc from slither.solc_parsing.declarations.event import EventSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.modifier import ModifierSolc from slither.solc_parsing.declarations.modifier import ModifierSolc
from slither.solc_parsing.declarations.structure import StructureSolc from slither.solc_parsing.declarations.structure import StructureSolc
from slither.solc_parsing.exceptions import ParsingError, VariableNotFound
from slither.solc_parsing.solidity_types.type_parsing import parse_type from slither.solc_parsing.solidity_types.type_parsing import parse_type
from slither.solc_parsing.variables.state_variable import StateVariableSolc from slither.solc_parsing.variables.state_variable import StateVariableSolc
from slither.solc_parsing.exceptions import ParsingError, VariableNotFound
logger = logging.getLogger("ContractSolcParsing") LOGGER = logging.getLogger("ContractSolcParsing")
if TYPE_CHECKING:
from slither.solc_parsing.slitherSolc import SlitherSolc
from slither.core.slither_core import SlitherCore
class ContractSolc04(Contract):
def __init__(self, slitherSolc, data): class ContractSolc:
def __init__(self, slither_parser: "SlitherSolc", contract: Contract, data):
# assert slitherSolc.solc_version.startswith('0.4') # assert slitherSolc.solc_version.startswith('0.4')
super(ContractSolc04, self).__init__()
self.set_slither(slitherSolc) self._contract = contract
self._contract.set_slither(slither_parser.core)
self._slither_parser = slither_parser
self._data = data self._data = data
self._functionsNotParsed = [] self._functionsNotParsed: List[Dict] = []
self._modifiersNotParsed = [] self._modifiersNotParsed: List[Dict] = []
self._functions_no_params = [] self._functions_no_params: List[FunctionSolc] = []
self._modifiers_no_params = [] self._modifiers_no_params: List[ModifierSolc] = []
self._eventsNotParsed = [] self._eventsNotParsed: List[EventSolc] = []
self._variablesNotParsed = [] self._variablesNotParsed: List[Dict] = []
self._enumsNotParsed = [] self._enumsNotParsed: List[Dict] = []
self._structuresNotParsed = [] self._structuresNotParsed: List[Dict] = []
self._usingForNotParsed = [] self._usingForNotParsed: List[Dict] = []
self._functions_parser: List[FunctionSolc] = []
self._modifiers_parser: List[ModifierSolc] = []
self._structures_parser: List[StructureSolc] = []
self._is_analyzed = False self._is_analyzed: bool = False
# use to remap inheritance id # use to remap inheritance id
self._remapping = {} self._remapping: Dict[str, str] = {}
self.baseContracts = []
self.baseConstructorContractsCalled = []
self._linearized_base_contracts: List[int]
self._variables_parser: List[StateVariableSolc] = []
# Export info # Export info
if self.is_compact_ast: if self.is_compact_ast:
self._name = self._data['name'] self._contract.name = self._data["name"]
else: else:
self._name = self._data['attributes'][self.get_key()] self._contract.name = self._data["attributes"][self.get_key()]
self._id = self._data['id'] self._contract.id = self._data["id"]
self._inheritance = []
self._parse_contract_info() self._parse_contract_info()
self._parse_contract_items() self._parse_contract_items()
@ -61,33 +73,57 @@ class ContractSolc04(Contract):
################################################################################### ###################################################################################
@property @property
def is_analyzed(self): def is_analyzed(self) -> bool:
return self._is_analyzed return self._is_analyzed
def set_is_analyzed(self, is_analyzed): def set_is_analyzed(self, is_analyzed: bool):
self._is_analyzed = is_analyzed self._is_analyzed = is_analyzed
@property
def underlying_contract(self) -> Contract:
return self._contract
@property
def linearized_base_contracts(self) -> List[int]:
return self._linearized_base_contracts
@property
def slither(self) -> "SlitherCore":
return self._contract.slither
@property
def slither_parser(self) -> "SlitherSolc":
return self._slither_parser
@property
def functions_parser(self) -> List["FunctionSolc"]:
return self._functions_parser
@property
def modifiers_parser(self) -> List["ModifierSolc"]:
return self._modifiers_parser
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region AST # region AST
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_key(self): def get_key(self) -> str:
return self.slither.get_key() return self._slither_parser.get_key()
def get_children(self, key='nodes'): def get_children(self, key="nodes") -> str:
if self.is_compact_ast: if self.is_compact_ast:
return key return key
return 'children' return "children"
@property @property
def remapping(self): def remapping(self) -> Dict[str, str]:
return self._remapping return self._remapping
@property @property
def is_compact_ast(self): def is_compact_ast(self) -> bool:
return self.slither.is_compact_ast return self._slither_parser.is_compact_ast
# endregion # endregion
################################################################################### ###################################################################################
@ -100,162 +136,190 @@ class ContractSolc04(Contract):
if self.is_compact_ast: if self.is_compact_ast:
attributes = self._data attributes = self._data
else: else:
attributes = self._data['attributes'] attributes = self._data["attributes"]
self.isInterface = False self._contract.is_interface = False
if 'contractKind' in attributes: if "contractKind" in attributes:
if attributes['contractKind'] == 'interface': if attributes["contractKind"] == "interface":
self.isInterface = True self._contract.is_interface = True
self._kind = attributes['contractKind'] self._contract.kind = attributes["contractKind"]
self.linearizedBaseContracts = attributes['linearizedBaseContracts'] self._linearized_base_contracts = attributes["linearizedBaseContracts"]
self.fullyImplemented = attributes['fullyImplemented'] # self._contract.fullyImplemented = attributes["fullyImplemented"]
# Parse base contract information # Parse base contract information
self._parse_base_contract_info() self._parse_base_contract_info()
# trufle does some re-mapping of id # trufle does some re-mapping of id
if 'baseContracts' in self._data: if "baseContracts" in self._data:
for elem in self._data['baseContracts']: for elem in self._data["baseContracts"]:
if elem['nodeType'] == 'InheritanceSpecifier': if elem["nodeType"] == "InheritanceSpecifier":
self._remapping[elem['baseName']['referencedDeclaration']] = elem['baseName']['name'] self._remapping[elem["baseName"]["referencedDeclaration"]] = elem["baseName"][
"name"
]
def _parse_base_contract_info(self): def _parse_base_contract_info(self):
# Parse base contracts (immediate, non-linearized) # Parse base contracts (immediate, non-linearized)
self.baseContracts = []
self.baseConstructorContractsCalled = []
if self.is_compact_ast: if self.is_compact_ast:
# Parse base contracts + constructors in compact-ast # Parse base contracts + constructors in compact-ast
if 'baseContracts' in self._data: if "baseContracts" in self._data:
for base_contract in self._data['baseContracts']: for base_contract in self._data["baseContracts"]:
if base_contract['nodeType'] != 'InheritanceSpecifier': if base_contract["nodeType"] != "InheritanceSpecifier":
continue continue
if 'baseName' not in base_contract or 'referencedDeclaration' not in base_contract['baseName']: if (
"baseName" not in base_contract
or "referencedDeclaration" not in base_contract["baseName"]
):
continue continue
# Obtain our contract reference and add it to our base contract list # Obtain our contract reference and add it to our base contract list
referencedDeclaration = base_contract['baseName']['referencedDeclaration'] referencedDeclaration = base_contract["baseName"]["referencedDeclaration"]
self.baseContracts.append(referencedDeclaration) self.baseContracts.append(referencedDeclaration)
# If we have defined arguments in our arguments object, this is a constructor invocation. # If we have defined arguments in our arguments object, this is a constructor invocation.
# (note: 'arguments' can be [], which is not the same as None. [] implies a constructor was # (note: 'arguments' can be [], which is not the same as None. [] implies a constructor was
# called with no arguments, while None implies no constructor was called). # called with no arguments, while None implies no constructor was called).
if 'arguments' in base_contract and base_contract['arguments'] is not None: if "arguments" in base_contract and base_contract["arguments"] is not None:
self.baseConstructorContractsCalled.append(referencedDeclaration) self.baseConstructorContractsCalled.append(referencedDeclaration)
else: else:
# Parse base contracts + constructors in legacy-ast # Parse base contracts + constructors in legacy-ast
if 'children' in self._data: if "children" in self._data:
for base_contract in self._data['children']: for base_contract in self._data["children"]:
if base_contract['name'] != 'InheritanceSpecifier': if base_contract["name"] != "InheritanceSpecifier":
continue continue
if 'children' not in base_contract or len(base_contract['children']) == 0: if "children" not in base_contract or len(base_contract["children"]) == 0:
continue continue
# Obtain all items for this base contract specification (base contract, followed by arguments) # Obtain all items for this base contract specification (base contract, followed by arguments)
base_contract_items = base_contract['children'] base_contract_items = base_contract["children"]
if 'name' not in base_contract_items[0] or base_contract_items[0]['name'] != 'UserDefinedTypeName': if (
"name" not in base_contract_items[0]
or base_contract_items[0]["name"] != "UserDefinedTypeName"
):
continue continue
if 'attributes' not in base_contract_items[0] or 'referencedDeclaration' not in \ if (
base_contract_items[0]['attributes']: "attributes" not in base_contract_items[0]
or "referencedDeclaration" not in base_contract_items[0]["attributes"]
):
continue continue
# Obtain our contract reference and add it to our base contract list # Obtain our contract reference and add it to our base contract list
referencedDeclaration = base_contract_items[0]['attributes']['referencedDeclaration'] referencedDeclaration = base_contract_items[0]["attributes"][
"referencedDeclaration"
]
self.baseContracts.append(referencedDeclaration) self.baseContracts.append(referencedDeclaration)
# If we have an 'attributes'->'arguments' which is None, this is not a constructor call. # If we have an 'attributes'->'arguments' which is None, this is not a constructor call.
if 'attributes' not in base_contract or 'arguments' not in base_contract['attributes'] or \ if (
base_contract['attributes']['arguments'] is not None: "attributes" not in base_contract
or "arguments" not in base_contract["attributes"]
or base_contract["attributes"]["arguments"] is not None
):
self.baseConstructorContractsCalled.append(referencedDeclaration) self.baseConstructorContractsCalled.append(referencedDeclaration)
def _parse_contract_items(self): def _parse_contract_items(self):
if not self.get_children() in self._data: # empty contract if not self.get_children() in self._data: # empty contract
return return
for item in self._data[self.get_children()]: for item in self._data[self.get_children()]:
if item[self.get_key()] == 'FunctionDefinition': if item[self.get_key()] == "FunctionDefinition":
self._functionsNotParsed.append(item) self._functionsNotParsed.append(item)
elif item[self.get_key()] == 'EventDefinition': elif item[self.get_key()] == "EventDefinition":
self._eventsNotParsed.append(item) self._eventsNotParsed.append(item)
elif item[self.get_key()] == 'InheritanceSpecifier': elif item[self.get_key()] == "InheritanceSpecifier":
# we dont need to parse it as it is redundant # we dont need to parse it as it is redundant
# with self.linearizedBaseContracts # with self.linearizedBaseContracts
continue continue
elif item[self.get_key()] == 'VariableDeclaration': elif item[self.get_key()] == "VariableDeclaration":
self._variablesNotParsed.append(item) self._variablesNotParsed.append(item)
elif item[self.get_key()] == 'EnumDefinition': elif item[self.get_key()] == "EnumDefinition":
self._enumsNotParsed.append(item) self._enumsNotParsed.append(item)
elif item[self.get_key()] == 'ModifierDefinition': elif item[self.get_key()] == "ModifierDefinition":
self._modifiersNotParsed.append(item) self._modifiersNotParsed.append(item)
elif item[self.get_key()] == 'StructDefinition': elif item[self.get_key()] == "StructDefinition":
self._structuresNotParsed.append(item) self._structuresNotParsed.append(item)
elif item[self.get_key()] == 'UsingForDirective': elif item[self.get_key()] == "UsingForDirective":
self._usingForNotParsed.append(item) self._usingForNotParsed.append(item)
else: else:
raise ParsingError('Unknown contract item: ' + item[self.get_key()]) raise ParsingError("Unknown contract item: " + item[self.get_key()])
return return
def _parse_struct(self, struct): def _parse_struct(self, struct: Dict):
if self.is_compact_ast: if self.is_compact_ast:
name = struct['name'] name = struct["name"]
attributes = struct attributes = struct
else: else:
name = struct['attributes'][self.get_key()] name = struct["attributes"][self.get_key()]
attributes = struct['attributes'] attributes = struct["attributes"]
if 'canonicalName' in attributes: if "canonicalName" in attributes:
canonicalName = attributes['canonicalName'] canonicalName = attributes["canonicalName"]
else: else:
canonicalName = self.name + '.' + name canonicalName = self._contract.name + "." + name
if self.get_children('members') in struct: if self.get_children("members") in struct:
children = struct[self.get_children('members')] children = struct[self.get_children("members")]
else: else:
children = [] # empty struct children = [] # empty struct
st = StructureSolc(name, canonicalName, children)
st.set_contract(self) st = Structure()
st.set_offset(struct['src'], self.slither) st.set_contract(self._contract)
self._structures[name] = st st.set_offset(struct["src"], self._contract.slither)
st_parser = StructureSolc(st, name, canonicalName, children, self)
self._contract.structures_as_dict[name] = st
self._structures_parser.append(st_parser)
def parse_structs(self): def parse_structs(self):
for father in self.inheritance_reverse: for father in self._contract.inheritance_reverse:
self._structures.update(father.structures_as_dict()) self._contract.structures_as_dict.update(father.structures_as_dict)
for struct in self._structuresNotParsed: for struct in self._structuresNotParsed:
self._parse_struct(struct) self._parse_struct(struct)
self._structuresNotParsed = None self._structuresNotParsed = None
def parse_state_variables(self): def parse_state_variables(self):
for father in self.inheritance_reverse: for father in self._contract.inheritance_reverse:
self._variables.update(father.variables_as_dict()) self._contract.variables_as_dict.update(father.variables_as_dict)
self._variables_ordered += father.state_variables_ordered self._contract.add_variables_ordered(father.state_variables_ordered)
for varNotParsed in self._variablesNotParsed: for varNotParsed in self._variablesNotParsed:
var = StateVariableSolc(varNotParsed) var = StateVariable()
var.set_offset(varNotParsed['src'], self.slither) var.set_offset(varNotParsed["src"], self._contract.slither)
var.set_contract(self) var.set_contract(self._contract)
self._variables[var.name] = var var_parser = StateVariableSolc(var, varNotParsed)
self._variables_ordered.append(var) self._variables_parser.append(var_parser)
def _parse_modifier(self, modifier): self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var])
modif = ModifierSolc(modifier, self, self) def _parse_modifier(self, modifier_data: Dict):
modif.set_contract(self) modif = Modifier()
modif.set_contract_declarer(self) modif.set_offset(modifier_data["src"], self._contract.slither)
modif.set_offset(modifier['src'], self.slither) modif.set_contract(self._contract)
self.slither.add_modifier(modif) modif.set_contract_declarer(self._contract)
self._modifiers_no_params.append(modif)
def parse_modifiers(self): modif_parser = ModifierSolc(modif, modifier_data, self)
self._contract.slither.add_modifier(modif)
self._modifiers_no_params.append(modif_parser)
self._modifiers_parser.append(modif_parser)
self._slither_parser.add_functions_parser(modif_parser)
def parse_modifiers(self):
for modifier in self._modifiersNotParsed: for modifier in self._modifiersNotParsed:
self._parse_modifier(modifier) self._parse_modifier(modifier)
self._modifiersNotParsed = None self._modifiersNotParsed = None
return def _parse_function(self, function_data: Dict):
func = Function()
func.set_offset(function_data["src"], self._contract.slither)
func.set_contract(self._contract)
func.set_contract_declarer(self._contract)
def _parse_function(self, function): func_parser = FunctionSolc(func, function_data, self)
func = FunctionSolc(function, self, self) self._contract.slither.add_function(func)
func.set_offset(function['src'], self.slither) self._functions_no_params.append(func_parser)
self.slither.add_function(func) self._functions_parser.append(func_parser)
self._functions_no_params.append(func)
self._slither_parser.add_functions_parser(func_parser)
def parse_functions(self): def parse_functions(self):
@ -264,8 +328,6 @@ class ContractSolc04(Contract):
self._functionsNotParsed = None self._functionsNotParsed = None
return
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -274,56 +336,78 @@ class ContractSolc04(Contract):
################################################################################### ###################################################################################
def log_incorrect_parsing(self, error): def log_incorrect_parsing(self, error):
logger.error(error) LOGGER.error(error)
self._is_incorrectly_parsed = True self._contract.is_incorrectly_parsed = True
def analyze_content_modifiers(self): def analyze_content_modifiers(self):
try: try:
for modifier in self.modifiers: for modifier_parser in self._modifiers_parser:
modifier.analyze_content() modifier_parser.analyze_content()
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing modifier {e}') self.log_incorrect_parsing(f"Missing modifier {e}")
return
def analyze_content_functions(self): def analyze_content_functions(self):
try: try:
for function in self.functions: for function_parser in self._functions_parser:
function.analyze_content() function_parser.analyze_content()
except (VariableNotFound, KeyError, ParsingError) as e: except (VariableNotFound, KeyError, ParsingError) as e:
self.log_incorrect_parsing(f'Missing function {e}') self.log_incorrect_parsing(f"Missing function {e}")
return return
def analyze_params_modifiers(self): def analyze_params_modifiers(self):
try: try:
elements_no_params = self._modifiers_no_params elements_no_params = self._modifiers_no_params
getter = lambda f: f.modifiers getter = lambda c: c.modifiers_parser
getter_available = lambda f: f.modifiers_declared getter_available = lambda c: c.modifiers_declared
Cls = ModifierSolc Cls = Modifier
self._modifiers = self._analyze_params_elements(elements_no_params, getter, getter_available, Cls) Cls_parser = ModifierSolc
modifiers = self._analyze_params_elements(
elements_no_params,
getter,
getter_available,
Cls,
Cls_parser,
self._modifiers_parser,
)
self._contract.set_modifiers(modifiers)
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing params {e}') self.log_incorrect_parsing(f"Missing params {e}")
self._modifiers_no_params = [] self._modifiers_no_params = []
return
def analyze_params_functions(self): def analyze_params_functions(self):
try: try:
elements_no_params = self._functions_no_params elements_no_params = self._functions_no_params
getter = lambda f: f.functions getter = lambda c: c.functions_parser
getter_available = lambda f: f.functions_declared getter_available = lambda c: c.functions_declared
Cls = FunctionSolc Cls = Function
self._functions = self._analyze_params_elements(elements_no_params, getter, getter_available, Cls) Cls_parser = FunctionSolc
functions = self._analyze_params_elements(
elements_no_params,
getter,
getter_available,
Cls,
Cls_parser,
self._functions_parser,
)
self._contract.set_functions(functions)
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing params {e}') self.log_incorrect_parsing(f"Missing params {e}")
self._functions_no_params = [] self._functions_no_params = []
return
def _analyze_params_elements(self, elements_no_params, getter, getter_available, Cls): def _analyze_params_elements(
self,
elements_no_params: List[FunctionSolc],
getter: Callable[["ContractSolc"], List[FunctionSolc]],
getter_available: Callable[[Contract], List[Function]],
Cls: Callable,
Cls_parser: Callable,
parser: List[FunctionSolc],
) -> Dict[str, Union[Function, Modifier]]:
""" """
Analyze the parameters of the given elements (Function or Modifier). Analyze the parameters of the given elements (Function or Modifier).
The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier) The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier)
If the element is shadowed, set is_shadowed to True If the element is shadowed, set is_shadowed to True
:param elements_no_params: list of elements to analyzer :param elements_no_params: list of elements to analyzer
:param getter: fun x :param getter: fun x
:param getter_available: fun x :param getter_available: fun x
@ -333,15 +417,31 @@ class ContractSolc04(Contract):
all_elements = {} all_elements = {}
try: try:
for father in self.inheritance: for father in self._contract.inheritance:
for element in getter(father): father_parser = self._slither_parser.underlying_contract_to_parser[father]
elem = Cls(element._functionNotParsed, self, element.contract_declarer) for element_parser in getter(father_parser):
elem.set_offset(element._functionNotParsed['src'], self.slither) elem = Cls()
elem.analyze_params() elem.set_contract(self._contract)
self.slither.add_function(elem) elem.set_contract_declarer(element_parser.underlying_function.contract_declarer)
elem.set_offset(
element_parser.function_not_parsed["src"], self._contract.slither
)
elem_parser = Cls_parser(elem, element_parser.function_not_parsed, self,)
elem_parser.analyze_params()
if isinstance(elem, Modifier):
self._contract.slither.add_modifier(elem)
else:
self._contract.slither.add_function(elem)
self._slither_parser.add_functions_parser(elem_parser)
all_elements[elem.canonical_name] = elem all_elements[elem.canonical_name] = elem
parser.append(elem_parser)
accessible_elements = self.available_elements_from_inheritances(all_elements, getter_available) accessible_elements = self._contract.available_elements_from_inheritances(
all_elements, getter_available
)
# If there is a constructor in the functions # If there is a constructor in the functions
# We remove the previous constructor # We remove the previous constructor
@ -349,128 +449,65 @@ class ContractSolc04(Contract):
# #
# Note: contract.all_functions_called returns the constructors of the base contracts # Note: contract.all_functions_called returns the constructors of the base contracts
has_constructor = False has_constructor = False
for element in elements_no_params: for element_parser in elements_no_params:
element.analyze_params() element_parser.analyze_params()
if element.is_constructor: if element_parser.underlying_function.is_constructor:
has_constructor = True has_constructor = True
if has_constructor: if has_constructor:
_accessible_functions = {k: v for (k, v) in accessible_elements.items() if not v.is_constructor} _accessible_functions = {
k: v for (k, v) in accessible_elements.items() if not v.is_constructor
for element in elements_no_params: }
accessible_elements[element.full_name] = element
all_elements[element.canonical_name] = element for element_parser in elements_no_params:
accessible_elements[
element_parser.underlying_function.full_name
] = element_parser.underlying_function
all_elements[
element_parser.underlying_function.canonical_name
] = element_parser.underlying_function
for element in all_elements.values(): for element in all_elements.values():
if accessible_elements[element.full_name] != all_elements[element.canonical_name]: if accessible_elements[element.full_name] != all_elements[element.canonical_name]:
element.is_shadowed = True element.is_shadowed = True
accessible_elements[element.full_name].shadows = True accessible_elements[element.full_name].shadows = True
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing params {e}') self.log_incorrect_parsing(f"Missing params {e}")
return all_elements return all_elements
def analyze_constant_state_variables(self): def analyze_constant_state_variables(self):
for var in self.variables: for var_parser in self._variables_parser:
if var.is_constant: if var_parser.underlying_variable.is_constant:
# cant parse constant expression based on function calls # cant parse constant expression based on function calls
try: try:
var.analyze(self) var_parser.analyze(self)
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
logger.error(e) LOGGER.error(e)
pass pass
return
def _create_node(self, func, counter, variable):
# Function uses to create node for state variable declaration statements
node = Node(NodeType.OTHER_ENTRYPOINT, counter)
node.set_offset(variable.source_mapping, self.slither)
node.set_function(func)
func.add_node(node)
expression = AssignmentOperation(Identifier(variable),
variable.expression,
AssignmentOperationType.ASSIGN,
variable.type)
expression.set_offset(variable.source_mapping, self.slither)
node.add_expression(expression)
return node
def add_constructor_variables(self):
if self.state_variables:
for (idx, variable_candidate) in enumerate(self.state_variables):
if variable_candidate.expression and not variable_candidate.is_constant:
constructor_variable = Function()
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES)
constructor_variable.set_contract(self)
constructor_variable.set_contract_declarer(self)
constructor_variable.set_visibility('internal')
# For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping
constructor_variable.set_offset(self.source_mapping, self.slither)
self._functions[constructor_variable.canonical_name] = constructor_variable
prev_node = self._create_node(constructor_variable, 0, variable_candidate)
variable_candidate.node_initialization = prev_node
counter = 1
for v in self.state_variables[idx + 1:]:
if v.expression and not v.is_constant:
next_node = self._create_node(constructor_variable, counter, v)
v.node_initialization = next_node
prev_node.add_son(next_node)
next_node.add_father(prev_node)
counter += 1
break
for (idx, variable_candidate) in enumerate(self.state_variables):
if variable_candidate.expression and variable_candidate.is_constant:
constructor_variable = Function()
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES)
constructor_variable.set_contract(self)
constructor_variable.set_contract_declarer(self)
constructor_variable.set_visibility('internal')
# For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping
constructor_variable.set_offset(self.source_mapping, self.slither)
self._functions[constructor_variable.canonical_name] = constructor_variable
prev_node = self._create_node(constructor_variable, 0, variable_candidate)
variable_candidate.node_initialization = prev_node
counter = 1
for v in self.state_variables[idx + 1:]:
if v.expression and v.is_constant:
next_node = self._create_node(constructor_variable, counter, v)
v.node_initialization = next_node
prev_node.add_son(next_node)
next_node.add_father(prev_node)
counter += 1
break
def analyze_state_variables(self): def analyze_state_variables(self):
try: try:
for var in self.variables: for var_parser in self._variables_parser:
var.analyze(self) var_parser.analyze(self)
return return
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing state variable {e}') self.log_incorrect_parsing(f"Missing state variable {e}")
def analyze_using_for(self): def analyze_using_for(self):
try: try:
for father in self.inheritance: for father in self._contract.inheritance:
self._using_for.update(father.using_for) self._contract.using_for.update(father.using_for)
if self.is_compact_ast: if self.is_compact_ast:
for using_for in self._usingForNotParsed: for using_for in self._usingForNotParsed:
lib_name = parse_type(using_for['libraryName'], self) lib_name = parse_type(using_for["libraryName"], self)
if 'typeName' in using_for and using_for['typeName']: if "typeName" in using_for and using_for["typeName"]:
type_name = parse_type(using_for['typeName'], self) type_name = parse_type(using_for["typeName"], self)
else: else:
type_name = '*' type_name = "*"
if not type_name in self._using_for: if type_name not in self._contract.using_for:
self.using_for[type_name] = [] self._contract.using_for[type_name] = []
self._using_for[type_name].append(lib_name) self._contract.using_for[type_name].append(lib_name)
else: else:
for using_for in self._usingForNotParsed: for using_for in self._usingForNotParsed:
children = using_for[self.get_children()] children = using_for[self.get_children()]
@ -480,18 +517,18 @@ class ContractSolc04(Contract):
old = parse_type(children[1], self) old = parse_type(children[1], self)
else: else:
new = parse_type(children[0], self) new = parse_type(children[0], self)
old = '*' old = "*"
if not old in self._using_for: if old not in self._contract.using_for:
self.using_for[old] = [] self._contract.using_for[old] = []
self._using_for[old].append(new) self._contract.using_for[old].append(new)
self._usingForNotParsed = [] self._usingForNotParsed = []
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing using for {e}') self.log_incorrect_parsing(f"Missing using for {e}")
def analyze_enums(self): def analyze_enums(self):
try: try:
for father in self.inheritance: for father in self._contract.inheritance:
self._enums.update(father.enums_as_dict()) self._contract.enums_as_dict.update(father.enums_as_dict)
for enum in self._enumsNotParsed: for enum in self._enumsNotParsed:
# for enum, we can parse and analyze it # for enum, we can parse and analyze it
@ -499,107 +536,60 @@ class ContractSolc04(Contract):
self._analyze_enum(enum) self._analyze_enum(enum)
self._enumsNotParsed = None self._enumsNotParsed = None
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing enum {e}') self.log_incorrect_parsing(f"Missing enum {e}")
def _analyze_enum(self, enum): def _analyze_enum(self, enum):
# Enum can be parsed in one pass # Enum can be parsed in one pass
if self.is_compact_ast: if self.is_compact_ast:
name = enum['name'] name = enum["name"]
canonicalName = enum['canonicalName'] canonicalName = enum["canonicalName"]
else: else:
name = enum['attributes'][self.get_key()] name = enum["attributes"][self.get_key()]
if 'canonicalName' in enum['attributes']: if "canonicalName" in enum["attributes"]:
canonicalName = enum['attributes']['canonicalName'] canonicalName = enum["attributes"]["canonicalName"]
else: else:
canonicalName = self.name + '.' + name canonicalName = self._contract.name + "." + name
values = [] values = []
for child in enum[self.get_children('members')]: for child in enum[self.get_children("members")]:
assert child[self.get_key()] == 'EnumValue' assert child[self.get_key()] == "EnumValue"
if self.is_compact_ast: if self.is_compact_ast:
values.append(child['name']) values.append(child["name"])
else: else:
values.append(child['attributes'][self.get_key()]) values.append(child["attributes"][self.get_key()])
new_enum = Enum(name, canonicalName, values) new_enum = Enum(name, canonicalName, values)
new_enum.set_contract(self) new_enum.set_contract(self._contract)
new_enum.set_offset(enum['src'], self.slither) new_enum.set_offset(enum["src"], self._contract.slither)
self._enums[canonicalName] = new_enum self._contract.enums_as_dict[canonicalName] = new_enum
def _analyze_struct(self, struct): def _analyze_struct(self, struct: StructureSolc):
struct.analyze() struct.analyze()
def analyze_structs(self): def analyze_structs(self):
try: try:
for struct in self.structures: for struct in self._structures_parser:
self._analyze_struct(struct) self._analyze_struct(struct)
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing struct {e}') self.log_incorrect_parsing(f"Missing struct {e}")
def analyze_events(self): def analyze_events(self):
try: try:
for father in self.inheritance_reverse: for father in self._contract.inheritance_reverse:
self._events.update(father.events_as_dict()) self._contract.events_as_dict.update(father.events_as_dict)
for event_to_parse in self._eventsNotParsed: for event_to_parse in self._eventsNotParsed:
event = EventSolc(event_to_parse, self) event = Event()
event.analyze(self) event.set_contract(self._contract)
event.set_contract(self) event.set_offset(event_to_parse["src"], self._contract.slither)
event.set_offset(event_to_parse['src'], self.slither)
self._events[event.full_name] = event event_parser = EventSolc(event, event_to_parse, self)
event_parser.analyze(self)
self._contract.events_as_dict[event.full_name] = event
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f'Missing event {e}') self.log_incorrect_parsing(f"Missing event {e}")
self._eventsNotParsed = None self._eventsNotParsed = None
# endregion
###################################################################################
###################################################################################
# region SlithIR
###################################################################################
###################################################################################
def convert_expression_to_slithir(self):
for func in self.functions + self.modifiers:
try:
func.generate_slithir_and_analyze()
except AttributeError:
# This can happens for example if there is a call to an interface
# And the interface is redefined due to contract's name reuse
# But the available version misses some functions
self.log_incorrect_parsing(f'Impossible to generate IR for {self.name}.{func.name}')
all_ssa_state_variables_instances = dict()
for contract in self.inheritance:
for v in contract.state_variables_declared:
new_var = StateIRVariable(v)
all_ssa_state_variables_instances[v.canonical_name] = new_var
self._initial_state_variables.append(new_var)
for v in self.variables:
if v.contract == self:
new_var = StateIRVariable(v)
all_ssa_state_variables_instances[v.canonical_name] = new_var
self._initial_state_variables.append(new_var)
for func in self.functions + self.modifiers:
func.generate_slithir_ssa(all_ssa_state_variables_instances)
def fix_phi(self):
last_state_variables_instances = dict()
initial_state_variables_instances = dict()
for v in self._initial_state_variables:
last_state_variables_instances[v.canonical_name] = []
initial_state_variables_instances[v.canonical_name] = v
for func in self.functions + self.modifiers:
result = func.get_last_ssa_state_variables_instances()
for variable_name, instances in result.items():
last_state_variables_instances[variable_name] += instances
for func in self.functions + self.modifiers:
func.fix_phi(last_state_variables_instances, initial_state_variables_instances)
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -631,6 +621,6 @@ class ContractSolc04(Contract):
################################################################################### ###################################################################################
def __hash__(self): def __hash__(self):
return self._id return self._contract.id
# endregion # endregion

@ -1,43 +1,55 @@
""" """
Event module Event module
""" """
from typing import TYPE_CHECKING, Dict
from slither.core.variables.event_variable import EventVariable
from slither.solc_parsing.variables.event_variable import EventVariableSolc from slither.solc_parsing.variables.event_variable import EventVariableSolc
from slither.core.declarations.event import Event from slither.core.declarations.event import Event
class EventSolc(Event): if TYPE_CHECKING:
from slither.solc_parsing.declarations.contract import ContractSolc
class EventSolc:
""" """
Event class Event class
""" """
def __init__(self, event, contract): def __init__(self, event: Event, event_data: Dict, contract_parser: "ContractSolc"):
super(EventSolc, self).__init__()
self._contract = contract self._event = event
event.set_contract(contract_parser.underlying_contract)
self._parser_contract = contract_parser
self._elems = []
if self.is_compact_ast: if self.is_compact_ast:
self._name = event['name'] self._event.name = event_data["name"]
elems = event['parameters'] elems = event_data["parameters"]
assert elems['nodeType'] == 'ParameterList' assert elems["nodeType"] == "ParameterList"
self._elemsNotParsed = elems['parameters'] self._elemsNotParsed = elems["parameters"]
else: else:
self._name = event['attributes']['name'] self._event.name = event_data["attributes"]["name"]
elems = event['children'][0] elems = event_data["children"][0]
assert elems['name'] == 'ParameterList' assert elems["name"] == "ParameterList"
if 'children' in elems: if "children" in elems:
self._elemsNotParsed = elems['children'] self._elemsNotParsed = elems["children"]
else: else:
self._elemsNotParsed = [] self._elemsNotParsed = []
@property @property
def is_compact_ast(self): def is_compact_ast(self) -> bool:
return self.contract.is_compact_ast return self._parser_contract.is_compact_ast
def analyze(self, contract): def analyze(self, contract: "ContractSolc"):
for elem_to_parse in self._elemsNotParsed: for elem_to_parse in self._elemsNotParsed:
elem = EventVariableSolc(elem_to_parse) elem = EventVariable()
elem.analyze(contract) # Todo: check if the source offset is always here
self._elems.append(elem) if "src" in elem_to_parse:
elem.set_offset(elem_to_parse["src"], self._parser_contract.slither)
elem_parser = EventVariableSolc(elem, elem_to_parse)
elem_parser.analyze(contract)
self._elemsNotParsed = [] self._event.elems.append(elem)
self._elemsNotParsed = []

File diff suppressed because it is too large Load Diff

@ -1,15 +1,29 @@
""" """
Event module Event module
""" """
from slither.core.declarations.modifier import Modifier from typing import Dict, TYPE_CHECKING
from slither.solc_parsing.declarations.function import FunctionSolc
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
from slither.core.cfg.node import link_nodes from slither.core.cfg.node import link_nodes
from slither.core.declarations.modifier import Modifier
from slither.solc_parsing.cfg.node import NodeSolc
from slither.solc_parsing.declarations.function import FunctionSolc
class ModifierSolc(Modifier, FunctionSolc): if TYPE_CHECKING:
from slither.solc_parsing.declarations.contract import ContractSolc
class ModifierSolc(FunctionSolc):
def __init__(self, modifier: Modifier, function_data: Dict, contract_parser: "ContractSolc"):
super().__init__(modifier, function_data, contract_parser)
# _modifier is equal to _function, but keep it here to prevent
# confusion for mypy in underlying_function
self._modifier = modifier
@property
def underlying_function(self) -> Modifier:
return self._modifier
def analyze_params(self): def analyze_params(self):
# Can be re-analyzed due to inheritance # Can be re-analyzed due to inheritance
if self._params_was_analyzed: if self._params_was_analyzed:
@ -20,9 +34,9 @@ class ModifierSolc(Modifier, FunctionSolc):
self._analyze_attributes() self._analyze_attributes()
if self.is_compact_ast: if self.is_compact_ast:
params = self._functionNotParsed['parameters'] params = self._functionNotParsed["parameters"]
else: else:
children = self._functionNotParsed['children'] children = self._functionNotParsed["children"]
params = children[0] params = children[0]
if params: if params:
@ -34,41 +48,40 @@ class ModifierSolc(Modifier, FunctionSolc):
self._content_was_analyzed = True self._content_was_analyzed = True
if self.is_compact_ast: if self.is_compact_ast:
body = self._functionNotParsed['body'] body = self._functionNotParsed["body"]
if body and body[self.get_key()] == 'Block': if body and body[self.get_key()] == "Block":
self._is_implemented = True self._function.is_implemented = True
self._parse_cfg(body) self._parse_cfg(body)
else: else:
children = self._functionNotParsed['children'] children = self._functionNotParsed["children"]
self._isImplemented = False self._function.is_implemented = False
if len(children) > 1: if len(children) > 1:
assert len(children) == 2 assert len(children) == 2
block = children[1] block = children[1]
assert block['name'] == 'Block' assert block["name"] == "Block"
self._is_implemented = True self._function.is_implemented = True
self._parse_cfg(block) self._parse_cfg(block)
for local_vars in self.variables: for local_var_parser in self._local_variables_parser:
local_vars.analyze(self) local_var_parser.analyze(self)
for node in self.nodes: for node in self._node_to_nodesolc.values():
node.analyze_expressions(self) node.analyze_expressions(self)
self._filter_ternary() self._filter_ternary()
self._remove_alone_endif() self._remove_alone_endif()
self._analyze_read_write() # self._analyze_read_write()
self._analyze_calls() # self._analyze_calls()
def _parse_statement(self, statement, node): def _parse_statement(self, statement: Dict, node: NodeSolc) -> NodeSolc:
name = statement[self.get_key()] name = statement[self.get_key()]
if name == 'PlaceholderStatement': if name == "PlaceholderStatement":
placeholder_node = self._new_node(NodeType.PLACEHOLDER, statement['src']) placeholder_node = self._new_node(NodeType.PLACEHOLDER, statement["src"])
link_nodes(node, placeholder_node) link_nodes(node.underlying_node, placeholder_node.underlying_node)
return placeholder_node return placeholder_node
return super(ModifierSolc, self)._parse_statement(statement, node) return super(ModifierSolc, self)._parse_statement(statement, node)

@ -1,34 +1,47 @@
""" """
Structure module Structure module
""" """
from typing import List, TYPE_CHECKING
from slither.core.variables.structure_variable import StructureVariable
from slither.solc_parsing.variables.structure_variable import StructureVariableSolc from slither.solc_parsing.variables.structure_variable import StructureVariableSolc
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
class StructureSolc(Structure): if TYPE_CHECKING:
from slither.solc_parsing.declarations.contract import ContractSolc
class StructureSolc:
""" """
Structure class Structure class
""" """
# elems = [(type, name)]
# elems = [(type, name)]
def __init__(self, name, canonicalName, elems): def __init__(
super(StructureSolc, self).__init__() self,
self._name = name st: Structure,
self._canonical_name = canonicalName name: str,
self._elems = {} canonicalName: str,
self._elems_ordered = [] elems: List[str],
contract_parser: "ContractSolc",
):
self._structure = st
st.name = name
st.canonical_name = canonicalName
self._contract_parser = contract_parser
self._elemsNotParsed = elems self._elemsNotParsed = elems
def analyze(self): def analyze(self):
for elem_to_parse in self._elemsNotParsed: for elem_to_parse in self._elemsNotParsed:
elem = StructureVariableSolc(elem_to_parse) elem = StructureVariable()
elem.set_structure(self) elem.set_structure(self._structure)
elem.set_offset(elem_to_parse['src'], self.contract.slither) elem.set_offset(elem_to_parse["src"], self._structure.contract.slither)
elem.analyze(self.contract) elem_parser = StructureVariableSolc(elem, elem_to_parse)
elem_parser.analyze(self._contract_parser)
self._elems[elem.name] = elem self._structure.elems[elem.name] = elem
self._elems_ordered.append(elem.name) self._structure.add_elem_in_order(elem.name)
self._elemsNotParsed = [] self._elemsNotParsed = []

@ -1,5 +1,9 @@
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
class ParsingError(SlitherException): pass
class VariableNotFound(SlitherException): pass class ParsingError(SlitherException):
pass
class VariableNotFound(SlitherException):
pass

@ -1,23 +1,26 @@
import logging import logging
import re import re
from typing import Dict, TYPE_CHECKING, Optional, Union
from slither.core.declarations import Event, Enum, Structure
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function from slither.core.declarations.function import Function
from slither.core.declarations.solidity_variables import (SOLIDITY_FUNCTIONS, from slither.core.declarations.solidity_variables import (
SOLIDITY_FUNCTIONS,
SOLIDITY_VARIABLES, SOLIDITY_VARIABLES,
SOLIDITY_VARIABLES_COMPOSED, SOLIDITY_VARIABLES_COMPOSED,
SolidityFunction, SolidityFunction,
SolidityVariable, SolidityVariable,
SolidityVariableComposed) SolidityVariableComposed,
from slither.core.expressions.assignment_operation import (AssignmentOperation, )
AssignmentOperationType) from slither.core.expressions.assignment_operation import (
from slither.core.expressions.binary_operation import (BinaryOperation, AssignmentOperation,
BinaryOperationType) AssignmentOperationType,
)
from slither.core.expressions.binary_operation import BinaryOperation, BinaryOperationType
from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.conditional_expression import \ from slither.core.expressions.conditional_expression import ConditionalExpression
ConditionalExpression from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.expressions.elementary_type_name_expression import \
ElementaryTypeNameExpression
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.index_access import IndexAccess
from slither.core.expressions.literal import Literal from slither.core.expressions.literal import Literal
@ -29,16 +32,18 @@ from slither.core.expressions.super_call_expression import SuperCallExpression
from slither.core.expressions.super_identifier import SuperIdentifier from slither.core.expressions.super_identifier import SuperIdentifier
from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.tuple_expression import TupleExpression
from slither.core.expressions.type_conversion import TypeConversion from slither.core.expressions.type_conversion import TypeConversion
from slither.core.expressions.unary_operation import (UnaryOperation, from slither.core.expressions.unary_operation import UnaryOperation, UnaryOperationType
UnaryOperationType) from slither.core.solidity_types import ArrayType, ElementaryType, FunctionType, MappingType
from slither.core.solidity_types import (ArrayType, ElementaryType, from slither.core.variables.variable import Variable
FunctionType, MappingType)
from slither.solc_parsing.solidity_types.type_parsing import (UnknownType,
parse_type)
from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.exceptions import ParsingError, VariableNotFound
from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type
logger = logging.getLogger("ExpressionParsing") if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.contract import ContractSolc
logger = logging.getLogger("ExpressionParsing")
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -46,21 +51,34 @@ logger = logging.getLogger("ExpressionParsing")
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_pointer_name(variable): CallerContext = Union["ContractSolc", "FunctionSolc"]
def get_pointer_name(variable: Variable):
curr_type = variable.type curr_type = variable.type
while (isinstance(curr_type, (ArrayType, MappingType))): while isinstance(curr_type, (ArrayType, MappingType)):
if isinstance(curr_type, ArrayType): if isinstance(curr_type, ArrayType):
curr_type = curr_type.type curr_type = curr_type.type
else: else:
assert isinstance(curr_type, MappingType) assert isinstance(curr_type, MappingType)
curr_type = curr_type.type_to curr_type = curr_type.type_to
if isinstance(curr_type, (FunctionType)): if isinstance(curr_type, FunctionType):
return variable.name + curr_type.parameters_signature return variable.name + curr_type.parameters_signature
return None return None
def find_variable(var_name, caller_context, referenced_declaration=None, is_super=False): def find_variable(
var_name: str,
caller_context: CallerContext,
referenced_declaration: Optional[int] = None,
is_super=False,
) -> Union[
Variable, Function, Contract, SolidityVariable, SolidityFunction, Event, Enum, Structure
]:
from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.function import FunctionSolc
# variable are looked from the contract declarer # variable are looked from the contract declarer
# functions can be shadowed, but are looked from the contract instance, rather than the contract declarer # functions can be shadowed, but are looked from the contract instance, rather than the contract declarer
# the difference between function and variable come from the fact that an internal call, or an variable access # the difference between function and variable come from the fact that an internal call, or an variable access
@ -76,24 +94,24 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
# for events it's unclear what should be the behavior, as they can be shadowed, but there is not impact # for events it's unclear what should be the behavior, as they can be shadowed, but there is not impact
# structure/enums cannot be shadowed # structure/enums cannot be shadowed
if isinstance(caller_context, Contract): if isinstance(caller_context, ContractSolc):
function = None function: Optional[FunctionSolc] = None
contract = caller_context contract = caller_context.underlying_contract
contract_declarer = caller_context contract_declarer = caller_context.underlying_contract
elif isinstance(caller_context, Function): elif isinstance(caller_context, FunctionSolc):
function = caller_context function = caller_context
contract = function.contract contract = function.underlying_function.contract
contract_declarer = function.contract_declarer contract_declarer = function.underlying_function.contract_declarer
else: else:
raise ParsingError('Incorrect caller context') raise ParsingError("Incorrect caller context")
if function: if function:
# We look for variable declared with the referencedDeclaration attr # We look for variable declared with the referencedDeclaration attr
func_variables = function.variables_renamed func_variables = function.variables_renamed
if referenced_declaration and referenced_declaration in func_variables: if referenced_declaration and referenced_declaration in func_variables:
return func_variables[referenced_declaration] return func_variables[referenced_declaration].underlying_variable
# If not found, check for name # If not found, check for name
func_variables = function.variables_as_dict() func_variables = function.underlying_function.variables_as_dict
if var_name in func_variables: if var_name in func_variables:
return func_variables[var_name] return func_variables[var_name]
# A local variable can be a pointer # A local variable can be a pointer
@ -101,12 +119,14 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
# function test(function(uint) internal returns(bool) t) interna{ # function test(function(uint) internal returns(bool) t) interna{
# Will have a local variable t which will match the signature # Will have a local variable t which will match the signature
# t(uint256) # t(uint256)
func_variables_ptr = {get_pointer_name(f): f for f in function.variables} func_variables_ptr = {
get_pointer_name(f): f for f in function.underlying_function.variables
}
if var_name and var_name in func_variables_ptr: if var_name and var_name in func_variables_ptr:
return func_variables_ptr[var_name] return func_variables_ptr[var_name]
# variable are looked from the contract declarer # variable are looked from the contract declarer
contract_variables = contract_declarer.variables_as_dict() contract_variables = contract_declarer.variables_as_dict
if var_name in contract_variables: if var_name in contract_variables:
return contract_variables[var_name] return contract_variables[var_name]
@ -118,8 +138,12 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
if is_super: if is_super:
getter_available = lambda f: f.functions_declared getter_available = lambda f: f.functions_declared
d = {f.canonical_name: f for f in contract.functions} d = {f.canonical_name: f for f in contract.functions}
functions = {f.full_name: f for f in functions = {
contract_declarer.available_elements_from_inheritances(d, getter_available).values()} f.full_name: f
for f in contract_declarer.available_elements_from_inheritances(
d, getter_available
).values()
}
else: else:
functions = contract.available_functions_as_dict() functions = contract.available_functions_as_dict()
if var_name in functions: if var_name in functions:
@ -128,23 +152,27 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
if is_super: if is_super:
getter_available = lambda m: m.modifiers_declared getter_available = lambda m: m.modifiers_declared
d = {m.canonical_name: m for m in contract.modifiers} d = {m.canonical_name: m for m in contract.modifiers}
modifiers = {m.full_name: m for m in modifiers = {
contract_declarer.available_elements_from_inheritances(d, getter_available).values()} m.full_name: m
for m in contract_declarer.available_elements_from_inheritances(
d, getter_available
).values()
}
else: else:
modifiers = contract.available_modifiers_as_dict() modifiers = contract.available_modifiers_as_dict()
if var_name in modifiers: if var_name in modifiers:
return modifiers[var_name] return modifiers[var_name]
# structures are looked on the contract declarer # structures are looked on the contract declarer
structures = contract.structures_as_dict() structures = contract.structures_as_dict
if var_name in structures: if var_name in structures:
return structures[var_name] return structures[var_name]
events = contract.events_as_dict() events = contract.events_as_dict
if var_name in events: if var_name in events:
return events[var_name] return events[var_name]
enums = contract.enums_as_dict() enums = contract.enums_as_dict
if var_name in enums: if var_name in enums:
return enums[var_name] return enums[var_name]
@ -154,7 +182,7 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
return enums[var_name] return enums[var_name]
# Could refer to any enum # Could refer to any enum
all_enums = [c.enums_as_dict() for c in contract.slither.contracts] all_enums = [c.enums_as_dict for c in contract.slither.contracts]
all_enums = {k: v for d in all_enums for k, v in d.items()} all_enums = {k: v for d in all_enums for k, v in d.items()}
if var_name in all_enums: if var_name in all_enums:
return all_enums[var_name] return all_enums[var_name]
@ -165,19 +193,22 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
if var_name in SOLIDITY_FUNCTIONS: if var_name in SOLIDITY_FUNCTIONS:
return SolidityFunction(var_name) return SolidityFunction(var_name)
contracts = contract.slither.contracts_as_dict() contracts = contract.slither.contracts_as_dict
if var_name in contracts: if var_name in contracts:
return contracts[var_name] return contracts[var_name]
if referenced_declaration: if referenced_declaration:
for contract in contract.slither.contracts: # id of the contracts is the referenced declaration
if contract.id == referenced_declaration: # This is not true for the functions, as we dont always have the referenced_declaration
return contract # But maybe we could? (TODO)
for function in contract.slither.functions: for contract_candidate in contract.slither.contracts:
if function.referenced_declaration == referenced_declaration: if contract_candidate.id == referenced_declaration:
return function return contract_candidate
for function_candidate in caller_context.slither_parser.all_functions_parser:
if function_candidate.referenced_declaration == referenced_declaration:
return function_candidate.underlying_function
raise VariableNotFound('Variable not found: {} (context {})'.format(var_name, caller_context)) raise VariableNotFound("Variable not found: {} (context {})".format(var_name, caller_context))
# endregion # endregion
@ -187,38 +218,39 @@ def find_variable(var_name, caller_context, referenced_declaration=None, is_supe
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def filter_name(value):
value = value.replace(' memory', '') def filter_name(value: str) -> str:
value = value.replace(' storage', '') value = value.replace(" memory", "")
value = value.replace(' external', '') value = value.replace(" storage", "")
value = value.replace(' internal', '') value = value.replace(" external", "")
value = value.replace('struct ', '') value = value.replace(" internal", "")
value = value.replace('contract ', '') value = value.replace("struct ", "")
value = value.replace('enum ', '') value = value.replace("contract ", "")
value = value.replace(' ref', '') value = value.replace("enum ", "")
value = value.replace(' pointer', '') value = value.replace(" ref", "")
value = value.replace(' pure', '') value = value.replace(" pointer", "")
value = value.replace(' view', '') value = value.replace(" pure", "")
value = value.replace(' constant', '') value = value.replace(" view", "")
value = value.replace(' payable', '') value = value.replace(" constant", "")
value = value.replace('function (', 'function(') value = value.replace(" payable", "")
value = value.replace('returns (', 'returns(') value = value.replace("function (", "function(")
value = value.replace("returns (", "returns(")
# remove the text remaining after functio(...) # remove the text remaining after functio(...)
# which should only be ..returns(...) # which should only be ..returns(...)
# nested parenthesis so we use a system of counter on parenthesis # nested parenthesis so we use a system of counter on parenthesis
idx = value.find('(') idx = value.find("(")
if idx: if idx:
counter = 1 counter = 1
max_idx = len(value) max_idx = len(value)
while counter: while counter:
assert idx < max_idx assert idx < max_idx
idx = idx + 1 idx = idx + 1
if value[idx] == '(': if value[idx] == "(":
counter += 1 counter += 1
elif value[idx] == ')': elif value[idx] == ")":
counter -= 1 counter -= 1
value = value[:idx + 1] value = value[: idx + 1]
return value return value
@ -230,35 +262,38 @@ def filter_name(value):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def parse_call(expression, caller_context):
src = expression['src'] def parse_call(expression: Dict, caller_context):
src = expression["src"]
if caller_context.is_compact_ast: if caller_context.is_compact_ast:
attributes = expression attributes = expression
type_conversion = expression['kind'] == 'typeConversion' type_conversion = expression["kind"] == "typeConversion"
type_return = attributes['typeDescriptions']['typeString'] type_return = attributes["typeDescriptions"]["typeString"]
else: else:
attributes = expression['attributes'] attributes = expression["attributes"]
type_conversion = attributes['type_conversion'] type_conversion = attributes["type_conversion"]
type_return = attributes['type'] type_return = attributes["type"]
if type_conversion: if type_conversion:
type_call = parse_type(UnknownType(type_return), caller_context) type_call = parse_type(UnknownType(type_return), caller_context)
if caller_context.is_compact_ast: if caller_context.is_compact_ast:
assert len(expression['arguments']) == 1 assert len(expression["arguments"]) == 1
expression_to_parse = expression['arguments'][0] expression_to_parse = expression["arguments"][0]
else: else:
children = expression['children'] children = expression["children"]
assert len(children) == 2 assert len(children) == 2
type_info = children[0] type_info = children[0]
expression_to_parse = children[1] expression_to_parse = children[1]
assert type_info['name'] in ['ElementaryTypenameExpression', assert type_info["name"] in [
'ElementaryTypeNameExpression', "ElementaryTypenameExpression",
'Identifier', "ElementaryTypeNameExpression",
'TupleExpression', "Identifier",
'IndexAccess', "TupleExpression",
'MemberAccess'] "IndexAccess",
"MemberAccess",
]
expression = parse_expression(expression_to_parse, caller_context) expression = parse_expression(expression_to_parse, caller_context)
t = TypeConversion(expression, type_call) t = TypeConversion(expression, type_call)
@ -269,33 +304,33 @@ def parse_call(expression, caller_context):
call_value = None call_value = None
call_salt = None call_salt = None
if caller_context.is_compact_ast: if caller_context.is_compact_ast:
called = parse_expression(expression['expression'], caller_context) called = parse_expression(expression["expression"], caller_context)
# If the next expression is a FunctionCallOptions # If the next expression is a FunctionCallOptions
# We can here the gas/value information # We can here the gas/value information
# This is only available if the syntax is {gas: , value: } # This is only available if the syntax is {gas: , value: }
# For the .gas().value(), the member are considered as function call # For the .gas().value(), the member are considered as function call
# And converted later to the correct info (convert.py) # And converted later to the correct info (convert.py)
if expression['expression'][caller_context.get_key()] == 'FunctionCallOptions': if expression["expression"][caller_context.get_key()] == "FunctionCallOptions":
call_with_options = expression['expression'] call_with_options = expression["expression"]
for idx, name in enumerate(call_with_options.get('names', [])): for idx, name in enumerate(call_with_options.get("names", [])):
option = parse_expression(call_with_options['options'][idx], caller_context) option = parse_expression(call_with_options["options"][idx], caller_context)
if name == 'value': if name == "value":
call_value = option call_value = option
if name == 'gas': if name == "gas":
call_gas = option call_gas = option
if name == 'salt': if name == "salt":
call_salt = option call_salt = option
arguments = [] arguments = []
if expression['arguments']: if expression["arguments"]:
arguments = [parse_expression(a, caller_context) for a in expression['arguments']] arguments = [parse_expression(a, caller_context) for a in expression["arguments"]]
else: else:
children = expression['children'] children = expression["children"]
called = parse_expression(children[0], caller_context) called = parse_expression(children[0], caller_context)
arguments = [parse_expression(a, caller_context) for a in children[1::]] arguments = [parse_expression(a, caller_context) for a in children[1::]]
if isinstance(called, SuperCallExpression): if isinstance(called, SuperCallExpression):
sp = SuperCallExpression(called, arguments, type_return) sp = SuperCallExpression(called, arguments, type_return)
sp.set_offset(expression['src'], caller_context.slither) sp.set_offset(expression["src"], caller_context.slither)
return sp return sp
call_expression = CallExpression(called, arguments, type_return) call_expression = CallExpression(called, arguments, type_return)
call_expression.set_offset(src, caller_context.slither) call_expression.set_offset(src, caller_context.slither)
@ -307,46 +342,48 @@ def parse_call(expression, caller_context):
return call_expression return call_expression
def parse_super_name(expression, is_compact_ast): def parse_super_name(expression: Dict, is_compact_ast: bool) -> str:
if is_compact_ast: if is_compact_ast:
assert expression['nodeType'] == 'MemberAccess' assert expression["nodeType"] == "MemberAccess"
base_name = expression['memberName'] base_name = expression["memberName"]
arguments = expression['typeDescriptions']['typeString'] arguments = expression["typeDescriptions"]["typeString"]
else: else:
assert expression['name'] == 'MemberAccess' assert expression["name"] == "MemberAccess"
attributes = expression['attributes'] attributes = expression["attributes"]
base_name = attributes['member_name'] base_name = attributes["member_name"]
arguments = attributes['type'] arguments = attributes["type"]
assert arguments.startswith('function ') assert arguments.startswith("function ")
# remove function (...() # remove function (...()
arguments = arguments[len('function '):] arguments = arguments[len("function ") :]
arguments = filter_name(arguments) arguments = filter_name(arguments)
if ' ' in arguments: if " " in arguments:
arguments = arguments[:arguments.find(' ')] arguments = arguments[: arguments.find(" ")]
return base_name + arguments return base_name + arguments
def _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context): def _parse_elementary_type_name_expression(
expression: Dict, is_compact_ast: bool, caller_context
) -> ElementaryTypeNameExpression:
# nop exression # nop exression
# uint; # uint;
if is_compact_ast: if is_compact_ast:
value = expression['typeName'] value = expression["typeName"]
else: else:
assert 'children' not in expression assert "children" not in expression
value = expression['attributes']['value'] value = expression["attributes"]["value"]
if isinstance(value, dict): if isinstance(value, dict):
t = parse_type(value, caller_context) t = parse_type(value, caller_context)
else: else:
t = parse_type(UnknownType(value), caller_context) t = parse_type(UnknownType(value), caller_context)
e = ElementaryTypeNameExpression(t) e = ElementaryTypeNameExpression(t)
e.set_offset(expression['src'], caller_context.slither) e.set_offset(expression["src"], caller_context.slither)
return e return e
def parse_expression(expression, caller_context): def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expression":
""" """
Returns: Returns:
@ -378,53 +415,53 @@ def parse_expression(expression, caller_context):
# The AST naming does not follow the spec # The AST naming does not follow the spec
name = expression[caller_context.get_key()] name = expression[caller_context.get_key()]
is_compact_ast = caller_context.is_compact_ast is_compact_ast = caller_context.is_compact_ast
src = expression['src'] src = expression["src"]
if name == 'UnaryOperation': if name == "UnaryOperation":
if is_compact_ast: if is_compact_ast:
attributes = expression attributes = expression
else: else:
attributes = expression['attributes'] attributes = expression["attributes"]
assert 'prefix' in attributes assert "prefix" in attributes
operation_type = UnaryOperationType.get_type(attributes['operator'], attributes['prefix']) operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"])
if is_compact_ast: if is_compact_ast:
expression = parse_expression(expression['subExpression'], caller_context) expression = parse_expression(expression["subExpression"], caller_context)
else: else:
assert len(expression['children']) == 1 assert len(expression["children"]) == 1
expression = parse_expression(expression['children'][0], caller_context) expression = parse_expression(expression["children"][0], caller_context)
unary_op = UnaryOperation(expression, operation_type) unary_op = UnaryOperation(expression, operation_type)
unary_op.set_offset(src, caller_context.slither) unary_op.set_offset(src, caller_context.slither)
return unary_op return unary_op
elif name == 'BinaryOperation': elif name == "BinaryOperation":
if is_compact_ast: if is_compact_ast:
attributes = expression attributes = expression
else: else:
attributes = expression['attributes'] attributes = expression["attributes"]
operation_type = BinaryOperationType.get_type(attributes['operator']) operation_type = BinaryOperationType.get_type(attributes["operator"])
if is_compact_ast: if is_compact_ast:
left_expression = parse_expression(expression['leftExpression'], caller_context) left_expression = parse_expression(expression["leftExpression"], caller_context)
right_expression = parse_expression(expression['rightExpression'], caller_context) right_expression = parse_expression(expression["rightExpression"], caller_context)
else: else:
assert len(expression['children']) == 2 assert len(expression["children"]) == 2
left_expression = parse_expression(expression['children'][0], caller_context) left_expression = parse_expression(expression["children"][0], caller_context)
right_expression = parse_expression(expression['children'][1], caller_context) right_expression = parse_expression(expression["children"][1], caller_context)
binary_op = BinaryOperation(left_expression, right_expression, operation_type) binary_op = BinaryOperation(left_expression, right_expression, operation_type)
binary_op.set_offset(src, caller_context.slither) binary_op.set_offset(src, caller_context.slither)
return binary_op return binary_op
elif name in 'FunctionCall': elif name in "FunctionCall":
return parse_call(expression, caller_context) return parse_call(expression, caller_context)
elif name == 'FunctionCallOptions': elif name == "FunctionCallOptions":
# call/gas info are handled in parse_call # call/gas info are handled in parse_call
called = parse_expression(expression['expression'], caller_context) called = parse_expression(expression["expression"], caller_context)
assert isinstance(called, (MemberAccess, NewContract)) assert isinstance(called, (MemberAccess, NewContract))
return called return called
elif name == 'TupleExpression': elif name == "TupleExpression":
""" """
For expression like For expression like
(a,,c) = (1,2,3) (a,,c) = (1,2,3)
@ -437,35 +474,39 @@ def parse_expression(expression, caller_context):
Note: this is only possible with Solidity >= 0.4.12 Note: this is only possible with Solidity >= 0.4.12
""" """
if is_compact_ast: if is_compact_ast:
expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']] expressions = [
else: parse_expression(e, caller_context) if e else None for e in expression["components"]
if 'children' not in expression: ]
attributes = expression['attributes'] else:
components = attributes['components'] if "children" not in expression:
expressions = [parse_expression(c, caller_context) if c else None for c in components] attributes = expression["attributes"]
else: components = attributes["components"]
expressions = [parse_expression(e, caller_context) for e in expression['children']] expressions = [
parse_expression(c, caller_context) if c else None for c in components
]
else:
expressions = [parse_expression(e, caller_context) for e in expression["children"]]
# Add none for empty tuple items # Add none for empty tuple items
if "attributes" in expression: if "attributes" in expression:
if "type" in expression['attributes']: if "type" in expression["attributes"]:
t = expression['attributes']['type'] t = expression["attributes"]["type"]
if ',,' in t or '(,' in t or ',)' in t: if ",," in t or "(," in t or ",)" in t:
t = t[len('tuple('):-1] t = t[len("tuple(") : -1]
elems = t.split(',') elems = t.split(",")
for idx in range(len(elems)): for idx in range(len(elems)):
if elems[idx] == '': if elems[idx] == "":
expressions.insert(idx, None) expressions.insert(idx, None)
t = TupleExpression(expressions) t = TupleExpression(expressions)
t.set_offset(src, caller_context.slither) t.set_offset(src, caller_context.slither)
return t return t
elif name == 'Conditional': elif name == "Conditional":
if is_compact_ast: if is_compact_ast:
if_expression = parse_expression(expression['condition'], caller_context) if_expression = parse_expression(expression["condition"], caller_context)
then_expression = parse_expression(expression['trueExpression'], caller_context) then_expression = parse_expression(expression["trueExpression"], caller_context)
else_expression = parse_expression(expression['falseExpression'], caller_context) else_expression = parse_expression(expression["falseExpression"], caller_context)
else: else:
children = expression['children'] children = expression["children"]
assert len(children) == 3 assert len(children) == 3
if_expression = parse_expression(children[0], caller_context) if_expression = parse_expression(children[0], caller_context)
then_expression = parse_expression(children[1], caller_context) then_expression = parse_expression(children[1], caller_context)
@ -474,100 +515,103 @@ def parse_expression(expression, caller_context):
conditional.set_offset(src, caller_context.slither) conditional.set_offset(src, caller_context.slither)
return conditional return conditional
elif name == 'Assignment': elif name == "Assignment":
if is_compact_ast: if is_compact_ast:
left_expression = parse_expression(expression['leftHandSide'], caller_context) left_expression = parse_expression(expression["leftHandSide"], caller_context)
right_expression = parse_expression(expression['rightHandSide'], caller_context) right_expression = parse_expression(expression["rightHandSide"], caller_context)
operation_type = AssignmentOperationType.get_type(expression['operator']) operation_type = AssignmentOperationType.get_type(expression["operator"])
operation_return_type = expression['typeDescriptions']['typeString'] operation_return_type = expression["typeDescriptions"]["typeString"]
else: else:
attributes = expression['attributes'] attributes = expression["attributes"]
children = expression['children'] children = expression["children"]
assert len(expression['children']) == 2 assert len(expression["children"]) == 2
left_expression = parse_expression(children[0], caller_context) left_expression = parse_expression(children[0], caller_context)
right_expression = parse_expression(children[1], caller_context) right_expression = parse_expression(children[1], caller_context)
operation_type = AssignmentOperationType.get_type(attributes['operator']) operation_type = AssignmentOperationType.get_type(attributes["operator"])
operation_return_type = attributes['type'] operation_return_type = attributes["type"]
assignement = AssignmentOperation(left_expression, right_expression, operation_type, operation_return_type) assignement = AssignmentOperation(
left_expression, right_expression, operation_type, operation_return_type
)
assignement.set_offset(src, caller_context.slither) assignement.set_offset(src, caller_context.slither)
return assignement return assignement
elif name == "Literal":
elif name == 'Literal':
subdenomination = None subdenomination = None
assert 'children' not in expression assert "children" not in expression
if is_compact_ast: if is_compact_ast:
value = expression['value'] value = expression["value"]
if value: if value:
if 'subdenomination' in expression and expression['subdenomination']: if "subdenomination" in expression and expression["subdenomination"]:
subdenomination = expression['subdenomination'] subdenomination = expression["subdenomination"]
elif not value and value != "": elif not value and value != "":
value = '0x' + expression['hexValue'] value = "0x" + expression["hexValue"]
type = expression['typeDescriptions']['typeString'] type_candidate = expression["typeDescriptions"]["typeString"]
# Length declaration for array was None until solc 0.5.5 # Length declaration for array was None until solc 0.5.5
if type is None: if type_candidate is None:
if expression['kind'] == 'number': if expression["kind"] == "number":
type = 'int_const' type_candidate = "int_const"
else: else:
value = expression['attributes']['value'] value = expression["attributes"]["value"]
if value: if value:
if 'subdenomination' in expression['attributes'] and expression['attributes']['subdenomination']: if (
subdenomination = expression['attributes']['subdenomination'] "subdenomination" in expression["attributes"]
and expression["attributes"]["subdenomination"]
):
subdenomination = expression["attributes"]["subdenomination"]
elif value is None: elif value is None:
# for literal declared as hex # for literal declared as hex
# see https://solidity.readthedocs.io/en/v0.4.25/types.html?highlight=hex#hexadecimal-literals # see https://solidity.readthedocs.io/en/v0.4.25/types.html?highlight=hex#hexadecimal-literals
assert 'hexvalue' in expression['attributes'] assert "hexvalue" in expression["attributes"]
value = '0x' + expression['attributes']['hexvalue'] value = "0x" + expression["attributes"]["hexvalue"]
type = expression['attributes']['type'] type_candidate = expression["attributes"]["type"]
if type is None: if type_candidate is None:
if value.isdecimal(): if value.isdecimal():
type = ElementaryType('uint256') type_candidate = ElementaryType("uint256")
else: else:
type = ElementaryType('string') type_candidate = ElementaryType("string")
elif type.startswith('int_const '): elif type_candidate.startswith("int_const "):
type = ElementaryType('uint256') type_candidate = ElementaryType("uint256")
elif type.startswith('bool'): elif type_candidate.startswith("bool"):
type = ElementaryType('bool') type_candidate = ElementaryType("bool")
elif type.startswith('address'): elif type_candidate.startswith("address"):
type = ElementaryType('address') type_candidate = ElementaryType("address")
else: else:
type = ElementaryType('string') type_candidate = ElementaryType("string")
literal = Literal(value, type, subdenomination) literal = Literal(value, type_candidate, subdenomination)
literal.set_offset(src, caller_context.slither) literal.set_offset(src, caller_context.slither)
return literal return literal
elif name == 'Identifier': elif name == "Identifier":
assert 'children' not in expression assert "children" not in expression
t = None t = None
if caller_context.is_compact_ast: if caller_context.is_compact_ast:
value = expression['name'] value = expression["name"]
t = expression['typeDescriptions']['typeString'] t = expression["typeDescriptions"]["typeString"]
else: else:
value = expression['attributes']['value'] value = expression["attributes"]["value"]
if 'type' in expression['attributes']: if "type" in expression["attributes"]:
t = expression['attributes']['type'] t = expression["attributes"]["type"]
if t: if t:
found = re.findall('[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)', t) found = re.findall("[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)", t)
assert len(found) <= 1 assert len(found) <= 1
if found: if found:
value = value + '(' + found[0] + ')' value = value + "(" + found[0] + ")"
value = filter_name(value) value = filter_name(value)
if 'referencedDeclaration' in expression: if "referencedDeclaration" in expression:
referenced_declaration = expression['referencedDeclaration'] referenced_declaration = expression["referencedDeclaration"]
else: else:
referenced_declaration = None referenced_declaration = None
@ -577,14 +621,14 @@ def parse_expression(expression, caller_context):
identifier.set_offset(src, caller_context.slither) identifier.set_offset(src, caller_context.slither)
return identifier return identifier
elif name == 'IndexAccess': elif name == "IndexAccess":
if is_compact_ast: if is_compact_ast:
index_type = expression['typeDescriptions']['typeString'] index_type = expression["typeDescriptions"]["typeString"]
left = expression['baseExpression'] left = expression["baseExpression"]
right = expression['indexExpression'] right = expression["indexExpression"]
else: else:
index_type = expression['attributes']['type'] index_type = expression["attributes"]["type"]
children = expression['children'] children = expression["children"]
assert len(children) == 2 assert len(children) == 2
left = children[0] left = children[0]
right = children[1] right = children[1]
@ -600,22 +644,22 @@ def parse_expression(expression, caller_context):
index.set_offset(src, caller_context.slither) index.set_offset(src, caller_context.slither)
return index return index
elif name == 'MemberAccess': elif name == "MemberAccess":
if caller_context.is_compact_ast: if caller_context.is_compact_ast:
member_name = expression['memberName'] member_name = expression["memberName"]
member_type = expression['typeDescriptions']['typeString'] member_type = expression["typeDescriptions"]["typeString"]
member_expression = parse_expression(expression['expression'], caller_context) member_expression = parse_expression(expression["expression"], caller_context)
else: else:
member_name = expression['attributes']['member_name'] member_name = expression["attributes"]["member_name"]
member_type = expression['attributes']['type'] member_type = expression["attributes"]["type"]
children = expression['children'] children = expression["children"]
assert len(children) == 1 assert len(children) == 1
member_expression = parse_expression(children[0], caller_context) member_expression = parse_expression(children[0], caller_context)
if str(member_expression) == 'super': if str(member_expression) == "super":
super_name = parse_super_name(expression, is_compact_ast) super_name = parse_super_name(expression, is_compact_ast)
var = find_variable(super_name, caller_context, is_super=True) var = find_variable(super_name, caller_context, is_super=True)
if var is None: if var is None:
raise VariableNotFound('Variable not found: {}'.format(super_name)) raise VariableNotFound("Variable not found: {}".format(super_name))
sup = SuperIdentifier(var) sup = SuperIdentifier(var)
sup.set_offset(src, caller_context.slither) sup.set_offset(src, caller_context.slither)
return sup return sup
@ -627,81 +671,82 @@ def parse_expression(expression, caller_context):
return idx return idx
return member_access return member_access
elif name == 'ElementaryTypeNameExpression': elif name == "ElementaryTypeNameExpression":
return _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context) return _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context)
# NewExpression is not a root expression, it's always the child of another expression # NewExpression is not a root expression, it's always the child of another expression
elif name == 'NewExpression': elif name == "NewExpression":
if is_compact_ast: if is_compact_ast:
type_name = expression['typeName'] type_name = expression["typeName"]
else: else:
children = expression['children'] children = expression["children"]
assert len(children) == 1 assert len(children) == 1
type_name = children[0] type_name = children[0]
if type_name[caller_context.get_key()] == 'ArrayTypeName': if type_name[caller_context.get_key()] == "ArrayTypeName":
depth = 0 depth = 0
while type_name[caller_context.get_key()] == 'ArrayTypeName': while type_name[caller_context.get_key()] == "ArrayTypeName":
# Note: dont conserve the size of the array if provided # Note: dont conserve the size of the array if provided
# We compute it directly # We compute it directly
if is_compact_ast: if is_compact_ast:
type_name = type_name['baseType'] type_name = type_name["baseType"]
else: else:
type_name = type_name['children'][0] type_name = type_name["children"][0]
depth += 1 depth += 1
if type_name[caller_context.get_key()] == 'ElementaryTypeName': if type_name[caller_context.get_key()] == "ElementaryTypeName":
if is_compact_ast: if is_compact_ast:
array_type = ElementaryType(type_name['name']) array_type = ElementaryType(type_name["name"])
else: else:
array_type = ElementaryType(type_name['attributes']['name']) array_type = ElementaryType(type_name["attributes"]["name"])
elif type_name[caller_context.get_key()] == 'UserDefinedTypeName': elif type_name[caller_context.get_key()] == "UserDefinedTypeName":
if is_compact_ast: if is_compact_ast:
array_type = parse_type(UnknownType(type_name['name']), caller_context) array_type = parse_type(UnknownType(type_name["name"]), caller_context)
else: else:
array_type = parse_type(UnknownType(type_name['attributes']['name']), caller_context) array_type = parse_type(
elif type_name[caller_context.get_key()] == 'FunctionTypeName': UnknownType(type_name["attributes"]["name"]), caller_context
)
elif type_name[caller_context.get_key()] == "FunctionTypeName":
array_type = parse_type(type_name, caller_context) array_type = parse_type(type_name, caller_context)
else: else:
raise ParsingError('Incorrect type array {}'.format(type_name)) raise ParsingError("Incorrect type array {}".format(type_name))
array = NewArray(depth, array_type) array = NewArray(depth, array_type)
array.set_offset(src, caller_context.slither) array.set_offset(src, caller_context.slither)
return array return array
if type_name[caller_context.get_key()] == 'ElementaryTypeName': if type_name[caller_context.get_key()] == "ElementaryTypeName":
if is_compact_ast: if is_compact_ast:
elem_type = ElementaryType(type_name['name']) elem_type = ElementaryType(type_name["name"])
else: else:
elem_type = ElementaryType(type_name['attributes']['name']) elem_type = ElementaryType(type_name["attributes"]["name"])
new_elem = NewElementaryType(elem_type) new_elem = NewElementaryType(elem_type)
new_elem.set_offset(src, caller_context.slither) new_elem.set_offset(src, caller_context.slither)
return new_elem return new_elem
assert type_name[caller_context.get_key()] == 'UserDefinedTypeName' assert type_name[caller_context.get_key()] == "UserDefinedTypeName"
if is_compact_ast: if is_compact_ast:
contract_name = type_name['name'] contract_name = type_name["name"]
else: else:
contract_name = type_name['attributes']['name'] contract_name = type_name["attributes"]["name"]
new = NewContract(contract_name) new = NewContract(contract_name)
new.set_offset(src, caller_context.slither) new.set_offset(src, caller_context.slither)
return new return new
elif name == 'ModifierInvocation': elif name == "ModifierInvocation":
if is_compact_ast: if is_compact_ast:
called = parse_expression(expression['modifierName'], caller_context) called = parse_expression(expression["modifierName"], caller_context)
arguments = [] arguments = []
if expression['arguments']: if expression["arguments"]:
arguments = [parse_expression(a, caller_context) for a in expression['arguments']] arguments = [parse_expression(a, caller_context) for a in expression["arguments"]]
else: else:
children = expression['children'] children = expression["children"]
called = parse_expression(children[0], caller_context) called = parse_expression(children[0], caller_context)
arguments = [parse_expression(a, caller_context) for a in children[1::]] arguments = [parse_expression(a, caller_context) for a in children[1::]]
call = CallExpression(called, arguments, 'Modifier') call = CallExpression(called, arguments, "Modifier")
call.set_offset(src, caller_context.slither) call.set_offset(src, caller_context.slither)
return call return call
raise ParsingError('Expression not parsed %s' % name) raise ParsingError("Expression not parsed %s" % name)

@ -1,8 +1,8 @@
import os
import json import json
import re
import logging import logging
from typing import Optional, List import os
import re
from typing import List, Dict
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
@ -11,44 +11,63 @@ logging.basicConfig()
logger = logging.getLogger("SlitherSolcParsing") logger = logging.getLogger("SlitherSolcParsing")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
from slither.solc_parsing.declarations.contract import ContractSolc04 from slither.solc_parsing.declarations.contract import ContractSolc
from slither.core.slither_core import Slither from slither.solc_parsing.declarations.function import FunctionSolc
from slither.core.slither_core import SlitherCore
from slither.core.declarations.pragma_directive import Pragma from slither.core.declarations.pragma_directive import Pragma
from slither.core.declarations.import_directive import Import from slither.core.declarations.import_directive import Import
from slither.analyses.data_dependency.data_dependency import compute_dependency from slither.analyses.data_dependency.data_dependency import compute_dependency
class SlitherSolc(Slither): class SlitherSolc:
def __init__(self, filename: str, core: SlitherCore):
def __init__(self, filename):
super(SlitherSolc, self).__init__() super(SlitherSolc, self).__init__()
self._filename = filename core.filename = filename
self._contractsNotParsed = [] self._contracts_by_id: Dict[int, ContractSolc] = {}
self._contracts_by_id = {}
self._analyzed = False self._analyzed = False
self._underlying_contract_to_parser: Dict[Contract, ContractSolc] = dict()
self._is_compact_ast = False self._is_compact_ast = False
self._core: SlitherCore = core
self._all_functions_parser: List[FunctionSolc] = []
self._top_level_contracts_counter = 0 self._top_level_contracts_counter = 0
@property
def core(self):
return self._core
@property
def all_functions_parser(self) -> List[FunctionSolc]:
return self._all_functions_parser
def add_functions_parser(self, f: FunctionSolc):
self._all_functions_parser.append(f)
@property
def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]:
return self._underlying_contract_to_parser
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region AST # region AST
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_key(self): def get_key(self) -> str:
if self._is_compact_ast: if self._is_compact_ast:
return 'nodeType' return "nodeType"
return 'name' return "name"
def get_children(self): def get_children(self) -> str:
if self._is_compact_ast: if self._is_compact_ast:
return 'nodes' return "nodes"
return 'children' return "children"
@property @property
def is_compact_ast(self): def is_compact_ast(self) -> bool:
return self._is_compact_ast return self._is_compact_ast
# endregion # endregion
@ -58,139 +77,147 @@ class SlitherSolc(Slither):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def _parse_contracts_from_json(self, json_data): def parse_contracts_from_json(self, json_data: str) -> bool:
try: try:
data_loaded = json.loads(json_data) data_loaded = json.loads(json_data)
# Truffle AST # Truffle AST
if 'ast' in data_loaded: if "ast" in data_loaded:
self._parse_contracts_from_loaded_json(data_loaded['ast'], data_loaded['sourcePath']) self.parse_contracts_from_loaded_json(data_loaded["ast"], data_loaded["sourcePath"])
return True return True
# solc AST, where the non-json text was removed # solc AST, where the non-json text was removed
else: else:
if 'attributes' in data_loaded: if "attributes" in data_loaded:
filename = data_loaded['attributes']['absolutePath'] filename = data_loaded["attributes"]["absolutePath"]
else: else:
filename = data_loaded['absolutePath'] filename = data_loaded["absolutePath"]
self._parse_contracts_from_loaded_json(data_loaded, filename) self.parse_contracts_from_loaded_json(data_loaded, filename)
return True return True
except ValueError: except ValueError:
first = json_data.find('{') first = json_data.find("{")
if first != -1: if first != -1:
last = json_data.rfind('}') + 1 last = json_data.rfind("}") + 1
filename = json_data[0:first] filename = json_data[0:first]
json_data = json_data[first:last] json_data = json_data[first:last]
data_loaded = json.loads(json_data) data_loaded = json.loads(json_data)
self._parse_contracts_from_loaded_json(data_loaded, filename) self.parse_contracts_from_loaded_json(data_loaded, filename)
return True return True
return False return False
def _parse_contracts_from_loaded_json(self, data_loaded, filename): def parse_contracts_from_loaded_json(self, data_loaded: Dict, filename: str):
if 'nodeType' in data_loaded: if "nodeType" in data_loaded:
self._is_compact_ast = True self._is_compact_ast = True
if 'sourcePaths' in data_loaded: if "sourcePaths" in data_loaded:
for sourcePath in data_loaded['sourcePaths']: for sourcePath in data_loaded["sourcePaths"]:
if os.path.isfile(sourcePath): if os.path.isfile(sourcePath):
self._add_source_code(sourcePath) self._core.add_source_code(sourcePath)
if data_loaded[self.get_key()] == 'root': if data_loaded[self.get_key()] == "root":
self._solc_version = '0.3' self._solc_version = "0.3"
logger.error('solc <0.4 is not supported') logger.error("solc <0.4 is not supported")
return return
elif data_loaded[self.get_key()] == 'SourceUnit': elif data_loaded[self.get_key()] == "SourceUnit":
self._solc_version = '0.4' self._solc_version = "0.4"
self._parse_source_unit(data_loaded, filename) self._parse_source_unit(data_loaded, filename)
else: else:
logger.error('solc version is not supported') logger.error("solc version is not supported")
return return
for contract_data in data_loaded[self.get_children()]: for contract_data in data_loaded[self.get_children()]:
assert contract_data[self.get_key()] in ['ContractDefinition', assert contract_data[self.get_key()] in [
'PragmaDirective', "ContractDefinition",
'ImportDirective', "PragmaDirective",
'StructDefinition', "ImportDirective",
'EnumDefinition'] "StructDefinition",
if contract_data[self.get_key()] == 'ContractDefinition': "EnumDefinition",
contract = ContractSolc04(self, contract_data) ]
if 'src' in contract_data: if contract_data[self.get_key()] == "ContractDefinition":
contract.set_offset(contract_data['src'], self) contract = Contract()
self._contractsNotParsed.append(contract) contract_parser = ContractSolc(self, contract, contract_data)
elif contract_data[self.get_key()] == 'PragmaDirective': if "src" in contract_data:
contract.set_offset(contract_data["src"], self._core)
self._underlying_contract_to_parser[contract] = contract_parser
elif contract_data[self.get_key()] == "PragmaDirective":
if self._is_compact_ast: if self._is_compact_ast:
pragma = Pragma(contract_data['literals']) pragma = Pragma(contract_data["literals"])
else: else:
pragma = Pragma(contract_data['attributes']["literals"]) pragma = Pragma(contract_data["attributes"]["literals"])
pragma.set_offset(contract_data['src'], self) pragma.set_offset(contract_data["src"], self._core)
self._pragma_directives.append(pragma) self._core.pragma_directives.append(pragma)
elif contract_data[self.get_key()] == 'ImportDirective': elif contract_data[self.get_key()] == "ImportDirective":
if self.is_compact_ast: if self.is_compact_ast:
import_directive = Import(contract_data["absolutePath"]) import_directive = Import(contract_data["absolutePath"])
else: else:
import_directive = Import(contract_data['attributes']["absolutePath"]) import_directive = Import(contract_data["attributes"]["absolutePath"])
import_directive.set_offset(contract_data['src'], self) import_directive.set_offset(contract_data["src"], self._core)
self._import_directives.append(import_directive) self._core.import_directives.append(import_directive)
elif contract_data[self.get_key()] in ['StructDefinition', 'EnumDefinition']: elif contract_data[self.get_key()] in ["StructDefinition", "EnumDefinition"]:
# This can only happen for top-level structure and enum # This can only happen for top-level structure and enum
# They were introduced with 0.6.5 # They were introduced with 0.6.5
assert self._is_compact_ast # Do not support top level definition for legacy AST assert self._is_compact_ast # Do not support top level definition for legacy AST
fake_contract_data = { fake_contract_data = {
'name': f'SlitherInternalTopLevelContract{self._top_level_contracts_counter}', "name": f"SlitherInternalTopLevelContract{self._top_level_contracts_counter}",
'id': -1000, # TODO: determine if collission possible "id": -1000, # TODO: determine if collission possible
'linearizedBaseContracts': [], "linearizedBaseContracts": [],
'fullyImplemented': True, "fullyImplemented": True,
'contractKind': 'SLitherInternal' "contractKind": "SLitherInternal",
} }
self._top_level_contracts_counter += 1 self._top_level_contracts_counter += 1
top_level_contract = ContractSolc04(self, fake_contract_data) top_level_contract = ContractSolc(self, fake_contract_data)
top_level_contract.is_top_level = True top_level_contract.is_top_level = True
top_level_contract.set_offset(contract_data['src'], self) top_level_contract.set_offset(contract_data["src"], self)
if contract_data[self.get_key()] == 'StructDefinition': if contract_data[self.get_key()] == "StructDefinition":
top_level_contract._structuresNotParsed.append(contract_data) # Todo add proper setters top_level_contract._structuresNotParsed.append(
contract_data
) # Todo add proper setters
else: else:
top_level_contract._enumsNotParsed.append(contract_data) # Todo add proper setters top_level_contract._enumsNotParsed.append(
contract_data
) # Todo add proper setters
self._contractsNotParsed.append(top_level_contract) self._contractsNotParsed.append(top_level_contract)
def _parse_source_unit(self, data: Dict, filename: str):
def _parse_source_unit(self, data, filename): if data[self.get_key()] != "SourceUnit":
if data[self.get_key()] != 'SourceUnit':
return -1 # handle solc prior 0.3.6 return -1 # handle solc prior 0.3.6
# match any char for filename # match any char for filename
# filename can contain space, /, -, .. # filename can contain space, /, -, ..
name = re.findall('=+ (.+) =+', filename) name_candidates = re.findall("=+ (.+) =+", filename)
if name: if name_candidates:
assert len(name) == 1 assert len(name_candidates) == 1
name = name[0] name: str = name_candidates[0]
else: else:
name = filename name = filename
sourceUnit = -1 # handle old solc, or error sourceUnit = -1 # handle old solc, or error
if 'src' in data: if "src" in data:
sourceUnit = re.findall('[0-9]*:[0-9]*:([0-9]*)', data['src']) sourceUnit_candidates = re.findall("[0-9]*:[0-9]*:([0-9]*)", data["src"])
if len(sourceUnit) == 1: if len(sourceUnit_candidates) == 1:
sourceUnit = int(sourceUnit[0]) sourceUnit = int(sourceUnit_candidates[0])
if sourceUnit == -1: if sourceUnit == -1:
# if source unit is not found # if source unit is not found
# We can still deduce it, by assigning to the last source_code added # We can still deduce it, by assigning to the last source_code added
# This works only for crytic compile. # This works only for crytic compile.
# which used --combined-json ast, rather than --ast-json # which used --combined-json ast, rather than --ast-json
# As a result -1 is not used as index # As a result -1 is not used as index
if self.crytic_compile is not None: if self._core.crytic_compile is not None:
sourceUnit = len(self.source_code) sourceUnit = len(self._core.source_code)
self._source_units[sourceUnit] = name self._core.source_units[sourceUnit] = name
if os.path.isfile(name) and not name in self.source_code: if os.path.isfile(name) and not name in self._core.source_code:
self._add_source_code(name) self._core.add_source_code(name)
else: else:
lib_name = os.path.join('node_modules', name) lib_name = os.path.join("node_modules", name)
if os.path.isfile(lib_name) and not name in self.source_code: if os.path.isfile(lib_name) and not name in self._core.source_code:
self._add_source_code(lib_name) self._core.add_source_code(lib_name)
# endregion # endregion
################################################################################### ###################################################################################
@ -200,32 +227,42 @@ class SlitherSolc(Slither):
################################################################################### ###################################################################################
@property @property
def analyzed(self): def analyzed(self) -> bool:
return self._analyzed return self._analyzed
def _analyze_contracts(self): def analyze_contracts(self):
if not self._contractsNotParsed: if not self._underlying_contract_to_parser:
logger.info(f'No contract were found in {self.filename}, check the correct compilation') logger.info(
f"No contract were found in {self._core.filename}, check the correct compilation"
)
if self._analyzed: if self._analyzed:
raise Exception('Contract analysis can be run only once!') raise Exception("Contract analysis can be run only once!")
# First we save all the contracts in a dict # First we save all the contracts in a dict
# the key is the contractid # the key is the contractid
for contract in self._contractsNotParsed: for contract in self._underlying_contract_to_parser.keys():
if contract.name.startswith('SlitherInternalTopLevelContract') and not contract.is_top_level: if (
raise SlitherException("""Your codebase has a contract named 'SlitherInternalTopLevelContract'. contract.name.startswith("SlitherInternalTopLevelContract")
Please rename it, this name is reserved for Slither's internals""") and not contract.is_top_level
if contract.name in self._contracts: ):
if contract.id != self._contracts[contract.name].id: raise SlitherException(
self._contract_name_collisions[contract.name].append(contract.source_mapping_str) """Your codebase has a contract named 'SlitherInternalTopLevelContract'.
self._contract_name_collisions[contract.name].append( Please rename it, this name is reserved for Slither's internals"""
self._contracts[contract.name].source_mapping_str) )
if contract.name in self._core.contracts_as_dict:
if contract.id != self._core.contracts_as_dict[contract.name].id:
self._core.contract_name_collisions[contract.name].append(
contract.source_mapping_str
)
self._core.contract_name_collisions[contract.name].append(
self._core.contracts_as_dict[contract.name].source_mapping_str
)
else: else:
self._contracts_by_id[contract.id] = contract self._contracts_by_id[contract.id] = contract
self._contracts[contract.name] = contract self._core.contracts_as_dict[contract.name] = contract
# Update of the inheritance # Update of the inheritance
for contract in self._contractsNotParsed: for contract_parser in self._underlying_contract_to_parser.values():
# remove the first elem in linearizedBaseContracts as it is the contract itself # remove the first elem in linearizedBaseContracts as it is the contract itself
ancestors = [] ancestors = []
fathers = [] fathers = []
@ -234,58 +271,70 @@ Please rename it, this name is reserved for Slither's internals""")
# Resolve linearized base contracts. # Resolve linearized base contracts.
missing_inheritance = False missing_inheritance = False
for i in contract.linearizedBaseContracts[1:]: for i in contract_parser.linearized_base_contracts[1:]:
if i in contract.remapping: if i in contract_parser.remapping:
ancestors.append(self.get_contract_from_name(contract.remapping[i])) ancestors.append(
self._core.get_contract_from_name(contract_parser.remapping[i])
)
elif i in self._contracts_by_id: elif i in self._contracts_by_id:
ancestors.append(self._contracts_by_id[i]) ancestors.append(self._contracts_by_id[i])
else: else:
missing_inheritance = True missing_inheritance = True
# Resolve immediate base contracts # Resolve immediate base contracts
for i in contract.baseContracts: for i in contract_parser.baseContracts:
if i in contract.remapping: if i in contract_parser.remapping:
fathers.append(self.get_contract_from_name(contract.remapping[i])) fathers.append(self._core.get_contract_from_name(contract_parser.remapping[i]))
elif i in self._contracts_by_id: elif i in self._contracts_by_id:
fathers.append(self._contracts_by_id[i]) fathers.append(self._contracts_by_id[i])
else: else:
missing_inheritance = True missing_inheritance = True
# Resolve immediate base constructor calls # Resolve immediate base constructor calls
for i in contract.baseConstructorContractsCalled: for i in contract_parser.baseConstructorContractsCalled:
if i in contract.remapping: if i in contract_parser.remapping:
father_constructors.append(self.get_contract_from_name(contract.remapping[i])) father_constructors.append(
self._core.get_contract_from_name(contract_parser.remapping[i])
)
elif i in self._contracts_by_id: elif i in self._contracts_by_id:
father_constructors.append(self._contracts_by_id[i]) father_constructors.append(self._contracts_by_id[i])
else: else:
missing_inheritance = True missing_inheritance = True
contract.setInheritance(ancestors, fathers, father_constructors) contract_parser.underlying_contract.set_inheritance(
ancestors, fathers, father_constructors
)
if missing_inheritance: if missing_inheritance:
self._contract_with_missing_inheritance.add(contract) self._core.contracts_with_missing_inheritance.add(
contract.log_incorrect_parsing(f'Missing inheritance {contract}') contract_parser.underlying_contract
contract.set_is_analyzed(True) )
contract.delete_content() contract_parser.log_incorrect_parsing(f"Missing inheritance {contract_parser}")
contract_parser.set_is_analyzed(True)
contract_parser.delete_content()
contracts_to_be_analyzed = self.contracts contracts_to_be_analyzed = list(self._underlying_contract_to_parser.values())
# Any contract can refer another contract enum without need for inheritance # Any contract can refer another contract enum without need for inheritance
self._analyze_all_enums(contracts_to_be_analyzed) self._analyze_all_enums(contracts_to_be_analyzed)
[c.set_is_analyzed(False) for c in self.contracts] [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()]
libraries = [c for c in contracts_to_be_analyzed if c.contract_kind == 'library'] libraries = [
contracts_to_be_analyzed = [c for c in contracts_to_be_analyzed if c.contract_kind != 'library'] c for c in contracts_to_be_analyzed if c.underlying_contract.contract_kind == "library"
]
contracts_to_be_analyzed = [
c for c in contracts_to_be_analyzed if c.underlying_contract.contract_kind != "library"
]
# We first parse the struct/variables/functions/contract # We first parse the struct/variables/functions/contract
self._analyze_first_part(contracts_to_be_analyzed, libraries) self._analyze_first_part(contracts_to_be_analyzed, libraries)
[c.set_is_analyzed(False) for c in self.contracts] [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()]
# We analyze the struct and parse and analyze the events # We analyze the struct and parse and analyze the events
# A contract can refer in the variables a struct or a event from any contract # A contract can refer in the variables a struct or a event from any contract
# (without inheritance link) # (without inheritance link)
self._analyze_second_part(contracts_to_be_analyzed, libraries) self._analyze_second_part(contracts_to_be_analyzed, libraries)
[c.set_is_analyzed(False) for c in self.contracts] [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()]
# Then we analyse state variables, functions and modifiers # Then we analyse state variables, functions and modifiers
self._analyze_third_part(contracts_to_be_analyzed, libraries) self._analyze_third_part(contracts_to_be_analyzed, libraries)
@ -294,22 +343,27 @@ Please rename it, this name is reserved for Slither's internals""")
self._convert_to_slithir() self._convert_to_slithir()
compute_dependency(self) compute_dependency(self._core)
def _analyze_all_enums(self, contracts_to_be_analyzed): def _analyze_all_enums(self, contracts_to_be_analyzed: List[ContractSolc]):
while contracts_to_be_analyzed: while contracts_to_be_analyzed:
contract = contracts_to_be_analyzed[0] contract = contracts_to_be_analyzed[0]
contracts_to_be_analyzed = contracts_to_be_analyzed[1:] contracts_to_be_analyzed = contracts_to_be_analyzed[1:]
all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) all_father_analyzed = all(
self._underlying_contract_to_parser[father].is_analyzed
for father in contract.underlying_contract.inheritance
)
if not contract.inheritance or all_father_analyzed: if not contract.underlying_contract.inheritance or all_father_analyzed:
self._analyze_enums(contract) self._analyze_enums(contract)
else: else:
contracts_to_be_analyzed += [contract] contracts_to_be_analyzed += [contract]
return return
def _analyze_first_part(self, contracts_to_be_analyzed, libraries): def _analyze_first_part(
self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc]
):
for lib in libraries: for lib in libraries:
self._parse_struct_var_modifiers_functions(lib) self._parse_struct_var_modifiers_functions(lib)
@ -321,16 +375,21 @@ Please rename it, this name is reserved for Slither's internals""")
contract = contracts_to_be_analyzed[0] contract = contracts_to_be_analyzed[0]
contracts_to_be_analyzed = contracts_to_be_analyzed[1:] contracts_to_be_analyzed = contracts_to_be_analyzed[1:]
all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) all_father_analyzed = all(
self._underlying_contract_to_parser[father].is_analyzed
for father in contract.underlying_contract.inheritance
)
if not contract.inheritance or all_father_analyzed: if not contract.underlying_contract.inheritance or all_father_analyzed:
self._parse_struct_var_modifiers_functions(contract) self._parse_struct_var_modifiers_functions(contract)
else: else:
contracts_to_be_analyzed += [contract] contracts_to_be_analyzed += [contract]
return return
def _analyze_second_part(self, contracts_to_be_analyzed, libraries): def _analyze_second_part(
self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc]
):
for lib in libraries: for lib in libraries:
self._analyze_struct_events(lib) self._analyze_struct_events(lib)
@ -342,16 +401,21 @@ Please rename it, this name is reserved for Slither's internals""")
contract = contracts_to_be_analyzed[0] contract = contracts_to_be_analyzed[0]
contracts_to_be_analyzed = contracts_to_be_analyzed[1:] contracts_to_be_analyzed = contracts_to_be_analyzed[1:]
all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) all_father_analyzed = all(
self._underlying_contract_to_parser[father].is_analyzed
for father in contract.underlying_contract.inheritance
)
if not contract.inheritance or all_father_analyzed: if not contract.underlying_contract.inheritance or all_father_analyzed:
self._analyze_struct_events(contract) self._analyze_struct_events(contract)
else: else:
contracts_to_be_analyzed += [contract] contracts_to_be_analyzed += [contract]
return return
def _analyze_third_part(self, contracts_to_be_analyzed, libraries): def _analyze_third_part(
self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc]
):
for lib in libraries: for lib in libraries:
self._analyze_variables_modifiers_functions(lib) self._analyze_variables_modifiers_functions(lib)
@ -363,28 +427,31 @@ Please rename it, this name is reserved for Slither's internals""")
contract = contracts_to_be_analyzed[0] contract = contracts_to_be_analyzed[0]
contracts_to_be_analyzed = contracts_to_be_analyzed[1:] contracts_to_be_analyzed = contracts_to_be_analyzed[1:]
all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) all_father_analyzed = all(
self._underlying_contract_to_parser[father].is_analyzed
for father in contract.underlying_contract.inheritance
)
if not contract.inheritance or all_father_analyzed: if not contract.underlying_contract.inheritance or all_father_analyzed:
self._analyze_variables_modifiers_functions(contract) self._analyze_variables_modifiers_functions(contract)
else: else:
contracts_to_be_analyzed += [contract] contracts_to_be_analyzed += [contract]
return return
def _analyze_enums(self, contract): def _analyze_enums(self, contract: ContractSolc):
# Enum must be analyzed first # Enum must be analyzed first
contract.analyze_enums() contract.analyze_enums()
contract.set_is_analyzed(True) contract.set_is_analyzed(True)
def _parse_struct_var_modifiers_functions(self, contract): def _parse_struct_var_modifiers_functions(self, contract: ContractSolc):
contract.parse_structs() # struct can refer another struct contract.parse_structs() # struct can refer another struct
contract.parse_state_variables() contract.parse_state_variables()
contract.parse_modifiers() contract.parse_modifiers()
contract.parse_functions() contract.parse_functions()
contract.set_is_analyzed(True) contract.set_is_analyzed(True)
def _analyze_struct_events(self, contract): def _analyze_struct_events(self, contract: ContractSolc):
contract.analyze_constant_state_variables() contract.analyze_constant_state_variables()
@ -397,7 +464,7 @@ Please rename it, this name is reserved for Slither's internals""")
contract.set_is_analyzed(True) contract.set_is_analyzed(True)
def _analyze_variables_modifiers_functions(self, contract): def _analyze_variables_modifiers_functions(self, contract: ContractSolc):
# State variables, modifiers and functions can refer to anything # State variables, modifiers and functions can refer to anything
contract.analyze_params_modifiers() contract.analyze_params_modifiers()
@ -412,11 +479,23 @@ Please rename it, this name is reserved for Slither's internals""")
def _convert_to_slithir(self): def _convert_to_slithir(self):
for contract in self.contracts: for contract in self._core.contracts:
contract.add_constructor_variables() contract.add_constructor_variables()
contract.convert_expression_to_slithir()
self._propagate_function_calls() for func in contract.functions + contract.modifiers:
for contract in self.contracts: try:
func.generate_slithir_and_analyze()
except AttributeError:
# This can happens for example if there is a call to an interface
# And the interface is redefined due to contract's name reuse
# But the available version misses some functions
self._underlying_contract_to_parser[contract].log_incorrect_parsing(
f"Impossible to generate IR for {contract.name}.{func.name}"
)
contract.convert_expression_to_slithir_ssa()
self._core.propagate_function_calls()
for contract in self._core.contracts:
contract.fix_phi() contract.fix_phi()
contract.update_read_write_using_ssa() contract.update_read_write_using_ssa()

@ -1,6 +1,8 @@
import logging import logging
from typing import List, TYPE_CHECKING, Union, Dict
from slither.core.solidity_types.elementary_type import ElementaryType, ElementaryTypeName from slither.core.solidity_types.elementary_type import ElementaryType, ElementaryTypeName
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.array_type import ArrayType
from slither.core.solidity_types.mapping_type import MappingType from slither.core.solidity_types.mapping_type import MappingType
@ -9,14 +11,17 @@ from slither.core.solidity_types.function_type import FunctionType
from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.function_type_variable import FunctionTypeVariable
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function
from slither.core.expressions.literal import Literal from slither.core.expressions.literal import Literal
from slither.solc_parsing.exceptions import ParsingError from slither.solc_parsing.exceptions import ParsingError
import re import re
logger = logging.getLogger('TypeParsing') if TYPE_CHECKING:
from slither.core.declarations import Structure, Enum
logger = logging.getLogger("TypeParsing")
class UnknownType: class UnknownType:
def __init__(self, name): def __init__(self, name):
@ -26,24 +31,31 @@ class UnknownType:
def name(self): def name(self):
return self._name return self._name
def _find_from_type_name(name, contract, contracts, structures, enums):
name_elementary = name.split(' ')[0] def _find_from_type_name(
if '[' in name_elementary: name: str,
name_elementary = name_elementary[0:name_elementary.find('[')] contract: Contract,
contracts: List[Contract],
structures: List["Structure"],
enums: List["Enum"],
) -> Type:
name_elementary = name.split(" ")[0]
if "[" in name_elementary:
name_elementary = name_elementary[0 : name_elementary.find("[")]
if name_elementary in ElementaryTypeName: if name_elementary in ElementaryTypeName:
depth = name.count('[') depth = name.count("[")
if depth: if depth:
return ArrayType(ElementaryType(name_elementary), Literal(depth, 'uint256')) return ArrayType(ElementaryType(name_elementary), Literal(depth, "uint256"))
else: else:
return ElementaryType(name_elementary) return ElementaryType(name_elementary)
# We first look for contract # We first look for contract
# To avoid collision # To avoid collision
# Ex: a structure with the name of a contract # Ex: a structure with the name of a contract
name_contract = name name_contract = name
if name_contract.startswith('contract '): if name_contract.startswith("contract "):
name_contract = name_contract[len('contract '):] name_contract = name_contract[len("contract ") :]
if name_contract.startswith('library '): if name_contract.startswith("library "):
name_contract = name_contract[len('library '):] name_contract = name_contract[len("library ") :]
var_type = next((c for c in contracts if c.name == name_contract), None) var_type = next((c for c in contracts if c.name == name_contract), None)
if not var_type: if not var_type:
@ -53,8 +65,8 @@ def _find_from_type_name(name, contract, contracts, structures, enums):
if not var_type: if not var_type:
# any contract can refer to another contract's enum # any contract can refer to another contract's enum
enum_name = name enum_name = name
if enum_name.startswith('enum '): if enum_name.startswith("enum "):
enum_name = enum_name[len('enum '):] enum_name = enum_name[len("enum ") :]
all_enums = [c.enums for c in contracts] all_enums = [c.enums for c in contracts]
all_enums = [item for sublist in all_enums for item in sublist] all_enums = [item for sublist in all_enums for item in sublist]
var_type = next((e for e in all_enums if e.name == enum_name), None) var_type = next((e for e in all_enums if e.name == enum_name), None)
@ -63,9 +75,9 @@ def _find_from_type_name(name, contract, contracts, structures, enums):
if not var_type: if not var_type:
# any contract can refer to another contract's structure # any contract can refer to another contract's structure
name_struct = name name_struct = name
if name_struct.startswith('struct '): if name_struct.startswith("struct "):
name_struct = name_struct[len('struct '):] name_struct = name_struct[len("struct ") :]
name_struct = name_struct.split(' ')[0] # remove stuff like storage pointer at the end name_struct = name_struct.split(" ")[0] # remove stuff like storage pointer at the end
all_structures = [c.structures for c in contracts] all_structures = [c.structures for c in contracts]
all_structures = [item for sublist in all_structures for item in sublist] all_structures = [item for sublist in all_structures for item in sublist]
var_type = next((st for st in all_structures if st.name == name_struct), None) var_type = next((st for st in all_structures if st.name == name_struct), None)
@ -74,23 +86,30 @@ def _find_from_type_name(name, contract, contracts, structures, enums):
# case where struct xxx.xx[] where not well formed in the AST # case where struct xxx.xx[] where not well formed in the AST
if not var_type: if not var_type:
depth = 0 depth = 0
while name_struct.endswith('[]'): while name_struct.endswith("[]"):
name_struct = name_struct[0:-2] name_struct = name_struct[0:-2]
depth+=1 depth += 1
var_type = next((st for st in all_structures if st.canonical_name == name_struct), None) var_type = next((st for st in all_structures if st.canonical_name == name_struct), None)
if var_type: if var_type:
return ArrayType(UserDefinedType(var_type), Literal(depth, 'uint256')) return ArrayType(UserDefinedType(var_type), Literal(depth, "uint256"))
if not var_type: if not var_type:
var_type = next((f for f in contract.functions if f.name == name), None) var_type = next((f for f in contract.functions if f.name == name), None)
if not var_type: if not var_type:
if name.startswith('function '): if name.startswith("function "):
found = re.findall('function \(([ ()a-zA-Z0-9\.,]*)\) returns \(([a-zA-Z0-9\.,]*)\)', name) found = re.findall(
"function \(([ ()a-zA-Z0-9\.,]*)\) returns \(([a-zA-Z0-9\.,]*)\)", name
)
assert len(found) == 1 assert len(found) == 1
params = found[0][0].split(',') params = found[0][0].split(",")
return_values = found[0][1].split(',') return_values = found[0][1].split(",")
params = [_find_from_type_name(p, contract, contracts, structures, enums) for p in params] params = [
return_values = [_find_from_type_name(r, contract, contracts, structures, enums) for r in return_values] _find_from_type_name(p, contract, contracts, structures, enums) for p in params
]
return_values = [
_find_from_type_name(r, contract, contracts, structures, enums)
for r in return_values
]
params_vars = [] params_vars = []
return_vars = [] return_vars = []
for p in params: for p in params:
@ -103,12 +122,14 @@ def _find_from_type_name(name, contract, contracts, structures, enums):
return_vars.append(var) return_vars.append(var)
return FunctionType(params_vars, return_vars) return FunctionType(params_vars, return_vars)
if not var_type: if not var_type:
if name.startswith('mapping('): if name.startswith("mapping("):
# nested mapping declared with var # nested mapping declared with var
if name.count('mapping(') == 1 : if name.count("mapping(") == 1:
found = re.findall('mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.\[\]]*)\)', name) found = re.findall("mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.\[\]]*)\)", name)
else: else:
found = re.findall('mapping\(([a-zA-Z0-9\.]*) => (mapping\([=> a-zA-Z0-9\.\[\]]*\))\)', name) found = re.findall(
"mapping\(([a-zA-Z0-9\.]*) => (mapping\([=> a-zA-Z0-9\.\[\]]*\))\)", name
)
assert len(found) == 1 assert len(found) == 1
from_ = found[0][0] from_ = found[0][0]
to_ = found[0][1] to_ = found[0][1]
@ -119,30 +140,32 @@ def _find_from_type_name(name, contract, contracts, structures, enums):
return MappingType(from_type, to_type) return MappingType(from_type, to_type)
if not var_type: if not var_type:
raise ParsingError('Type not found '+str(name)) raise ParsingError("Type not found " + str(name))
return UserDefinedType(var_type) return UserDefinedType(var_type)
def parse_type(t: Union[Dict, UnknownType], caller_context):
def parse_type(t, caller_context):
# local import to avoid circular dependency # local import to avoid circular dependency
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc
from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.function import FunctionSolc
if isinstance(caller_context, Contract): if isinstance(caller_context, ContractSolc):
contract = caller_context contract = caller_context.underlying_contract
elif isinstance(caller_context, Function): contract_parser = caller_context
contract = caller_context.contract is_compact_ast = caller_context.is_compact_ast
else: elif isinstance(caller_context, FunctionSolc):
raise ParsingError('Incorrect caller context') contract = caller_context.underlying_function.contract
contract_parser = caller_context.contract_parser
is_compact_ast = caller_context.is_compact_ast is_compact_ast = caller_context.is_compact_ast
else:
raise ParsingError(f"Incorrect caller context: {type(caller_context)}")
if is_compact_ast: if is_compact_ast:
key = 'nodeType' key = "nodeType"
else: else:
key = 'name' key = "name"
structures = contract.structures structures = contract.structures
enums = contract.enums enums = contract.enums
@ -151,75 +174,84 @@ def parse_type(t, caller_context):
if isinstance(t, UnknownType): if isinstance(t, UnknownType):
return _find_from_type_name(t.name, contract, contracts, structures, enums) return _find_from_type_name(t.name, contract, contracts, structures, enums)
elif t[key] == 'ElementaryTypeName': elif t[key] == "ElementaryTypeName":
if is_compact_ast: if is_compact_ast:
return ElementaryType(t['name']) return ElementaryType(t["name"])
return ElementaryType(t['attributes'][key]) return ElementaryType(t["attributes"][key])
elif t[key] == 'UserDefinedTypeName': elif t[key] == "UserDefinedTypeName":
if is_compact_ast: if is_compact_ast:
return _find_from_type_name(t['typeDescriptions']['typeString'], contract, contracts, structures, enums) return _find_from_type_name(
t["typeDescriptions"]["typeString"], contract, contracts, structures, enums
)
# Determine if we have a type node (otherwise we use the name node, as some older solc did not have 'type'). # Determine if we have a type node (otherwise we use the name node, as some older solc did not have 'type').
type_name_key = 'type' if 'type' in t['attributes'] else key type_name_key = "type" if "type" in t["attributes"] else key
return _find_from_type_name(t['attributes'][type_name_key], contract, contracts, structures, enums) return _find_from_type_name(
t["attributes"][type_name_key], contract, contracts, structures, enums
)
elif t[key] == 'ArrayTypeName': elif t[key] == "ArrayTypeName":
length = None length = None
if is_compact_ast: if is_compact_ast:
if t['length']: if t["length"]:
length = parse_expression(t['length'], caller_context) length = parse_expression(t["length"], caller_context)
array_type = parse_type(t['baseType'], contract) array_type = parse_type(t["baseType"], contract_parser)
else: else:
if len(t['children']) == 2: if len(t["children"]) == 2:
length = parse_expression(t['children'][1], caller_context) length = parse_expression(t["children"][1], caller_context)
else: else:
assert len(t['children']) == 1 assert len(t["children"]) == 1
array_type = parse_type(t['children'][0], contract) array_type = parse_type(t["children"][0], contract_parser)
return ArrayType(array_type, length) return ArrayType(array_type, length)
elif t[key] == 'Mapping': elif t[key] == "Mapping":
if is_compact_ast: if is_compact_ast:
mappingFrom = parse_type(t['keyType'], contract) mappingFrom = parse_type(t["keyType"], contract_parser)
mappingTo = parse_type(t['valueType'], contract) mappingTo = parse_type(t["valueType"], contract_parser)
else: else:
assert len(t['children']) == 2 assert len(t["children"]) == 2
mappingFrom = parse_type(t['children'][0], contract) mappingFrom = parse_type(t["children"][0], contract_parser)
mappingTo = parse_type(t['children'][1], contract) mappingTo = parse_type(t["children"][1], contract_parser)
return MappingType(mappingFrom, mappingTo) return MappingType(mappingFrom, mappingTo)
elif t[key] == 'FunctionTypeName': elif t[key] == "FunctionTypeName":
if is_compact_ast: if is_compact_ast:
params = t['parameterTypes'] params = t["parameterTypes"]
return_values = t['returnParameterTypes'] return_values = t["returnParameterTypes"]
index = 'parameters' index = "parameters"
else: else:
assert len(t['children']) == 2 assert len(t["children"]) == 2
params = t['children'][0] params = t["children"][0]
return_values = t['children'][1] return_values = t["children"][1]
index = 'children' index = "children"
assert params[key] == 'ParameterList' assert params[key] == "ParameterList"
assert return_values[key] == 'ParameterList' assert return_values[key] == "ParameterList"
params_vars = [] params_vars: List[FunctionTypeVariable] = []
return_values_vars = [] return_values_vars: List[FunctionTypeVariable] = []
for p in params[index]: for p in params[index]:
var = FunctionTypeVariableSolc(p) var = FunctionTypeVariable()
var.set_offset(p['src'], caller_context.slither) var.set_offset(p["src"], caller_context.slither)
var.analyze(caller_context)
var_parser = FunctionTypeVariableSolc(var, p)
var_parser.analyze(caller_context)
params_vars.append(var) params_vars.append(var)
for p in return_values[index]: for p in return_values[index]:
var = FunctionTypeVariableSolc(p) var = FunctionTypeVariable()
var.set_offset(p["src"], caller_context.slither)
var_parser = FunctionTypeVariableSolc(var, p)
var_parser.analyze(caller_context)
var.set_offset(p['src'], caller_context.slither)
var.analyze(caller_context)
return_values_vars.append(var) return_values_vars.append(var)
return FunctionType(params_vars, return_values_vars) return FunctionType(params_vars, return_values_vars)
raise ParsingError('Type name not found '+str(t)) raise ParsingError("Type name not found " + str(t))

@ -1,10 +1,20 @@
from typing import Dict
from .variable_declaration import VariableDeclarationSolc from .variable_declaration import VariableDeclarationSolc
from slither.core.variables.event_variable import EventVariable from slither.core.variables.event_variable import EventVariable
class EventVariableSolc(VariableDeclarationSolc, EventVariable):
def _analyze_variable_attributes(self, attributes): class EventVariableSolc(VariableDeclarationSolc):
def __init__(self, variable: EventVariable, variable_data: Dict):
super(EventVariableSolc, self).__init__(variable, variable_data)
@property
def underlying_variable(self) -> EventVariable:
# Todo: Not sure how to overcome this with mypy
assert isinstance(self._variable, EventVariable)
return self._variable
def _analyze_variable_attributes(self, attributes: Dict):
""" """
Analyze event variable attributes Analyze event variable attributes
:param attributes: The event variable attributes to parse. :param attributes: The event variable attributes to parse.
@ -12,8 +22,7 @@ class EventVariableSolc(VariableDeclarationSolc, EventVariable):
""" """
# Check for the indexed attribute # Check for the indexed attribute
if 'indexed' in attributes: if "indexed" in attributes:
self._indexed = attributes['indexed'] self.underlying_variable.indexed = attributes["indexed"]
super(EventVariableSolc, self)._analyze_variable_attributes(attributes) super(EventVariableSolc, self)._analyze_variable_attributes(attributes)

@ -1,5 +1,15 @@
from typing import Dict
from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc
from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.function_type_variable import FunctionTypeVariable
class FunctionTypeVariableSolc(VariableDeclarationSolc, FunctionTypeVariable): pass
class FunctionTypeVariableSolc(VariableDeclarationSolc):
def __init__(self, variable: FunctionTypeVariable, variable_data: Dict):
super(FunctionTypeVariableSolc, self).__init__(variable, variable_data)
@property
def underlying_variable(self) -> FunctionTypeVariable:
# Todo: Not sure how to overcome this with mypy
assert isinstance(self._variable, FunctionTypeVariable)
return self._variable

@ -1,24 +1,33 @@
from typing import Dict
from .variable_declaration import VariableDeclarationSolc from .variable_declaration import VariableDeclarationSolc
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
class LocalVariableSolc(VariableDeclarationSolc, LocalVariable):
def _analyze_variable_attributes(self, attributes): class LocalVariableSolc(VariableDeclarationSolc):
'''' def __init__(self, variable: LocalVariable, variable_data: Dict):
super(LocalVariableSolc, self).__init__(variable, variable_data)
@property
def underlying_variable(self) -> LocalVariable:
# Todo: Not sure how to overcome this with mypy
assert isinstance(self._variable, LocalVariable)
return self._variable
def _analyze_variable_attributes(self, attributes: Dict):
"""'
Variable Location Variable Location
Can be storage/memory or default Can be storage/memory or default
''' """
if 'storageLocation' in attributes: if "storageLocation" in attributes:
location = attributes['storageLocation'] location = attributes["storageLocation"]
self._location = location self.underlying_variable.set_location(location)
else: else:
if 'memory' in attributes['type']: if "memory" in attributes["type"]:
self._location = 'memory' self.underlying_variable.set_location("memory")
elif'storage' in attributes['type']: elif "storage" in attributes["type"]:
self._location = 'storage' self.underlying_variable.set_location("storage")
else: else:
self._location = 'default' self.underlying_variable.set_location("default")
super(LocalVariableSolc, self)._analyze_variable_attributes(attributes) super(LocalVariableSolc, self)._analyze_variable_attributes(attributes)

@ -1,11 +1,16 @@
from typing import Dict
from .variable_declaration import VariableDeclarationSolc from .variable_declaration import VariableDeclarationSolc
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
class LocalVariableInitFromTupleSolc(VariableDeclarationSolc, LocalVariableInitFromTuple):
def __init__(self, var, index):
super(LocalVariableInitFromTupleSolc, self).__init__(var)
self._tuple_index = index
class LocalVariableInitFromTupleSolc(VariableDeclarationSolc):
def __init__(self, variable: LocalVariableInitFromTuple, variable_data: Dict, index: int):
super(LocalVariableInitFromTupleSolc, self).__init__(variable, variable_data)
variable.tuple_index = index
@property
def underlying_variable(self) -> LocalVariableInitFromTuple:
# Todo: Not sure how to overcome this with mypy
assert isinstance(self._variable, LocalVariableInitFromTuple)
return self._variable

@ -1,5 +1,15 @@
from typing import Dict
from .variable_declaration import VariableDeclarationSolc from .variable_declaration import VariableDeclarationSolc
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
class StateVariableSolc(VariableDeclarationSolc, StateVariable): pass
class StateVariableSolc(VariableDeclarationSolc):
def __init__(self, variable: StateVariable, variable_data: Dict):
super(StateVariableSolc, self).__init__(variable, variable_data)
@property
def underlying_variable(self) -> StateVariable:
# Todo: Not sure how to overcome this with mypy
assert isinstance(self._variable, StateVariable)
return self._variable

@ -1,5 +1,15 @@
from typing import Dict
from .variable_declaration import VariableDeclarationSolc from .variable_declaration import VariableDeclarationSolc
from slither.core.variables.structure_variable import StructureVariable from slither.core.variables.structure_variable import StructureVariable
class StructureVariableSolc(VariableDeclarationSolc, StructureVariable): pass
class StructureVariableSolc(VariableDeclarationSolc):
def __init__(self, variable: StructureVariable, variable_data: Dict):
super(StructureVariableSolc, self).__init__(variable, variable_data)
@property
def underlying_variable(self) -> StructureVariable:
# Todo: Not sure how to overcome this with mypy
assert isinstance(self._variable, StructureVariable)
return self._variable

@ -1,4 +1,6 @@
import logging import logging
from typing import Dict
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
@ -7,28 +9,31 @@ from slither.solc_parsing.solidity_types.type_parsing import parse_type, Unknown
from slither.core.solidity_types.elementary_type import ElementaryType, NonElementaryType from slither.core.solidity_types.elementary_type import ElementaryType, NonElementaryType
from slither.solc_parsing.exceptions import ParsingError from slither.solc_parsing.exceptions import ParsingError
logger = logging.getLogger("VariableDeclarationSolcParsing") logger = logging.getLogger("VariableDeclarationSolcParsing")
class MultipleVariablesDeclaration(Exception): class MultipleVariablesDeclaration(Exception):
''' """
This is raised on This is raised on
var (a,b) = ... var (a,b) = ...
It should occur only on local variable definition It should occur only on local variable definition
''' """
pass pass
class VariableDeclarationSolc(Variable):
def __init__(self, var): class VariableDeclarationSolc:
''' def __init__(self, variable: Variable, variable_data: Dict):
"""
A variable can be declared through a statement, or directly. A variable can be declared through a statement, or directly.
If it is through a statement, the following children may contain If it is through a statement, the following children may contain
the init value. the init value.
It may be possible that the variable is declared through a statement, It may be possible that the variable is declared through a statement,
but the init value is declared at the VariableDeclaration children level but the init value is declared at the VariableDeclaration children level
''' """
super(VariableDeclarationSolc, self).__init__() self._variable = variable
self._was_analyzed = False self._was_analyzed = False
self._elem_to_parse = None self._elem_to_parse = None
self._initializedNotParsed = None self._initializedNotParsed = None
@ -37,125 +42,122 @@ class VariableDeclarationSolc(Variable):
self._reference_id = None self._reference_id = None
if "nodeType" in variable_data:
if 'nodeType' in var:
self._is_compact_ast = True self._is_compact_ast = True
nodeType = var['nodeType'] nodeType = variable_data["nodeType"]
if nodeType in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: if nodeType in ["VariableDeclarationStatement", "VariableDefinitionStatement"]:
if len(var['declarations'])>1: if len(variable_data["declarations"]) > 1:
raise MultipleVariablesDeclaration raise MultipleVariablesDeclaration
init = None init = None
if 'initialValue' in var: if "initialValue" in variable_data:
init = var['initialValue'] init = variable_data["initialValue"]
self._init_from_declaration(var['declarations'][0], init) self._init_from_declaration(variable_data["declarations"][0], init)
elif nodeType == 'VariableDeclaration': elif nodeType == "VariableDeclaration":
self._init_from_declaration(var, var['value']) self._init_from_declaration(variable_data, variable_data["value"])
else: else:
raise ParsingError('Incorrect variable declaration type {}'.format(nodeType)) raise ParsingError("Incorrect variable declaration type {}".format(nodeType))
else: else:
nodeType = var['name'] nodeType = variable_data["name"]
if nodeType in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: if nodeType in ["VariableDeclarationStatement", "VariableDefinitionStatement"]:
if len(var['children']) == 2: if len(variable_data["children"]) == 2:
init = var['children'][1] init = variable_data["children"][1]
elif len(var['children']) == 1: elif len(variable_data["children"]) == 1:
init = None init = None
elif len(var['children']) > 2: elif len(variable_data["children"]) > 2:
raise MultipleVariablesDeclaration raise MultipleVariablesDeclaration
else: else:
raise ParsingError('Variable declaration without children?'+var) raise ParsingError(
declaration = var['children'][0] "Variable declaration without children?" + str(variable_data)
)
declaration = variable_data["children"][0]
self._init_from_declaration(declaration, init) self._init_from_declaration(declaration, init)
elif nodeType == 'VariableDeclaration': elif nodeType == "VariableDeclaration":
self._init_from_declaration(var, None) self._init_from_declaration(variable_data, False)
else: else:
raise ParsingError('Incorrect variable declaration type {}'.format(nodeType)) raise ParsingError("Incorrect variable declaration type {}".format(nodeType))
@property
def initialized(self):
return self._initialized
@property @property
def uninitialized(self): def underlying_variable(self) -> Variable:
return not self._initialized return self._variable
@property @property
def reference_id(self): def reference_id(self) -> int:
''' """
Return the solc id. It can be compared with the referencedDeclaration attr Return the solc id. It can be compared with the referencedDeclaration attr
Returns None if it was not parsed (legacy AST) Returns None if it was not parsed (legacy AST)
''' """
return self._reference_id return self._reference_id
def _analyze_variable_attributes(self, attributes): def _analyze_variable_attributes(self, attributes: Dict):
if 'visibility' in attributes: if "visibility" in attributes:
self._visibility = attributes['visibility'] self._variable.visibility = attributes["visibility"]
else: else:
self._visibility = 'internal' self._variable.visibility = "internal"
def _init_from_declaration(self, var, init): def _init_from_declaration(self, var: Dict, init: bool):
if self._is_compact_ast: if self._is_compact_ast:
attributes = var attributes = var
self._typeName = attributes['typeDescriptions']['typeString'] self._typeName = attributes["typeDescriptions"]["typeString"]
else: else:
assert len(var['children']) <= 2 assert len(var["children"]) <= 2
assert var['name'] == 'VariableDeclaration' assert var["name"] == "VariableDeclaration"
attributes = var['attributes'] attributes = var["attributes"]
self._typeName = attributes['type'] self._typeName = attributes["type"]
self._name = attributes['name'] self._variable.name = attributes["name"]
self._arrayDepth = 0 # self._arrayDepth = 0
self._isMapping = False # self._isMapping = False
self._mappingFrom = None # self._mappingFrom = None
self._mappingTo = False # self._mappingTo = False
self._initial_expression = None # self._initial_expression = None
self._type = None # self._type = None
# Only for comapct ast format # Only for comapct ast format
# the id can be used later if referencedDeclaration # the id can be used later if referencedDeclaration
# is provided # is provided
if 'id' in var: if "id" in var:
self._reference_id = var['id'] self._reference_id = var["id"]
if 'constant' in attributes: if "constant" in attributes:
self._is_constant = attributes['constant'] self._variable.is_constant = attributes["constant"]
self._analyze_variable_attributes(attributes) self._analyze_variable_attributes(attributes)
if self._is_compact_ast: if self._is_compact_ast:
if var['typeName']: if var["typeName"]:
self._elem_to_parse = var['typeName'] self._elem_to_parse = var["typeName"]
else: else:
self._elem_to_parse = UnknownType(var['typeDescriptions']['typeString']) self._elem_to_parse = UnknownType(var["typeDescriptions"]["typeString"])
else: else:
if not var['children']: if not var["children"]:
# It happens on variable declared inside loop declaration # It happens on variable declared inside loop declaration
try: try:
self._type = ElementaryType(self._typeName) self._variable.type = ElementaryType(self._typeName)
self._elem_to_parse = None self._elem_to_parse = None
except NonElementaryType: except NonElementaryType:
self._elem_to_parse = UnknownType(self._typeName) self._elem_to_parse = UnknownType(self._typeName)
else: else:
self._elem_to_parse = var['children'][0] self._elem_to_parse = var["children"][0]
if self._is_compact_ast: if self._is_compact_ast:
self._initializedNotParsed = init self._initializedNotParsed = init
if init: if init:
self._initialized = True self._variable.initialized = True
else: else:
if init: # there are two way to init a var local in the AST if init: # there are two way to init a var local in the AST
assert len(var['children']) <= 1 assert len(var["children"]) <= 1
self._initialized = True self._variable.initialized = True
self._initializedNotParsed = init self._initializedNotParsed = init
elif len(var['children']) in [0, 1]: elif len(var["children"]) in [0, 1]:
self._initialized = False self._variable.initialized = False
self._initializedNotParsed = [] self._initializedNotParsed = []
else: else:
assert len(var['children']) == 2 assert len(var["children"]) == 2
self._initialized = True self._variable.initialized = True
self._initializedNotParsed = var['children'][1] self._initializedNotParsed = var["children"][1]
def analyze(self, caller_context): def analyze(self, caller_context):
# Can be re-analyzed due to inheritance # Can be re-analyzed due to inheritance
@ -164,9 +166,9 @@ class VariableDeclarationSolc(Variable):
self._was_analyzed = True self._was_analyzed = True
if self._elem_to_parse: if self._elem_to_parse:
self._type = parse_type(self._elem_to_parse, caller_context) self._variable.type = parse_type(self._elem_to_parse, caller_context)
self._elem_to_parse = None self._elem_to_parse = None
if self._initialized: if self._variable.initialized:
self._initial_expression = parse_expression(self._initializedNotParsed, caller_context) self._variable.expression = parse_expression(self._initializedNotParsed, caller_context)
self._initializedNotParsed = None self._initializedNotParsed = None

@ -9,16 +9,17 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-demo") logger = logging.getLogger("Slither-demo")
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='Demo', parser = argparse.ArgumentParser(description="Demo", usage="slither-demo filename")
usage='slither-demo filename')
parser.add_argument('filename', parser.add_argument(
help='The filename of the contract or truffle directory to analyze.') "filename", help="The filename of the contract or truffle directory to analyze."
)
# Add default arguments from crytic-compile # Add default arguments from crytic-compile
cryticparser.init(parser) cryticparser.init(parser)
@ -32,7 +33,8 @@ def main():
# Perform slither analysis on the given filename # Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args)) slither = Slither(args.filename, **vars(args))
logger.info('Analysis done!') logger.info("Analysis done!")
if __name__ == '__main__': if __name__ == "__main__":
main() main()

@ -17,28 +17,29 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
logger.addHandler(ch) logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
ADDITIONAL_CHECKS = { ADDITIONAL_CHECKS = {"ERC20": check_erc20}
"ERC20": check_erc20
}
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='Check the ERC 20 conformance', parser = argparse.ArgumentParser(
usage='slither-erc project contractName') description="Check the ERC 20 conformance", usage="slither-erc project contractName"
)
parser.add_argument('project', parser.add_argument("project", help="The codebase to be tested.")
help='The codebase to be tested.')
parser.add_argument('contract_name', parser.add_argument(
help='The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.') "contract_name",
help="The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.",
)
parser.add_argument( parser.add_argument(
"--erc", "--erc",
@ -47,22 +48,26 @@ def parse_args():
default="erc20", default="erc20",
) )
parser.add_argument('--json', parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)', help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store', action="store",
default=False) default=False,
)
# Add default arguments from crytic-compile # Add default arguments from crytic-compile
cryticparser.init(parser) cryticparser.init(parser)
return parser.parse_args() return parser.parse_args()
def _log_error(err, args): def _log_error(err, args):
if args.json: if args.json:
output_to_json(args.json, str(err), {"upgradeability-check": []}) output_to_json(args.json, str(err), {"upgradeability-check": []})
logger.error(err) logger.error(err)
def main(): def main():
args = parse_args() args = parse_args()
@ -76,7 +81,7 @@ def main():
contract = slither.get_contract_from_name(args.contract_name) contract = slither.get_contract_from_name(args.contract_name)
if not contract: if not contract:
err = f'Contract not found: {args.contract_name}' err = f"Contract not found: {args.contract_name}"
_log_error(err, args) _log_error(err, args)
return return
# First elem is the function, second is the event # First elem is the function, second is the event
@ -87,7 +92,7 @@ def main():
ADDITIONAL_CHECKS[args.erc.upper()](contract, ret) ADDITIONAL_CHECKS[args.erc.upper()](contract, ret)
else: else:
err = f'Incorrect ERC selected {args.erc}' err = f"Incorrect ERC selected {args.erc}"
_log_error(err, args) _log_error(err, args)
return return
@ -95,5 +100,5 @@ def main():
output_to_json(args.json, None, {"upgradeability-check": ret}) output_to_json(args.json, None, {"upgradeability-check": ret})
if __name__ == '__main__': if __name__ == "__main__":
main() main()

@ -6,21 +6,25 @@ logger = logging.getLogger("Slither-conformance")
def approval_race_condition(contract, ret): def approval_race_condition(contract, ret):
increaseAllowance = contract.get_function_from_signature('increaseAllowance(address,uint256)') increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)")
if not increaseAllowance: if not increaseAllowance:
increaseAllowance = contract.get_function_from_signature('safeIncreaseAllowance(address,uint256)') increaseAllowance = contract.get_function_from_signature(
"safeIncreaseAllowance(address,uint256)"
)
if increaseAllowance: if increaseAllowance:
txt = f'\t[✓] {contract.name} has {increaseAllowance.full_name}' txt = f"\t[✓] {contract.name} has {increaseAllowance.full_name}"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {contract.name} is not protected for the ERC20 approval race condition' txt = f"\t[ ] {contract.name} is not protected for the ERC20 approval race condition"
logger.info(txt) logger.info(txt)
lack_of_erc20_race_condition_protection = output.Output(txt) lack_of_erc20_race_condition_protection = output.Output(txt)
lack_of_erc20_race_condition_protection.add(contract) lack_of_erc20_race_condition_protection.add(contract)
ret["lack_of_erc20_race_condition_protection"].append(lack_of_erc20_race_condition_protection.data) ret["lack_of_erc20_race_condition_protection"].append(
lack_of_erc20_race_condition_protection.data
)
def check_erc20(contract, ret, explored=None): def check_erc20(contract, ret, explored=None):

@ -22,13 +22,15 @@ def _check_signature(erc_function, contract, ret):
# The check on state variable is needed until we have a better API to handle state variable getters # The check on state variable is needed until we have a better API to handle state variable getters
state_variable_as_function = contract.get_state_variable_from_name(name) state_variable_as_function = contract.get_state_variable_from_name(name)
if not state_variable_as_function or not state_variable_as_function.visibility in ['public', 'external']: if not state_variable_as_function or not state_variable_as_function.visibility in [
"public",
"external",
]:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt) logger.info(txt)
missing_func = output.Output(txt, additional_fields={ missing_func = output.Output(
"function": sig, txt, additional_fields={"function": sig, "required": required}
"required": required )
})
missing_func.add(contract) missing_func.add(contract)
ret["missing_function"].append(missing_func.data) ret["missing_function"].append(missing_func.data)
return return
@ -38,10 +40,9 @@ def _check_signature(erc_function, contract, ret):
if types != parameters: if types != parameters:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt) logger.info(txt)
missing_func = output.Output(txt, additional_fields={ missing_func = output.Output(
"function": sig, txt, additional_fields={"function": sig, "required": required}
"required": required )
})
missing_func.add(contract) missing_func.add(contract)
ret["missing_function"].append(missing_func.data) ret["missing_function"].append(missing_func.data)
return return
@ -53,45 +54,51 @@ def _check_signature(erc_function, contract, ret):
function_return_type = function.return_type function_return_type = function.return_type
function_view = function.view function_view = function.view
txt = f'[✓] {sig} is present' txt = f"[✓] {sig} is present"
logger.info(txt) logger.info(txt)
if function_return_type: if function_return_type:
function_return_type = ','.join([str(x) for x in function_return_type]) function_return_type = ",".join([str(x) for x in function_return_type])
if function_return_type == return_type: if function_return_type == return_type:
txt = f'\t[✓] {sig} -> () (correct return value)' txt = f"\t[✓] {sig} -> () (correct return value)"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {sig} -> () should return {return_type}' txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt) logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={ incorrect_return = output.Output(
txt,
additional_fields={
"expected_return_type": return_type, "expected_return_type": return_type,
"actual_return_type": function_return_type "actual_return_type": function_return_type,
}) },
)
incorrect_return.add(function) incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data) ret["incorrect_return_type"].append(incorrect_return.data)
elif not return_type: elif not return_type:
txt = f'\t[✓] {sig} -> () (correct return type)' txt = f"\t[✓] {sig} -> () (correct return type)"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {sig} -> () should return {return_type}' txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt) logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={ incorrect_return = output.Output(
txt,
additional_fields={
"expected_return_type": return_type, "expected_return_type": return_type,
"actual_return_type": function_return_type "actual_return_type": function_return_type,
}) },
)
incorrect_return.add(function) incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data) ret["incorrect_return_type"].append(incorrect_return.data)
if view: if view:
if function_view: if function_view:
txt = f'\t[✓] {sig} is view' txt = f"\t[✓] {sig} is view"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {sig} should be view' txt = f"\t[ ] {sig} should be view"
logger.info(txt) logger.info(txt)
should_be_view = output.Output(txt) should_be_view = output.Output(txt)
@ -103,12 +110,12 @@ def _check_signature(erc_function, contract, ret):
event_sig = f'{event.name}({",".join(event.parameters)})' event_sig = f'{event.name}({",".join(event.parameters)})'
if not function: if not function:
txt = f'\t[ ] Must emit be view {event_sig}' txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt) logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={ missing_event_emmited = output.Output(
"missing_event": event_sig txt, additional_fields={"missing_event": event_sig}
}) )
missing_event_emmited.add(function) missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data) ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -121,15 +128,15 @@ def _check_signature(erc_function, contract, ret):
event_found = True event_found = True
break break
if event_found: if event_found:
txt = f'\t[✓] {event_sig} is emitted' txt = f"\t[✓] {event_sig} is emitted"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] Must emit be view {event_sig}' txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt) logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={ missing_event_emmited = output.Output(
"missing_event": event_sig txt, additional_fields={"missing_event": event_sig}
}) )
missing_event_emmited.add(function) missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data) ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -143,31 +150,27 @@ def _check_events(erc_event, contract, ret):
event = contract.get_event_from_signature(sig) event = contract.get_event_from_signature(sig)
if not event: if not event:
txt = f'[ ] {sig} is missing' txt = f"[ ] {sig} is missing"
logger.info(txt) logger.info(txt)
missing_event = output.Output(txt, additional_fields={ missing_event = output.Output(txt, additional_fields={"event": sig})
"event": sig
})
missing_event.add(contract) missing_event.add(contract)
ret["missing_event"].append(missing_event.data) ret["missing_event"].append(missing_event.data)
return return
txt = f'[✓] {sig} is present' txt = f"[✓] {sig} is present"
logger.info(txt) logger.info(txt)
for i, index in enumerate(indexes): for i, index in enumerate(indexes):
if index: if index:
if event.elems[i].indexed: if event.elems[i].indexed:
txt = f'\t[✓] parameter {i} is indexed' txt = f"\t[✓] parameter {i} is indexed"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] parameter {i} should be indexed' txt = f"\t[ ] parameter {i} should be indexed"
logger.info(txt) logger.info(txt)
missing_event_index = output.Output(txt, additional_fields={ missing_event_index = output.Output(txt, additional_fields={"missing_index": i})
"missing_index": i
})
missing_event_index.add_event(event) missing_event_index.add_event(event)
ret["missing_event_index"].append(missing_event_index.data) ret["missing_event_index"].append(missing_event_index.data)
@ -179,16 +182,16 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None):
explored.add(contract) explored.add(contract)
logger.info(f'# Check {contract.name}\n') logger.info(f"# Check {contract.name}\n")
logger.info(f'## Check functions') logger.info(f"## Check functions")
for erc_function in erc_functions: for erc_function in erc_functions:
_check_signature(erc_function, contract, ret) _check_signature(erc_function, contract, ret)
logger.info(f'\n## Check events') logger.info(f"\n## Check events")
for erc_event in erc_events: for erc_event in erc_events:
_check_events(erc_event, contract, ret) _check_events(erc_event, contract, ret)
logger.info('\n') logger.info("\n")
for derived_contract in contract.derived_contracts: for derived_contract in contract.derived_contracts:
generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored) generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored)

@ -11,27 +11,36 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
logger.addHandler(ch) logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='slither-kspec-coverage', parser = argparse.ArgumentParser(
usage='slither-kspec-coverage contract.sol kspec.md') description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md"
)
parser.add_argument('contract', help='The filename of the contract or truffle directory to analyze.') parser.add_argument(
parser.add_argument('kspec', help='The filename of the Klab spec markdown for the analyzed contract(s)') "contract", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument(
"kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)"
)
parser.add_argument('--version', help='displays the current version', version='0.1.0',action='version') parser.add_argument(
parser.add_argument('--json', "--version", help="displays the current version", version="0.1.0", action="version"
)
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)', help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store', action="store",
default=False default=False,
) )
cryticparser.init(parser) cryticparser.init(parser)
@ -54,5 +63,6 @@ def main():
kspec_coverage(args) kspec_coverage(args)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

@ -7,25 +7,22 @@ from slither.utils.colors import yellow, green, red
from slither.utils import output from slither.utils import output
logging.basicConfig(level=logging.WARNING) logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('Slither.kspec') logger = logging.getLogger("Slither.kspec")
def _refactor_type(type): def _refactor_type(type):
return { return {"uint": "uint256", "int": "int256"}.get(type, type)
'uint': 'uint256',
'int': 'int256'
}.get(type, type)
def _get_all_covered_kspec_functions(target): def _get_all_covered_kspec_functions(target):
# Create a set of our discovered functions which are covered # Create a set of our discovered functions which are covered
covered_functions = set() covered_functions = set()
BEHAVIOUR_PATTERN = re.compile('behaviour\s+(\S+)\s+of\s+(\S+)') BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)")
INTERFACE_PATTERN = re.compile('interface\s+([^\r\n]+)') INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)")
# Read the file contents # Read the file contents
with open(target, 'r', encoding='utf8') as target_file: with open(target, "r", encoding="utf8") as target_file:
lines = target_file.readlines() lines = target_file.readlines()
# Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex, # Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex,
@ -38,10 +35,12 @@ def _get_all_covered_kspec_functions(target):
match = INTERFACE_PATTERN.match(lines[i + 1]) match = INTERFACE_PATTERN.match(lines[i + 1])
if match: if match:
function_full_name = match.groups()[0] function_full_name = match.groups()[0]
start, end = function_full_name.index('(') + 1, function_full_name.index(')') start, end = function_full_name.index("(") + 1, function_full_name.index(")")
function_arguments = function_full_name[start:end].split(',') function_arguments = function_full_name[start:end].split(",")
function_arguments = [_refactor_type(arg.strip().split(' ')[0]) for arg in function_arguments] function_arguments = [
function_full_name = function_full_name[:start] + ','.join(function_arguments) + ')' _refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments
]
function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")"
covered_functions.add((contract_name, function_full_name)) covered_functions.add((contract_name, function_full_name))
i += 1 i += 1
i += 1 i += 1
@ -50,14 +49,25 @@ def _get_all_covered_kspec_functions(target):
def _get_slither_functions(slither): def _get_slither_functions(slither):
# Use contract == contract_declarer to avoid dupplicate # Use contract == contract_declarer to avoid dupplicate
all_functions_declared = [f for f in slither.functions if (f.contract == f.contract_declarer all_functions_declared = [
f
for f in slither.functions
if (
f.contract == f.contract_declarer
and f.is_implemented and f.is_implemented
and not f.is_constructor and not f.is_constructor
and not f.is_constructor_variables)] and not f.is_constructor_variables
)
]
# Use list(set()) because same state variable instances can be shared accross contracts # Use list(set()) because same state variable instances can be shared accross contracts
# TODO: integrate state variables # TODO: integrate state variables
all_functions_declared += list(set([s for s in slither.state_variables if s.visibility in ['public', 'external']])) all_functions_declared += list(
slither_functions = {(function.contract.name, function.full_name): function for function in all_functions_declared} set([s for s in slither.state_variables if s.visibility in ["public", "external"]])
)
slither_functions = {
(function.contract.name, function.full_name): function
for function in all_functions_declared
}
return slither_functions return slither_functions
@ -110,35 +120,42 @@ def _run_coverage_analysis(args, slither, kspec_functions):
else: else:
kspec_missing.append(slither_func) kspec_missing.append(slither_func)
logger.info('## Check for functions coverage') logger.info("## Check for functions coverage")
json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json)
json_kspec_missing_functions = _generate_output([f for f in kspec_missing if isinstance(f, Function)], json_kspec_missing_functions = _generate_output(
[f for f in kspec_missing if isinstance(f, Function)],
"[ ] (Missing function)", "[ ] (Missing function)",
red, red,
args.json) args.json,
json_kspec_missing_variables = _generate_output([f for f in kspec_missing if isinstance(f, Variable)], )
json_kspec_missing_variables = _generate_output(
[f for f in kspec_missing if isinstance(f, Variable)],
"[ ] (Missing variable)", "[ ] (Missing variable)",
yellow, yellow,
args.json) args.json,
json_kspec_unresolved = _generate_output_unresolved(kspec_functions_unresolved, )
"[ ] (Unresolved)", json_kspec_unresolved = _generate_output_unresolved(
yellow, kspec_functions_unresolved, "[ ] (Unresolved)", yellow, args.json
args.json) )
# Handle unresolved kspecs # Handle unresolved kspecs
if args.json: if args.json:
output.output_to_json(args.json, None, { output.output_to_json(
args.json,
None,
{
"functions_present": json_kspec_present, "functions_present": json_kspec_present,
"functions_missing": json_kspec_missing_functions, "functions_missing": json_kspec_missing_functions,
"variables_missing": json_kspec_missing_variables, "variables_missing": json_kspec_missing_variables,
"functions_unresolved": json_kspec_unresolved "functions_unresolved": json_kspec_unresolved,
}) },
)
def run_analysis(args, slither, kspec): def run_analysis(args, slither, kspec):
# Get all of our kspec'd functions (tuple(contract_name, function_name)). # Get all of our kspec'd functions (tuple(contract_name, function_name)).
if ',' in kspec: if "," in kspec:
kspecs = kspec.split(',') kspecs = kspec.split(",")
kspec_functions = set() kspec_functions = set()
for kspec in kspecs: for kspec in kspecs:
kspec_functions |= _get_all_covered_kspec_functions(kspec) kspec_functions |= _get_all_covered_kspec_functions(kspec)

@ -1,6 +1,7 @@
from slither.tools.kspec_coverage.analysis import run_analysis from slither.tools.kspec_coverage.analysis import run_analysis
from slither import Slither from slither import Slither
def kspec_coverage(args): def kspec_coverage(args):
contract = args.contract contract = args.contract
@ -10,5 +11,3 @@ def kspec_coverage(args):
# Run the analysis on the Klab specs # Run the analysis on the Klab specs
run_analysis(args, slither, kspec) run_analysis(args, slither, kspec)

@ -9,18 +9,21 @@ from crytic_compile import cryticparser
logging.basicConfig() logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO) logging.getLogger("Slither").setLevel(logging.INFO)
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='PossiblePaths', parser = argparse.ArgumentParser(
usage='possible_paths.py filename [contract.function targets]') description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]"
)
parser.add_argument('filename', parser.add_argument(
help='The filename of the contract or truffle directory to analyze.') "filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument('targets', nargs='+') parser.add_argument("targets", nargs="+")
cryticparser.init(parser) cryticparser.init(parser)
@ -62,12 +65,16 @@ def main():
print("\n") print("\n")
# Format all function paths. # Format all function paths.
reaching_paths_str = [' -> '.join([f"{f.canonical_name}" for f in reaching_path]) for reaching_path in reaching_paths] reaching_paths_str = [
" -> ".join([f"{f.canonical_name}" for f in reaching_path])
for reaching_path in reaching_paths
]
# Print a sorted list of all function paths which can reach the targets. # Print a sorted list of all function paths which can reach the targets.
print(f"The following paths reach the specified targets:") print(f"The following paths reach the specified targets:")
for reaching_path in sorted(reaching_paths_str): for reaching_path in sorted(reaching_paths_str):
print(f"{reaching_path}\n") print(f"{reaching_path}\n")
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

@ -1,4 +1,5 @@
class ResolveFunctionException(Exception): pass class ResolveFunctionException(Exception):
pass
def resolve_function(slither, contract_name, function_name): def resolve_function(slither, contract_name, function_name):
@ -16,11 +17,15 @@ def resolve_function(slither, contract_name, function_name):
raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}") raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}")
# Obtain the target function # Obtain the target function
target_function = next((function for function in contract.functions if function.name == function_name), None) target_function = next(
(function for function in contract.functions if function.name == function_name), None
)
# Verify we have resolved the function specified. # Verify we have resolved the function specified.
if target_function is None: if target_function is None:
raise ResolveFunctionException(f"Could not resolve target function: {contract_name}.{function_name}") raise ResolveFunctionException(
f"Could not resolve target function: {contract_name}.{function_name}"
)
# Add the resolved function to the new list. # Add the resolved function to the new list.
return target_function return target_function
@ -44,17 +49,23 @@ def resolve_functions(slither, functions):
for item in functions: for item in functions:
if isinstance(item, str): if isinstance(item, str):
# If the item is a single string, we assume it is of form 'ContractName.FunctionName'. # If the item is a single string, we assume it is of form 'ContractName.FunctionName'.
parts = item.split('.') parts = item.split(".")
if len(parts) < 2: if len(parts) < 2:
raise ResolveFunctionException("Provided string descriptor must be of form 'ContractName.FunctionName'") raise ResolveFunctionException(
"Provided string descriptor must be of form 'ContractName.FunctionName'"
)
resolved.append(resolve_function(slither, parts[0], parts[1])) resolved.append(resolve_function(slither, parts[0], parts[1]))
elif isinstance(item, tuple): elif isinstance(item, tuple):
# If the item is a tuple, it should be a 2-tuple providing contract and function names. # If the item is a tuple, it should be a 2-tuple providing contract and function names.
if len(item) != 2: if len(item) != 2:
raise ResolveFunctionException("Provided tuple descriptor must provide a contract and function name.") raise ResolveFunctionException(
"Provided tuple descriptor must provide a contract and function name."
)
resolved.append(resolve_function(slither, item[0], item[1])) resolved.append(resolve_function(slither, item[0], item[1]))
else: else:
raise ResolveFunctionException(f"Unexpected function descriptor type to resolve in list: {type(item)}") raise ResolveFunctionException(
f"Unexpected function descriptor type to resolve in list: {type(item)}"
)
# Return the resolved list. # Return the resolved list.
return resolved return resolved
@ -66,9 +77,12 @@ def all_function_definitions(function):
:param function: The function to obtain all definitions at and beneath. :param function: The function to obtain all definitions at and beneath.
:return: Returns a list composed of the provided function definition and any base definitions. :return: Returns a list composed of the provided function definition and any base definitions.
""" """
return [function] + [f for c in function.contract.inheritance return [function] + [
f
for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name] if f.full_name == function.full_name
]
def __find_target_paths(slither, target_function, current_path=[]): def __find_target_paths(slither, target_function, current_path=[]):
@ -102,7 +116,7 @@ def __find_target_paths(slither, target_function, current_path=[]):
results = results.union(path_results) results = results.union(path_results)
# If this path is external accessible from this point, we add the current path to the list. # If this path is external accessible from this point, we add the current path to the list.
if target_function.visibility in ['public', 'external'] and len(current_path) > 1: if target_function.visibility in ["public", "external"] and len(current_path) > 1:
results.add(tuple(current_path)) results.add(tuple(current_path))
return results return results
@ -122,6 +136,3 @@ def find_target_paths(slither, target_functions):
results = results.union(__find_target_paths(slither, target_function)) results = results.union(__find_target_paths(slither, target_function))
return results return results

@ -16,20 +16,21 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
logger.addHandler(ch) logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
def _all_scenarios(): def _all_scenarios():
txt = '\n' txt = "\n"
txt += '#################### ERC20 ####################\n' txt += "#################### ERC20 ####################\n"
for k, value in ERC20_PROPERTIES.items(): for k, value in ERC20_PROPERTIES.items():
txt += f'{k} - {value.description}\n' txt += f"{k} - {value.description}\n"
return txt return txt
def _all_properties(): def _all_properties():
table = MyPrettyTable(["Num", "Description", "Scenario"]) table = MyPrettyTable(["Num", "Description", "Scenario"])
idx = 0 idx = 0
@ -39,6 +40,7 @@ def _all_properties():
idx = idx + 1 idx = idx + 1
return table return table
class ListScenarios(argparse.Action): class ListScenarios(argparse.Action):
def __call__(self, parser, *args, **kwargs): def __call__(self, parser, *args, **kwargs):
logger.info(_all_scenarios()) logger.info(_all_scenarios())
@ -56,43 +58,51 @@ def parse_args():
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='Demo', parser = argparse.ArgumentParser(
usage='slither-demo filename', description="Demo",
formatter_class=argparse.RawDescriptionHelpFormatter) usage="slither-demo filename",
formatter_class=argparse.RawDescriptionHelpFormatter,
parser.add_argument('filename', )
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument(
parser.add_argument('--contract', "filename", help="The filename of the contract or truffle directory to analyze."
help='The targeted contract.') )
parser.add_argument('--scenario', parser.add_argument("--contract", help="The targeted contract.")
help=f'Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable',
default='Transferable') parser.add_argument(
"--scenario",
parser.add_argument('--list-scenarios', help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable",
help='List available scenarios', default="Transferable",
)
parser.add_argument(
"--list-scenarios",
help="List available scenarios",
action=ListScenarios, action=ListScenarios,
nargs=0, nargs=0,
default=False) default=False,
)
parser.add_argument('--list-properties', parser.add_argument(
help='List available properties', "--list-properties",
help="List available properties",
action=ListProperties, action=ListProperties,
nargs=0, nargs=0,
default=False) default=False,
)
parser.add_argument('--address-owner', parser.add_argument(
help=f'Owner address. Default {OWNER_ADDRESS}', "--address-owner", help=f"Owner address. Default {OWNER_ADDRESS}", default=None
default=None) )
parser.add_argument('--address-user', parser.add_argument(
help=f'Owner address. Default {USER_ADDRESS}', "--address-user", help=f"Owner address. Default {USER_ADDRESS}", default=None
default=None) )
parser.add_argument('--address-attacker', parser.add_argument(
help=f'Attacker address. Default {ATTACKER_ADDRESS}', "--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None
default=None) )
# Add default arguments from crytic-compile # Add default arguments from crytic-compile
cryticparser.init(parser) cryticparser.init(parser)
@ -116,9 +126,9 @@ def main():
contract = slither.contracts[0] contract = slither.contracts[0]
else: else:
if args.contract is None: if args.contract is None:
logger.error(f'Specify the target: --contract ContractName') logger.error(f"Specify the target: --contract ContractName")
else: else:
logger.error(f'{args.contract} not found') logger.error(f"{args.contract} not found")
return return
addresses = Addresses(args.address_owner, args.address_user, args.address_attacker) addresses = Addresses(args.address_owner, args.address_user, args.address_attacker)
@ -126,5 +136,5 @@ def main():
generate_erc20(contract, args.scenario, addresses) generate_erc20(contract, args.scenario, addresses)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

@ -8,8 +8,12 @@ ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef"
class Addresses: class Addresses:
def __init__(
def __init__(self, owner: Optional[str] = None, user: Optional[str] = None, attacker: Optional[str] = None): self,
owner: Optional[str] = None,
user: Optional[str] = None,
attacker: Optional[str] = None,
):
self.owner = owner if owner else OWNER_ADDRESS self.owner = owner if owner else OWNER_ADDRESS
self.user = user if user else USER_ADDRESS self.user = user if user else USER_ADDRESS
self.attacker = attacker if attacker else ATTACKER_ADDRESS self.attacker = attacker if attacker else ATTACKER_ADDRESS

@ -11,11 +11,11 @@ def generate_echidna_config(output_dir: Path, addresses: Addresses) -> str:
:param addresses: :param addresses:
:return: :return:
""" """
content = 'prefix: crytic_\n' content = "prefix: crytic_\n"
content += f'deployer: "{addresses.owner}"\n' content += f'deployer: "{addresses.owner}"\n'
content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n' content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n'
content += f'psender: "{addresses.user}"\n' content += f'psender: "{addresses.user}"\n'
content += 'coverage: true\n' content += "coverage: true\n"
filename = 'echidna_config.yaml' filename = "echidna_config.yaml"
write_file(output_dir, filename, content) write_file(output_dir, filename, content)
return filename return filename

@ -7,21 +7,21 @@ from slither.tools.properties.addresses.address import Addresses
from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller
from slither.tools.properties.utils import write_file from slither.tools.properties.utils import write_file
PATTERN_TRUFFLE_MIGRATION = re.compile('^[0-9]*_') PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_")
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def _extract_caller(p: PropertyCaller): def _extract_caller(p: PropertyCaller):
if p == PropertyCaller.OWNER: if p == PropertyCaller.OWNER:
return ['owner'] return ["owner"]
if p == PropertyCaller.SENDER: if p == PropertyCaller.SENDER:
return ['user'] return ["user"]
if p == PropertyCaller.ATTACKER: if p == PropertyCaller.ATTACKER:
return ['attacker'] return ["attacker"]
if p == PropertyCaller.ALL: if p == PropertyCaller.ALL:
return ['owner', 'user', 'attacker'] return ["owner", "user", "attacker"]
assert p == PropertyCaller.ANY assert p == PropertyCaller.ANY
return ['user'] return ["user"]
def _helpers(): def _helpers():
@ -31,7 +31,7 @@ def _helpers():
- catchRevertThrow: check if the call revert/throw - catchRevertThrow: check if the call revert/throw
:return: :return:
""" """
return ''' return """
async function catchRevertThrowReturnFalse(promise) { async function catchRevertThrowReturnFalse(promise) {
try { try {
const ret = await promise; const ret = await promise;
@ -61,12 +61,17 @@ async function catchRevertThrow(promise) {
} }
assert(false, "Expected revert/throw/or return false"); assert(false, "Expected revert/throw/or return false");
}; };
''' """
def generate_unit_test(test_contract: str, filename: str, def generate_unit_test(
unit_tests: List[Property], output_dir: Path, test_contract: str,
addresses: Addresses, assert_message: str = ''): filename: str,
unit_tests: List[Property],
output_dir: Path,
addresses: Addresses,
assert_message: str = "",
):
""" """
Generate unit tests files Generate unit tests files
:param test_contract: :param test_contract:
@ -88,37 +93,37 @@ def generate_unit_test(test_contract: str, filename: str,
content += f'\tlet attacker = "{addresses.attacker}";\n' content += f'\tlet attacker = "{addresses.attacker}";\n'
for unit_test in unit_tests: for unit_test in unit_tests:
content += f'\tit("{unit_test.description}", async () => {{\n' content += f'\tit("{unit_test.description}", async () => {{\n'
content += f'\t\tlet instance = await {test_contract}.deployed();\n' content += f"\t\tlet instance = await {test_contract}.deployed();\n"
callers = _extract_caller(unit_test.caller) callers = _extract_caller(unit_test.caller)
if unit_test.return_type == PropertyReturn.SUCCESS: if unit_test.return_type == PropertyReturn.SUCCESS:
for caller in callers: for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message: if assert_message:
content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n' content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n'
else: else:
content += f'\t\tassert.equal(test_{caller}, true);\n' content += f"\t\tassert.equal(test_{caller}, true);\n"
elif unit_test.return_type == PropertyReturn.FAIL: elif unit_test.return_type == PropertyReturn.FAIL:
for caller in callers: for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message: if assert_message:
content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n' content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n'
else: else:
content += f'\t\tassert.equal(test_{caller}, false);\n' content += f"\t\tassert.equal(test_{caller}, false);\n"
elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW: elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW:
for caller in callers: for caller in callers:
content += f'\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' content += f"\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
elif unit_test.return_type == PropertyReturn.THROW: elif unit_test.return_type == PropertyReturn.THROW:
callers = _extract_caller(unit_test.caller) callers = _extract_caller(unit_test.caller)
for caller in callers: for caller in callers:
content += f'\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' content += f"\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
content += '\t});\n' content += "\t});\n"
content += '});\n' content += "});\n"
output_dir = Path(output_dir, 'test') output_dir = Path(output_dir, "test")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
output_dir = Path(output_dir, 'crytic') output_dir = Path(output_dir, "crytic")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
write_file(output_dir, filename, content) write_file(output_dir, filename, content)
@ -133,28 +138,31 @@ def generate_migration(test_contract: str, output_dir: Path, owner_address: str)
:param owner_address: :param owner_address:
:return: :return:
""" """
content = f'''{test_contract} = artifacts.require("{test_contract}"); content = f"""{test_contract} = artifacts.require("{test_contract}");
module.exports = function(deployer) {{ module.exports = function(deployer) {{
deployer.deploy({test_contract}, {{from: "{owner_address}"}}); deployer.deploy({test_contract}, {{from: "{owner_address}"}});
}}; }};
''' """
output_dir = Path(output_dir, 'migrations') output_dir = Path(output_dir, "migrations")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
migration_files = [js_file for js_file in output_dir.iterdir() if js_file.suffix == '.js' migration_files = [
and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)] js_file
for js_file in output_dir.iterdir()
if js_file.suffix == ".js" and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)
]
idx = len(migration_files) idx = len(migration_files)
filename = f'{idx + 1}_{test_contract}.js' filename = f"{idx + 1}_{test_contract}.js"
potential_previous_filename = f'{idx}_{test_contract}.js' potential_previous_filename = f"{idx}_{test_contract}.js"
for m in migration_files: for m in migration_files:
if m.name == potential_previous_filename: if m.name == potential_previous_filename:
write_file(output_dir, potential_previous_filename, content) write_file(output_dir, potential_previous_filename, content)
return return
if test_contract in m.name: if test_contract in m.name:
logger.error(f'Potential conflicts with {m.name}') logger.error(f"Potential conflicts with {m.name}")
write_file(output_dir, filename, content) write_file(output_dir, filename, content)

@ -12,27 +12,37 @@ from slither.tools.properties.platforms.echidna import generate_echidna_config
from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable
from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG
from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable
from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ERC20_NotMintableNotBurnable from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import (
from slither.tools.properties.properties.ercs.erc20.properties.transfer import ERC20_Transferable, ERC20_Pausable ERC20_NotMintableNotBurnable,
)
from slither.tools.properties.properties.ercs.erc20.properties.transfer import (
ERC20_Transferable,
ERC20_Pausable,
)
from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test
from slither.tools.properties.properties.properties import property_to_solidity, Property from slither.tools.properties.properties.properties import property_to_solidity, Property
from slither.tools.properties.solidity.generate_properties import generate_solidity_properties, generate_test_contract, \ from slither.tools.properties.solidity.generate_properties import (
generate_solidity_interface generate_solidity_properties,
generate_test_contract,
generate_solidity_interface,
)
from slither.utils.colors import red, green from slither.utils.colors import red, green
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
PropertyDescription = namedtuple('PropertyDescription', ['properties', 'description']) PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"])
ERC20_PROPERTIES = { ERC20_PROPERTIES = {
"Transferable": PropertyDescription(ERC20_Transferable, 'Test the correct tokens transfer'), "Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"),
"Pausable": PropertyDescription(ERC20_Pausable, 'Test the pausable functionality'), "Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"),
"NotMintable": PropertyDescription(ERC20_NotMintable, 'Test that no one can mint tokens'), "NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"),
"NotMintableNotBurnable": PropertyDescription(ERC20_NotMintableNotBurnable, "NotMintableNotBurnable": PropertyDescription(
'Test that no one can mint or burn tokens'), ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens"
"NotBurnable": PropertyDescription(ERC20_NotBurnable, 'Test that no one can burn tokens'), ),
"Burnable": PropertyDescription(ERC20_NotBurnable, "NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"),
'Test the burn of tokens. Require the "burn(address) returns()" function') "Burnable": PropertyDescription(
ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function'
),
} }
@ -54,7 +64,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
:return: :return:
""" """
if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]: if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]:
logging.error(f'{contract.slither.crytic_compile.type} not yet supported by slither-prop') logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop")
return return
# Check if the contract is an ERC20 contract and if the functions have the correct visibility # Check if the contract is an ERC20 contract and if the functions have the correct visibility
@ -65,7 +75,9 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
properties = ERC20_PROPERTIES.get(type_property, None) properties = ERC20_PROPERTIES.get(type_property, None)
if properties is None: if properties is None:
logger.error(f'{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}') logger.error(
f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}"
)
return return
properties = properties.properties properties = properties.properties
@ -78,51 +90,53 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
# Generate the contract containing the properties # Generate the contract containing the properties
generate_solidity_interface(output_dir, addresses) generate_solidity_interface(output_dir, addresses)
property_file = generate_solidity_properties(contract, type_property, solidity_properties, output_dir) property_file = generate_solidity_properties(
contract, type_property, solidity_properties, output_dir
)
# Generate the Test contract # Generate the Test contract
initialization_recommendation = _initialization_recommendation(type_property) initialization_recommendation = _initialization_recommendation(type_property)
contract_filename, contract_name = generate_test_contract(contract, contract_filename, contract_name = generate_test_contract(
type_property, contract, type_property, output_dir, property_file, initialization_recommendation
output_dir, )
property_file,
initialization_recommendation)
# Generate Echidna config file # Generate Echidna config file
echidna_config_filename = generate_echidna_config(Path(contract.slither.crytic_compile.target).parent, addresses) echidna_config_filename = generate_echidna_config(
Path(contract.slither.crytic_compile.target).parent, addresses
)
unit_test_info = '' unit_test_info = ""
# If truffle, generate unit tests # If truffle, generate unit tests
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses) unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses)
logger.info('################################################') logger.info("################################################")
logger.info(green(f'Update the constructor in {Path(output_dir, contract_filename)}')) logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}"))
if unit_test_info: if unit_test_info:
logger.info(green(unit_test_info)) logger.info(green(unit_test_info))
logger.info(green('To run Echidna:')) logger.info(green("To run Echidna:"))
txt = f'\t echidna-test {contract.slither.crytic_compile.target} ' txt = f"\t echidna-test {contract.slither.crytic_compile.target} "
txt += f'--contract {contract_name} --config {echidna_config_filename}' txt += f"--contract {contract_name} --config {echidna_config_filename}"
logger.info(green(txt)) logger.info(green(txt))
def _initialization_recommendation(type_property: str) -> str: def _initialization_recommendation(type_property: str) -> str:
content = '' content = ""
content += '\t\t// Add below a minimal configuration:\n' content += "\t\t// Add below a minimal configuration:\n"
content += '\t\t// - crytic_owner must have some tokens \n' content += "\t\t// - crytic_owner must have some tokens \n"
content += '\t\t// - crytic_user must have some tokens \n' content += "\t\t// - crytic_user must have some tokens \n"
content += '\t\t// - crytic_attacker must have some tokens \n' content += "\t\t// - crytic_attacker must have some tokens \n"
if type_property in ['Pausable']: if type_property in ["Pausable"]:
content += '\t\t// - The contract must be paused \n' content += "\t\t// - The contract must be paused \n"
if type_property in ['NotMintable', 'NotMintableNotBurnable']: if type_property in ["NotMintable", "NotMintableNotBurnable"]:
content += '\t\t// - The contract must not be mintable \n' content += "\t\t// - The contract must not be mintable \n"
if type_property in ['NotBurnable', 'NotMintableNotBurnable']: if type_property in ["NotBurnable", "NotMintableNotBurnable"]:
content += '\t\t// - The contract must not be burnable \n' content += "\t\t// - The contract must not be burnable \n"
content += '\n' content += "\n"
content += '\n' content += "\n"
return content return content
@ -130,44 +144,44 @@ def _initialization_recommendation(type_property: str) -> str:
# TODO: move this to crytic-compile # TODO: move this to crytic-compile
def _platform_to_output_dir(platform: AbstractPlatform) -> Path: def _platform_to_output_dir(platform: AbstractPlatform) -> Path:
if platform.TYPE == PlatformType.TRUFFLE: if platform.TYPE == PlatformType.TRUFFLE:
return Path(platform.target, 'contracts', 'crytic') return Path(platform.target, "contracts", "crytic")
if platform.TYPE == PlatformType.SOLC: if platform.TYPE == PlatformType.SOLC:
return Path(platform.target).parent return Path(platform.target).parent
def _check_compatibility(contract): def _check_compatibility(contract):
errors = '' errors = ""
if not contract.is_erc20(): if not contract.is_erc20():
errors = f'{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc' errors = f"{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc"
return errors return errors
transfer = contract.get_function_from_signature('transfer(address,uint256)') transfer = contract.get_function_from_signature("transfer(address,uint256)")
if transfer.visibility != 'public': if transfer.visibility != "public":
errors = f'slither-prop requires {transfer.canonical_name} to be public. Please change the visibility' errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility"
transfer_from = contract.get_function_from_signature('transferFrom(address,address,uint256)') transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)")
if transfer_from.visibility != 'public': if transfer_from.visibility != "public":
if errors: if errors:
errors += '\n' errors += "\n"
errors += f'slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility' errors += f"slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility"
approve = contract.get_function_from_signature('approve(address,uint256)') approve = contract.get_function_from_signature("approve(address,uint256)")
if approve.visibility != 'public': if approve.visibility != "public":
if errors: if errors:
errors += '\n' errors += "\n"
errors += f'slither-prop requires {approve.canonical_name} to be public. Please change the visibility' errors += f"slither-prop requires {approve.canonical_name} to be public. Please change the visibility"
return errors return errors
def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]: def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]:
solidity_properties = '' solidity_properties = ""
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
solidity_properties += '\n'.join([property_to_solidity(p) for p in ERC20_CONFIG]) solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG])
solidity_properties += '\n'.join([property_to_solidity(p) for p in properties]) solidity_properties += "\n".join([property_to_solidity(p) for p in properties])
unit_tests = [p for p in properties if p.is_unit_test] unit_tests = [p for p in properties if p.is_unit_test]
return solidity_properties, unit_tests return solidity_properties, unit_tests

@ -1,30 +1,38 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotBurnable = [ ERC20_NotBurnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', Property(
description='The total supply does not decrease.', name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
content=''' description="The total supply does not decrease.",
\t\treturn initialTotalSupply == this.totalSupply();''', content="""
\t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
] ]
# Require burn(address) returns() # Require burn(address) returns()
ERC20_Burnable = [ ERC20_Burnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', Property(
description='Cannot burn more than available balance', name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
content=''' description="Cannot burn more than available balance",
content="""
\t\tuint balance = balanceOf(msg.sender); \t\tuint balance = balanceOf(msg.sender);
\t\tburn(balance + 1); \t\tburn(balance + 1);
\t\treturn false;''', \t\treturn false;""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.THROW, return_type=PropertyReturn.THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL) caller=PropertyCaller.ALL,
)
] ]

@ -1,65 +1,76 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_CONFIG = [ ERC20_CONFIG = [
Property(
Property(name='init_total_supply()', name="init_total_supply()",
description='The total supply is correctly initialized.', description="The total supply is correctly initialized.",
content=''' content="""
\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;''', \t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='init_owner_balance()', Property(
name="init_owner_balance()",
description="Owner's balance is correctly initialized.", description="Owner's balance is correctly initialized.",
content=''' content="""
\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);''', \t\treturn initialBalance_owner == this.balanceOf(crytic_owner);""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='init_user_balance()', Property(
name="init_user_balance()",
description="User's balance is correctly initialized.", description="User's balance is correctly initialized.",
content=''' content="""
\t\treturn initialBalance_user == this.balanceOf(crytic_user);''', \t\treturn initialBalance_user == this.balanceOf(crytic_user);""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='init_attacker_balance()', Property(
name="init_attacker_balance()",
description="Attacker's balance is correctly initialized.", description="Attacker's balance is correctly initialized.",
content=''' content="""
\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);''', \t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='init_caller_balance()', Property(
name="init_caller_balance()",
description="All the users have a positive balance.", description="All the users have a positive balance.",
content=''' content="""
\t\treturn this.balanceOf(msg.sender) >0 ;''', \t\treturn this.balanceOf(msg.sender) >0 ;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
# Note: there is a potential overflow on the addition, but we dont consider it # Note: there is a potential overflow on the addition, but we dont consider it
Property(name='init_total_supply_is_balances()', Property(
name="init_total_supply_is_balances()",
description="The total supply is the user and owner balance.", description="The total supply is the user and owner balance.",
content=''' content="""
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();''', \t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
] ]

@ -1,13 +1,20 @@
from slither.tools.properties.properties.properties import PropertyType, PropertyReturn, Property, PropertyCaller from slither.tools.properties.properties.properties import (
PropertyType,
PropertyReturn,
Property,
PropertyCaller,
)
ERC20_NotMintable = [ ERC20_NotMintable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotMintable()', Property(
description='The total supply does not increase.', name="crytic_supply_constant_ERC20PropertiesNotMintable()",
content=''' description="The total supply does not increase.",
\t\treturn initialTotalSupply >= totalSupply();''', content="""
\t\treturn initialTotalSupply >= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
] ]

@ -1,14 +1,20 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotMintableNotBurnable = [ ERC20_NotMintableNotBurnable = [
Property(
Property(name='crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()', name="crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()",
description='The total supply does not change.', description="The total supply does not change.",
content=''' content="""
\t\treturn initialTotalSupply == this.totalSupply();''', \t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
] ]

@ -1,96 +1,108 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_Transferable = [ ERC20_Transferable = [
Property(
Property(name='crytic_zero_always_empty_ERC20Properties()', name="crytic_zero_always_empty_ERC20Properties()",
description='The address 0x0 should not receive tokens.', description="The address 0x0 should not receive tokens.",
content=''' content="""
\t\treturn this.balanceOf(address(0x0)) == 0;''', \t\treturn this.balanceOf(address(0x0)) == 0;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='crytic_approve_overwrites()', Property(
description='Allowance can be changed.', name="crytic_approve_overwrites()",
content=''' description="Allowance can be changed.",
content="""
\t\tbool approve_return; \t\tbool approve_return;
\t\tapprove_return = approve(crytic_user, 10); \t\tapprove_return = approve(crytic_user, 10);
\t\trequire(approve_return); \t\trequire(approve_return);
\t\tapprove_return = approve(crytic_user, 20); \t\tapprove_return = approve(crytic_user, 20);
\t\trequire(approve_return); \t\trequire(approve_return);
\t\treturn this.allowance(msg.sender, crytic_user) == 20;''', \t\treturn this.allowance(msg.sender, crytic_user) == 20;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_less_than_total_ERC20Properties()', Property(
description='Balance of one user must be less or equal to the total supply.', name="crytic_less_than_total_ERC20Properties()",
content=''' description="Balance of one user must be less or equal to the total supply.",
\t\treturn this.balanceOf(msg.sender) <= totalSupply();''', content="""
\t\treturn this.balanceOf(msg.sender) <= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_totalSupply_consistant_ERC20Properties()', Property(
description='Balance of the crytic users must be less or equal to the total supply.', name="crytic_totalSupply_consistant_ERC20Properties()",
content=''' description="Balance of the crytic users must be less or equal to the total supply.",
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();''', content="""
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()', Property(
description='No one should be able to send tokens to the address 0x0 (transfer).', name="crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()",
content=''' description="No one should be able to send tokens to the address 0x0 (transfer).",
content="""
\t\tif (this.balanceOf(msg.sender) == 0){ \t\tif (this.balanceOf(msg.sender) == 0){
\t\t\trevert(); \t\t\trevert();
\t\t} \t\t}
\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));''', \t\treturn transfer(address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()', Property(
description='No one should be able to send tokens to the address 0x0 (transferFrom).', name="crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()",
content=''' description="No one should be able to send tokens to the address 0x0 (transferFrom).",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == 0){ \t\tif (balance == 0){
\t\t\trevert(); \t\t\trevert();
\t\t} \t\t}
\t\tapprove(msg.sender, balance); \t\tapprove(msg.sender, balance);
\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));''', \t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_self_transferFrom_ERC20PropertiesTransferable()', Property(
description='Self transferFrom works.', name="crytic_self_transferFrom_ERC20PropertiesTransferable()",
content=''' description="Self transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance); \t\tbool approve_return = approve(msg.sender, balance);
\t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance); \t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;''', \t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()', Property(
description='transferFrom works.', name="crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()",
content=''' description="transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance); \t\tbool approve_return = approve(msg.sender, balance);
\t\taddress other = crytic_user; \t\taddress other = crytic_user;
@ -98,29 +110,30 @@ ERC20_Transferable = [
\t\t\tother = crytic_owner; \t\t\tother = crytic_owner;
\t\t} \t\t}
\t\tbool transfer_return = transferFrom(msg.sender, other, balance); \t\tbool transfer_return = transferFrom(msg.sender, other, balance);
\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;''', \t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(
Property(name='crytic_self_transfer_ERC20PropertiesTransferable()', name="crytic_self_transfer_ERC20PropertiesTransferable()",
description='Self transfer works.', description="Self transfer works.",
content=''' content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tbool transfer_return = transfer(msg.sender, balance); \t\tbool transfer_return = transfer(msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;''', \t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_transfer_to_other_ERC20PropertiesTransferable()', Property(
description='transfer works.', name="crytic_transfer_to_other_ERC20PropertiesTransferable()",
content=''' description="transfer works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\taddress other = crytic_user; \t\taddress other = crytic_user;
\t\tif (other == msg.sender) { \t\tif (other == msg.sender) {
@ -130,74 +143,76 @@ ERC20_Transferable = [
\t\t\tbool transfer_other = transfer(other, 1); \t\t\tbool transfer_other = transfer(other, 1);
\t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other; \t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other;
\t\t} \t\t}
\t\treturn true;''', \t\treturn true;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_revert_transfer_to_user_ERC20PropertiesTransferable()', Property(
description='Cannot transfer more than the balance.', name="crytic_revert_transfer_to_user_ERC20PropertiesTransferable()",
content=''' description="Cannot transfer more than the balance.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == (2 ** 256 - 1)) \t\tif (balance == (2 ** 256 - 1))
\t\t\treturn true; \t\t\treturn true;
\t\tbool transfer_other = transfer(crytic_user, balance+1); \t\tbool transfer_other = transfer(crytic_user, balance+1);
\t\treturn transfer_other;''', \t\treturn transfer_other;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
] ]
ERC20_Pausable = [ ERC20_Pausable = [
Property(
Property(name='crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()', name="crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()",
description='Cannot transfer.', description="Cannot transfer.",
content=''' content="""
\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));''', \t\treturn transfer(crytic_user, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()', Property(
description='Cannot execute transferFrom.', name="crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()",
content=''' description="Cannot execute transferFrom.",
content="""
\t\tapprove(msg.sender, this.balanceOf(msg.sender)); \t\tapprove(msg.sender, this.balanceOf(msg.sender));
\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));''', \t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_constantBalance()', Property(
description='Cannot change the balance.', name="crytic_constantBalance()",
content=''' description="Cannot change the balance.",
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;''', content="""
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_constantAllowance()', Property(
description='Cannot change the allowance.', name="crytic_constantAllowance()",
content=''' description="Cannot change the allowance.",
content="""
\t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) && \t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) &&
\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);''', \t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
] ]

@ -11,26 +11,32 @@ from slither.tools.properties.properties.properties import Property
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def generate_truffle_test(contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses) -> str: def generate_truffle_test(
test_contract = f'Test{contract.name}{type_property}' contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses
filename_init = f'Initialization{test_contract}.js' ) -> str:
filename = f'{test_contract}.js' test_contract = f"Test{contract.name}{type_property}"
filename_init = f"Initialization{test_contract}.js"
filename = f"{test_contract}.js"
output_dir = Path(contract.slither.crytic_compile.target) output_dir = Path(contract.slither.crytic_compile.target)
generate_migration(test_contract, output_dir, addresses.owner) generate_migration(test_contract, output_dir, addresses.owner)
generate_unit_test(test_contract, generate_unit_test(
test_contract,
filename_init, filename_init,
ERC20_CONFIG, ERC20_CONFIG,
output_dir, output_dir,
addresses, addresses,
f'Check the constructor of {test_contract}') f"Check the constructor of {test_contract}",
)
generate_unit_test(test_contract, filename, unit_tests, output_dir, addresses,) generate_unit_test(
test_contract, filename, unit_tests, output_dir, addresses,
)
log_info = '\n' log_info = "\n"
log_info += 'To run the unit tests:\n' log_info += "To run the unit tests:\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n"
return log_info return log_info

@ -36,4 +36,4 @@ class Property(NamedTuple):
def property_to_solidity(p: Property): def property_to_solidity(p: Property):
return f'\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n' return f"\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n"

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save