Merge slither/tools slither/visitors/expression from dev-0.7

pull/514/head
Josselin 5 years ago
parent 010d84125a
commit 4319bb3605
  1. 14
      slither/tools/demo/__main__.py
  2. 39
      slither/tools/erc_conformance/__main__.py
  3. 14
      slither/tools/erc_conformance/erc/erc20.py
  4. 99
      slither/tools/erc_conformance/erc/ercs.py
  5. 48
      slither/tools/kspec_coverage/__main__.py
  6. 95
      slither/tools/kspec_coverage/analysis.py
  7. 3
      slither/tools/kspec_coverage/kspec_coverage.py
  8. 21
      slither/tools/possible_paths/__main__.py
  9. 39
      slither/tools/possible_paths/possible_paths.py
  10. 98
      slither/tools/properties/__main__.py
  11. 8
      slither/tools/properties/addresses/address.py
  12. 6
      slither/tools/properties/platforms/echidna.py
  13. 68
      slither/tools/properties/platforms/truffle.py
  14. 130
      slither/tools/properties/properties/erc20.py
  15. 48
      slither/tools/properties/properties/ercs/erc20/properties/burn.py
  16. 135
      slither/tools/properties/properties/ercs/erc20/properties/initialization.py
  17. 27
      slither/tools/properties/properties/ercs/erc20/properties/mint.py
  18. 30
      slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py
  19. 327
      slither/tools/properties/properties/ercs/erc20/properties/transfer.py
  20. 36
      slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py
  21. 2
      slither/tools/properties/properties/properties.py
  22. 69
      slither/tools/properties/solidity/generate_properties.py
  23. 20
      slither/tools/properties/utils.py
  24. 103
      slither/tools/similarity/__main__.py
  25. 8
      slither/tools/similarity/cache.py
  26. 167
      slither/tools/similarity/encode.py
  27. 19
      slither/tools/similarity/info.py
  28. 45
      slither/tools/similarity/plot.py
  29. 1
      slither/tools/similarity/similarity.py
  30. 25
      slither/tools/similarity/test.py
  31. 39
      slither/tools/similarity/train.py
  32. 115
      slither/tools/slither_format/__main__.py
  33. 120
      slither/tools/slither_format/slither_format.py
  34. 158
      slither/tools/upgradeability/__main__.py
  35. 94
      slither/tools/upgradeability/checks/abstract_checks.py
  36. 22
      slither/tools/upgradeability/checks/all_checks.py
  37. 55
      slither/tools/upgradeability/checks/constant.py
  38. 83
      slither/tools/upgradeability/checks/functions_ids.py
  39. 173
      slither/tools/upgradeability/checks/initialization.py
  40. 22
      slither/tools/upgradeability/checks/variable_initialization.py
  41. 108
      slither/tools/upgradeability/checks/variables_order.py
  42. 104
      slither/tools/upgradeability/utils/command_line.py
  43. 28
      slither/visitors/expression/constants_folding.py
  44. 7
      slither/visitors/expression/export_values.py
  45. 20
      slither/visitors/expression/expression.py
  46. 21
      slither/visitors/expression/expression_printer.py
  47. 12
      slither/visitors/expression/find_calls.py
  48. 8
      slither/visitors/expression/find_push.py
  49. 7
      slither/visitors/expression/has_conditional.py
  50. 10
      slither/visitors/expression/left_value.py
  51. 8
      slither/visitors/expression/read_var.py
  52. 10
      slither/visitors/expression/right_value.py
  53. 36
      slither/visitors/expression/write_var.py

@ -9,16 +9,17 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-demo") logger = logging.getLogger("Slither-demo")
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='Demo', parser = argparse.ArgumentParser(description="Demo", usage="slither-demo filename")
usage='slither-demo filename')
parser.add_argument('filename', parser.add_argument(
help='The filename of the contract or truffle directory to analyze.') "filename", help="The filename of the contract or truffle directory to analyze."
)
# Add default arguments from crytic-compile # Add default arguments from crytic-compile
cryticparser.init(parser) cryticparser.init(parser)
@ -32,7 +33,8 @@ def main():
# Perform slither analysis on the given filename # Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args)) slither = Slither(args.filename, **vars(args))
logger.info('Analysis done!') logger.info("Analysis done!")
if __name__ == '__main__': if __name__ == "__main__":
main() main()

@ -17,28 +17,29 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
logger.addHandler(ch) logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
ADDITIONAL_CHECKS = { ADDITIONAL_CHECKS = {"ERC20": check_erc20}
"ERC20": check_erc20
}
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='Check the ERC 20 conformance', parser = argparse.ArgumentParser(
usage='slither-erc project contractName') description="Check the ERC 20 conformance", usage="slither-erc project contractName"
)
parser.add_argument('project', parser.add_argument("project", help="The codebase to be tested.")
help='The codebase to be tested.')
parser.add_argument('contract_name', parser.add_argument(
help='The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.') "contract_name",
help="The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.",
)
parser.add_argument( parser.add_argument(
"--erc", "--erc",
@ -47,22 +48,26 @@ def parse_args():
default="erc20", default="erc20",
) )
parser.add_argument('--json', parser.add_argument(
help='Export the results as a JSON file ("--json -" to export to stdout)', "--json",
action='store', help='Export the results as a JSON file ("--json -" to export to stdout)',
default=False) action="store",
default=False,
)
# Add default arguments from crytic-compile # Add default arguments from crytic-compile
cryticparser.init(parser) cryticparser.init(parser)
return parser.parse_args() return parser.parse_args()
def _log_error(err, args): def _log_error(err, args):
if args.json: if args.json:
output_to_json(args.json, str(err), {"upgradeability-check": []}) output_to_json(args.json, str(err), {"upgradeability-check": []})
logger.error(err) logger.error(err)
def main(): def main():
args = parse_args() args = parse_args()
@ -76,7 +81,7 @@ def main():
contract = slither.get_contract_from_name(args.contract_name) contract = slither.get_contract_from_name(args.contract_name)
if not contract: if not contract:
err = f'Contract not found: {args.contract_name}' err = f"Contract not found: {args.contract_name}"
_log_error(err, args) _log_error(err, args)
return return
# First elem is the function, second is the event # First elem is the function, second is the event
@ -87,7 +92,7 @@ def main():
ADDITIONAL_CHECKS[args.erc.upper()](contract, ret) ADDITIONAL_CHECKS[args.erc.upper()](contract, ret)
else: else:
err = f'Incorrect ERC selected {args.erc}' err = f"Incorrect ERC selected {args.erc}"
_log_error(err, args) _log_error(err, args)
return return
@ -95,5 +100,5 @@ def main():
output_to_json(args.json, None, {"upgradeability-check": ret}) output_to_json(args.json, None, {"upgradeability-check": ret})
if __name__ == '__main__': if __name__ == "__main__":
main() main()

@ -6,21 +6,25 @@ logger = logging.getLogger("Slither-conformance")
def approval_race_condition(contract, ret): def approval_race_condition(contract, ret):
increaseAllowance = contract.get_function_from_signature('increaseAllowance(address,uint256)') increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)")
if not increaseAllowance: if not increaseAllowance:
increaseAllowance = contract.get_function_from_signature('safeIncreaseAllowance(address,uint256)') increaseAllowance = contract.get_function_from_signature(
"safeIncreaseAllowance(address,uint256)"
)
if increaseAllowance: if increaseAllowance:
txt = f'\t[✓] {contract.name} has {increaseAllowance.full_name}' txt = f"\t[✓] {contract.name} has {increaseAllowance.full_name}"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {contract.name} is not protected for the ERC20 approval race condition' txt = f"\t[ ] {contract.name} is not protected for the ERC20 approval race condition"
logger.info(txt) logger.info(txt)
lack_of_erc20_race_condition_protection = output.Output(txt) lack_of_erc20_race_condition_protection = output.Output(txt)
lack_of_erc20_race_condition_protection.add(contract) lack_of_erc20_race_condition_protection.add(contract)
ret["lack_of_erc20_race_condition_protection"].append(lack_of_erc20_race_condition_protection.data) ret["lack_of_erc20_race_condition_protection"].append(
lack_of_erc20_race_condition_protection.data
)
def check_erc20(contract, ret, explored=None): def check_erc20(contract, ret, explored=None):

@ -22,13 +22,15 @@ def _check_signature(erc_function, contract, ret):
# The check on state variable is needed until we have a better API to handle state variable getters # The check on state variable is needed until we have a better API to handle state variable getters
state_variable_as_function = contract.get_state_variable_from_name(name) state_variable_as_function = contract.get_state_variable_from_name(name)
if not state_variable_as_function or not state_variable_as_function.visibility in ['public', 'external']: if not state_variable_as_function or not state_variable_as_function.visibility in [
"public",
"external",
]:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt) logger.info(txt)
missing_func = output.Output(txt, additional_fields={ missing_func = output.Output(
"function": sig, txt, additional_fields={"function": sig, "required": required}
"required": required )
})
missing_func.add(contract) missing_func.add(contract)
ret["missing_function"].append(missing_func.data) ret["missing_function"].append(missing_func.data)
return return
@ -38,10 +40,9 @@ def _check_signature(erc_function, contract, ret):
if types != parameters: if types != parameters:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt) logger.info(txt)
missing_func = output.Output(txt, additional_fields={ missing_func = output.Output(
"function": sig, txt, additional_fields={"function": sig, "required": required}
"required": required )
})
missing_func.add(contract) missing_func.add(contract)
ret["missing_function"].append(missing_func.data) ret["missing_function"].append(missing_func.data)
return return
@ -53,45 +54,51 @@ def _check_signature(erc_function, contract, ret):
function_return_type = function.return_type function_return_type = function.return_type
function_view = function.view function_view = function.view
txt = f'[✓] {sig} is present' txt = f"[✓] {sig} is present"
logger.info(txt) logger.info(txt)
if function_return_type: if function_return_type:
function_return_type = ','.join([str(x) for x in function_return_type]) function_return_type = ",".join([str(x) for x in function_return_type])
if function_return_type == return_type: if function_return_type == return_type:
txt = f'\t[✓] {sig} -> () (correct return value)' txt = f"\t[✓] {sig} -> () (correct return value)"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {sig} -> () should return {return_type}' txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt) logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={ incorrect_return = output.Output(
"expected_return_type": return_type, txt,
"actual_return_type": function_return_type additional_fields={
}) "expected_return_type": return_type,
"actual_return_type": function_return_type,
},
)
incorrect_return.add(function) incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data) ret["incorrect_return_type"].append(incorrect_return.data)
elif not return_type: elif not return_type:
txt = f'\t[✓] {sig} -> () (correct return type)' txt = f"\t[✓] {sig} -> () (correct return type)"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {sig} -> () should return {return_type}' txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt) logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={ incorrect_return = output.Output(
"expected_return_type": return_type, txt,
"actual_return_type": function_return_type additional_fields={
}) "expected_return_type": return_type,
"actual_return_type": function_return_type,
},
)
incorrect_return.add(function) incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data) ret["incorrect_return_type"].append(incorrect_return.data)
if view: if view:
if function_view: if function_view:
txt = f'\t[✓] {sig} is view' txt = f"\t[✓] {sig} is view"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] {sig} should be view' txt = f"\t[ ] {sig} should be view"
logger.info(txt) logger.info(txt)
should_be_view = output.Output(txt) should_be_view = output.Output(txt)
@ -103,12 +110,12 @@ def _check_signature(erc_function, contract, ret):
event_sig = f'{event.name}({",".join(event.parameters)})' event_sig = f'{event.name}({",".join(event.parameters)})'
if not function: if not function:
txt = f'\t[ ] Must emit be view {event_sig}' txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt) logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={ missing_event_emmited = output.Output(
"missing_event": event_sig txt, additional_fields={"missing_event": event_sig}
}) )
missing_event_emmited.add(function) missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data) ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -121,15 +128,15 @@ def _check_signature(erc_function, contract, ret):
event_found = True event_found = True
break break
if event_found: if event_found:
txt = f'\t[✓] {event_sig} is emitted' txt = f"\t[✓] {event_sig} is emitted"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] Must emit be view {event_sig}' txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt) logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={ missing_event_emmited = output.Output(
"missing_event": event_sig txt, additional_fields={"missing_event": event_sig}
}) )
missing_event_emmited.add(function) missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data) ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -143,31 +150,27 @@ def _check_events(erc_event, contract, ret):
event = contract.get_event_from_signature(sig) event = contract.get_event_from_signature(sig)
if not event: if not event:
txt = f'[ ] {sig} is missing' txt = f"[ ] {sig} is missing"
logger.info(txt) logger.info(txt)
missing_event = output.Output(txt, additional_fields={ missing_event = output.Output(txt, additional_fields={"event": sig})
"event": sig
})
missing_event.add(contract) missing_event.add(contract)
ret["missing_event"].append(missing_event.data) ret["missing_event"].append(missing_event.data)
return return
txt = f'[✓] {sig} is present' txt = f"[✓] {sig} is present"
logger.info(txt) logger.info(txt)
for i, index in enumerate(indexes): for i, index in enumerate(indexes):
if index: if index:
if event.elems[i].indexed: if event.elems[i].indexed:
txt = f'\t[✓] parameter {i} is indexed' txt = f"\t[✓] parameter {i} is indexed"
logger.info(txt) logger.info(txt)
else: else:
txt = f'\t[ ] parameter {i} should be indexed' txt = f"\t[ ] parameter {i} should be indexed"
logger.info(txt) logger.info(txt)
missing_event_index = output.Output(txt, additional_fields={ missing_event_index = output.Output(txt, additional_fields={"missing_index": i})
"missing_index": i
})
missing_event_index.add_event(event) missing_event_index.add_event(event)
ret["missing_event_index"].append(missing_event_index.data) ret["missing_event_index"].append(missing_event_index.data)
@ -179,16 +182,16 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None):
explored.add(contract) explored.add(contract)
logger.info(f'# Check {contract.name}\n') logger.info(f"# Check {contract.name}\n")
logger.info(f'## Check functions') logger.info(f"## Check functions")
for erc_function in erc_functions: for erc_function in erc_functions:
_check_signature(erc_function, contract, ret) _check_signature(erc_function, contract, ret)
logger.info(f'\n## Check events') logger.info(f"\n## Check events")
for erc_event in erc_events: for erc_event in erc_events:
_check_events(erc_event, contract, ret) _check_events(erc_event, contract, ret)
logger.info('\n') logger.info("\n")
for derived_contract in contract.derived_contracts: for derived_contract in contract.derived_contracts:
generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored) generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored)

@ -11,35 +11,44 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
logger.addHandler(ch) logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='slither-kspec-coverage', parser = argparse.ArgumentParser(
usage='slither-kspec-coverage contract.sol kspec.md') description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md"
)
parser.add_argument('contract', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('kspec', help='The filename of the Klab spec markdown for the analyzed contract(s)') parser.add_argument(
"contract", help="The filename of the contract or truffle directory to analyze."
parser.add_argument('--version', help='displays the current version', version='0.1.0',action='version') )
parser.add_argument('--json', parser.add_argument(
help='Export the results as a JSON file ("--json -" to export to stdout)', "kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)"
action='store',
default=False
) )
cryticparser.init(parser) parser.add_argument(
"--version", help="displays the current version", version="0.1.0", action="version"
if len(sys.argv) < 2: )
parser.print_help(sys.stderr) parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action="store",
default=False,
)
cryticparser.init(parser)
if len(sys.argv) < 2:
parser.print_help(sys.stderr)
sys.exit(1) sys.exit(1)
return parser.parse_args() return parser.parse_args()
@ -53,6 +62,7 @@ def main():
args = parse_args() args = parse_args()
kspec_coverage(args) kspec_coverage(args)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

@ -7,25 +7,22 @@ from slither.utils.colors import yellow, green, red
from slither.utils import output from slither.utils import output
logging.basicConfig(level=logging.WARNING) logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('Slither.kspec') logger = logging.getLogger("Slither.kspec")
def _refactor_type(type): def _refactor_type(type):
return { return {"uint": "uint256", "int": "int256"}.get(type, type)
'uint': 'uint256',
'int': 'int256'
}.get(type, type)
def _get_all_covered_kspec_functions(target): def _get_all_covered_kspec_functions(target):
# Create a set of our discovered functions which are covered # Create a set of our discovered functions which are covered
covered_functions = set() covered_functions = set()
BEHAVIOUR_PATTERN = re.compile('behaviour\s+(\S+)\s+of\s+(\S+)') BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)")
INTERFACE_PATTERN = re.compile('interface\s+([^\r\n]+)') INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)")
# Read the file contents # Read the file contents
with open(target, 'r', encoding='utf8') as target_file: with open(target, "r", encoding="utf8") as target_file:
lines = target_file.readlines() lines = target_file.readlines()
# Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex, # Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex,
@ -38,10 +35,12 @@ def _get_all_covered_kspec_functions(target):
match = INTERFACE_PATTERN.match(lines[i + 1]) match = INTERFACE_PATTERN.match(lines[i + 1])
if match: if match:
function_full_name = match.groups()[0] function_full_name = match.groups()[0]
start, end = function_full_name.index('(') + 1, function_full_name.index(')') start, end = function_full_name.index("(") + 1, function_full_name.index(")")
function_arguments = function_full_name[start:end].split(',') function_arguments = function_full_name[start:end].split(",")
function_arguments = [_refactor_type(arg.strip().split(' ')[0]) for arg in function_arguments] function_arguments = [
function_full_name = function_full_name[:start] + ','.join(function_arguments) + ')' _refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments
]
function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")"
covered_functions.add((contract_name, function_full_name)) covered_functions.add((contract_name, function_full_name))
i += 1 i += 1
i += 1 i += 1
@ -50,14 +49,25 @@ def _get_all_covered_kspec_functions(target):
def _get_slither_functions(slither): def _get_slither_functions(slither):
# Use contract == contract_declarer to avoid dupplicate # Use contract == contract_declarer to avoid dupplicate
all_functions_declared = [f for f in slither.functions if (f.contract == f.contract_declarer all_functions_declared = [
and f.is_implemented f
and not f.is_constructor for f in slither.functions
and not f.is_constructor_variables)] if (
f.contract == f.contract_declarer
and f.is_implemented
and not f.is_constructor
and not f.is_constructor_variables
)
]
# Use list(set()) because same state variable instances can be shared accross contracts # Use list(set()) because same state variable instances can be shared accross contracts
# TODO: integrate state variables # TODO: integrate state variables
all_functions_declared += list(set([s for s in slither.state_variables if s.visibility in ['public', 'external']])) all_functions_declared += list(
slither_functions = {(function.contract.name, function.full_name): function for function in all_functions_declared} set([s for s in slither.state_variables if s.visibility in ["public", "external"]])
)
slither_functions = {
(function.contract.name, function.full_name): function
for function in all_functions_declared
}
return slither_functions return slither_functions
@ -110,35 +120,42 @@ def _run_coverage_analysis(args, slither, kspec_functions):
else: else:
kspec_missing.append(slither_func) kspec_missing.append(slither_func)
logger.info('## Check for functions coverage') logger.info("## Check for functions coverage")
json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json)
json_kspec_missing_functions = _generate_output([f for f in kspec_missing if isinstance(f, Function)], json_kspec_missing_functions = _generate_output(
"[ ] (Missing function)", [f for f in kspec_missing if isinstance(f, Function)],
red, "[ ] (Missing function)",
args.json) red,
json_kspec_missing_variables = _generate_output([f for f in kspec_missing if isinstance(f, Variable)], args.json,
"[ ] (Missing variable)", )
yellow, json_kspec_missing_variables = _generate_output(
args.json) [f for f in kspec_missing if isinstance(f, Variable)],
json_kspec_unresolved = _generate_output_unresolved(kspec_functions_unresolved, "[ ] (Missing variable)",
"[ ] (Unresolved)", yellow,
yellow, args.json,
args.json) )
json_kspec_unresolved = _generate_output_unresolved(
kspec_functions_unresolved, "[ ] (Unresolved)", yellow, args.json
)
# Handle unresolved kspecs # Handle unresolved kspecs
if args.json: if args.json:
output.output_to_json(args.json, None, { output.output_to_json(
"functions_present": json_kspec_present, args.json,
"functions_missing": json_kspec_missing_functions, None,
"variables_missing": json_kspec_missing_variables, {
"functions_unresolved": json_kspec_unresolved "functions_present": json_kspec_present,
}) "functions_missing": json_kspec_missing_functions,
"variables_missing": json_kspec_missing_variables,
"functions_unresolved": json_kspec_unresolved,
},
)
def run_analysis(args, slither, kspec): def run_analysis(args, slither, kspec):
# Get all of our kspec'd functions (tuple(contract_name, function_name)). # Get all of our kspec'd functions (tuple(contract_name, function_name)).
if ',' in kspec: if "," in kspec:
kspecs = kspec.split(',') kspecs = kspec.split(",")
kspec_functions = set() kspec_functions = set()
for kspec in kspecs: for kspec in kspecs:
kspec_functions |= _get_all_covered_kspec_functions(kspec) kspec_functions |= _get_all_covered_kspec_functions(kspec)

@ -1,6 +1,7 @@
from slither.tools.kspec_coverage.analysis import run_analysis from slither.tools.kspec_coverage.analysis import run_analysis
from slither import Slither from slither import Slither
def kspec_coverage(args): def kspec_coverage(args):
contract = args.contract contract = args.contract
@ -10,5 +11,3 @@ def kspec_coverage(args):
# Run the analysis on the Klab specs # Run the analysis on the Klab specs
run_analysis(args, slither, kspec) run_analysis(args, slither, kspec)

@ -9,18 +9,21 @@ from crytic_compile import cryticparser
logging.basicConfig() logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO) logging.getLogger("Slither").setLevel(logging.INFO)
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='PossiblePaths', parser = argparse.ArgumentParser(
usage='possible_paths.py filename [contract.function targets]') description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]"
)
parser.add_argument('filename', parser.add_argument(
help='The filename of the contract or truffle directory to analyze.') "filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument('targets', nargs='+') parser.add_argument("targets", nargs="+")
cryticparser.init(parser) cryticparser.init(parser)
@ -62,12 +65,16 @@ def main():
print("\n") print("\n")
# Format all function paths. # Format all function paths.
reaching_paths_str = [' -> '.join([f"{f.canonical_name}" for f in reaching_path]) for reaching_path in reaching_paths] reaching_paths_str = [
" -> ".join([f"{f.canonical_name}" for f in reaching_path])
for reaching_path in reaching_paths
]
# Print a sorted list of all function paths which can reach the targets. # Print a sorted list of all function paths which can reach the targets.
print(f"The following paths reach the specified targets:") print(f"The following paths reach the specified targets:")
for reaching_path in sorted(reaching_paths_str): for reaching_path in sorted(reaching_paths_str):
print(f"{reaching_path}\n") print(f"{reaching_path}\n")
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

@ -1,4 +1,5 @@
class ResolveFunctionException(Exception): pass class ResolveFunctionException(Exception):
pass
def resolve_function(slither, contract_name, function_name): def resolve_function(slither, contract_name, function_name):
@ -16,11 +17,15 @@ def resolve_function(slither, contract_name, function_name):
raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}") raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}")
# Obtain the target function # Obtain the target function
target_function = next((function for function in contract.functions if function.name == function_name), None) target_function = next(
(function for function in contract.functions if function.name == function_name), None
)
# Verify we have resolved the function specified. # Verify we have resolved the function specified.
if target_function is None: if target_function is None:
raise ResolveFunctionException(f"Could not resolve target function: {contract_name}.{function_name}") raise ResolveFunctionException(
f"Could not resolve target function: {contract_name}.{function_name}"
)
# Add the resolved function to the new list. # Add the resolved function to the new list.
return target_function return target_function
@ -44,17 +49,23 @@ def resolve_functions(slither, functions):
for item in functions: for item in functions:
if isinstance(item, str): if isinstance(item, str):
# If the item is a single string, we assume it is of form 'ContractName.FunctionName'. # If the item is a single string, we assume it is of form 'ContractName.FunctionName'.
parts = item.split('.') parts = item.split(".")
if len(parts) < 2: if len(parts) < 2:
raise ResolveFunctionException("Provided string descriptor must be of form 'ContractName.FunctionName'") raise ResolveFunctionException(
"Provided string descriptor must be of form 'ContractName.FunctionName'"
)
resolved.append(resolve_function(slither, parts[0], parts[1])) resolved.append(resolve_function(slither, parts[0], parts[1]))
elif isinstance(item, tuple): elif isinstance(item, tuple):
# If the item is a tuple, it should be a 2-tuple providing contract and function names. # If the item is a tuple, it should be a 2-tuple providing contract and function names.
if len(item) != 2: if len(item) != 2:
raise ResolveFunctionException("Provided tuple descriptor must provide a contract and function name.") raise ResolveFunctionException(
"Provided tuple descriptor must provide a contract and function name."
)
resolved.append(resolve_function(slither, item[0], item[1])) resolved.append(resolve_function(slither, item[0], item[1]))
else: else:
raise ResolveFunctionException(f"Unexpected function descriptor type to resolve in list: {type(item)}") raise ResolveFunctionException(
f"Unexpected function descriptor type to resolve in list: {type(item)}"
)
# Return the resolved list. # Return the resolved list.
return resolved return resolved
@ -66,9 +77,12 @@ def all_function_definitions(function):
:param function: The function to obtain all definitions at and beneath. :param function: The function to obtain all definitions at and beneath.
:return: Returns a list composed of the provided function definition and any base definitions. :return: Returns a list composed of the provided function definition and any base definitions.
""" """
return [function] + [f for c in function.contract.inheritance return [function] + [
for f in c.functions_and_modifiers_declared f
if f.full_name == function.full_name] for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name
]
def __find_target_paths(slither, target_function, current_path=[]): def __find_target_paths(slither, target_function, current_path=[]):
@ -102,7 +116,7 @@ def __find_target_paths(slither, target_function, current_path=[]):
results = results.union(path_results) results = results.union(path_results)
# If this path is external accessible from this point, we add the current path to the list. # If this path is external accessible from this point, we add the current path to the list.
if target_function.visibility in ['public', 'external'] and len(current_path) > 1: if target_function.visibility in ["public", "external"] and len(current_path) > 1:
results.add(tuple(current_path)) results.add(tuple(current_path))
return results return results
@ -122,6 +136,3 @@ def find_target_paths(slither, target_functions):
results = results.union(__find_target_paths(slither, target_function)) results = results.union(__find_target_paths(slither, target_function))
return results return results

@ -16,20 +16,21 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
logger.addHandler(ch) logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
def _all_scenarios(): def _all_scenarios():
txt = '\n' txt = "\n"
txt += '#################### ERC20 ####################\n' txt += "#################### ERC20 ####################\n"
for k, value in ERC20_PROPERTIES.items(): for k, value in ERC20_PROPERTIES.items():
txt += f'{k} - {value.description}\n' txt += f"{k} - {value.description}\n"
return txt return txt
def _all_properties(): def _all_properties():
table = MyPrettyTable(["Num", "Description", "Scenario"]) table = MyPrettyTable(["Num", "Description", "Scenario"])
idx = 0 idx = 0
@ -39,6 +40,7 @@ def _all_properties():
idx = idx + 1 idx = idx + 1
return table return table
class ListScenarios(argparse.Action): class ListScenarios(argparse.Action):
def __call__(self, parser, *args, **kwargs): def __call__(self, parser, *args, **kwargs):
logger.info(_all_scenarios()) logger.info(_all_scenarios())
@ -56,43 +58,51 @@ def parse_args():
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='Demo', parser = argparse.ArgumentParser(
usage='slither-demo filename', description="Demo",
formatter_class=argparse.RawDescriptionHelpFormatter) usage="slither-demo filename",
formatter_class=argparse.RawDescriptionHelpFormatter,
parser.add_argument('filename', )
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument(
parser.add_argument('--contract', "filename", help="The filename of the contract or truffle directory to analyze."
help='The targeted contract.') )
parser.add_argument('--scenario', parser.add_argument("--contract", help="The targeted contract.")
help=f'Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable',
default='Transferable') parser.add_argument(
"--scenario",
parser.add_argument('--list-scenarios', help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable",
help='List available scenarios', default="Transferable",
action=ListScenarios, )
nargs=0,
default=False) parser.add_argument(
"--list-scenarios",
parser.add_argument('--list-properties', help="List available scenarios",
help='List available properties', action=ListScenarios,
action=ListProperties, nargs=0,
nargs=0, default=False,
default=False) )
parser.add_argument('--address-owner', parser.add_argument(
help=f'Owner address. Default {OWNER_ADDRESS}', "--list-properties",
default=None) help="List available properties",
action=ListProperties,
parser.add_argument('--address-user', nargs=0,
help=f'Owner address. Default {USER_ADDRESS}', default=False,
default=None) )
parser.add_argument('--address-attacker', parser.add_argument(
help=f'Attacker address. Default {ATTACKER_ADDRESS}', "--address-owner", help=f"Owner address. Default {OWNER_ADDRESS}", default=None
default=None) )
parser.add_argument(
"--address-user", help=f"Owner address. Default {USER_ADDRESS}", default=None
)
parser.add_argument(
"--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None
)
# Add default arguments from crytic-compile # Add default arguments from crytic-compile
cryticparser.init(parser) cryticparser.init(parser)
@ -116,9 +126,9 @@ def main():
contract = slither.contracts[0] contract = slither.contracts[0]
else: else:
if args.contract is None: if args.contract is None:
logger.error(f'Specify the target: --contract ContractName') logger.error(f"Specify the target: --contract ContractName")
else: else:
logger.error(f'{args.contract} not found') logger.error(f"{args.contract} not found")
return return
addresses = Addresses(args.address_owner, args.address_user, args.address_attacker) addresses = Addresses(args.address_owner, args.address_user, args.address_attacker)
@ -126,5 +136,5 @@ def main():
generate_erc20(contract, args.scenario, addresses) generate_erc20(contract, args.scenario, addresses)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

@ -8,8 +8,12 @@ ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef"
class Addresses: class Addresses:
def __init__(
def __init__(self, owner: Optional[str] = None, user: Optional[str] = None, attacker: Optional[str] = None): self,
owner: Optional[str] = None,
user: Optional[str] = None,
attacker: Optional[str] = None,
):
self.owner = owner if owner else OWNER_ADDRESS self.owner = owner if owner else OWNER_ADDRESS
self.user = user if user else USER_ADDRESS self.user = user if user else USER_ADDRESS
self.attacker = attacker if attacker else ATTACKER_ADDRESS self.attacker = attacker if attacker else ATTACKER_ADDRESS

@ -11,11 +11,11 @@ def generate_echidna_config(output_dir: Path, addresses: Addresses) -> str:
:param addresses: :param addresses:
:return: :return:
""" """
content = 'prefix: crytic_\n' content = "prefix: crytic_\n"
content += f'deployer: "{addresses.owner}"\n' content += f'deployer: "{addresses.owner}"\n'
content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n' content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n'
content += f'psender: "{addresses.user}"\n' content += f'psender: "{addresses.user}"\n'
content += 'coverage: true\n' content += "coverage: true\n"
filename = 'echidna_config.yaml' filename = "echidna_config.yaml"
write_file(output_dir, filename, content) write_file(output_dir, filename, content)
return filename return filename

@ -7,21 +7,21 @@ from slither.tools.properties.addresses.address import Addresses
from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller
from slither.tools.properties.utils import write_file from slither.tools.properties.utils import write_file
PATTERN_TRUFFLE_MIGRATION = re.compile('^[0-9]*_') PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_")
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def _extract_caller(p: PropertyCaller): def _extract_caller(p: PropertyCaller):
if p == PropertyCaller.OWNER: if p == PropertyCaller.OWNER:
return ['owner'] return ["owner"]
if p == PropertyCaller.SENDER: if p == PropertyCaller.SENDER:
return ['user'] return ["user"]
if p == PropertyCaller.ATTACKER: if p == PropertyCaller.ATTACKER:
return ['attacker'] return ["attacker"]
if p == PropertyCaller.ALL: if p == PropertyCaller.ALL:
return ['owner', 'user', 'attacker'] return ["owner", "user", "attacker"]
assert p == PropertyCaller.ANY assert p == PropertyCaller.ANY
return ['user'] return ["user"]
def _helpers(): def _helpers():
@ -31,7 +31,7 @@ def _helpers():
- catchRevertThrow: check if the call revert/throw - catchRevertThrow: check if the call revert/throw
:return: :return:
""" """
return ''' return """
async function catchRevertThrowReturnFalse(promise) { async function catchRevertThrowReturnFalse(promise) {
try { try {
const ret = await promise; const ret = await promise;
@ -61,12 +61,17 @@ async function catchRevertThrow(promise) {
} }
assert(false, "Expected revert/throw/or return false"); assert(false, "Expected revert/throw/or return false");
}; };
''' """
def generate_unit_test(test_contract: str, filename: str, def generate_unit_test(
unit_tests: List[Property], output_dir: Path, test_contract: str,
addresses: Addresses, assert_message: str = ''): filename: str,
unit_tests: List[Property],
output_dir: Path,
addresses: Addresses,
assert_message: str = "",
):
""" """
Generate unit tests files Generate unit tests files
:param test_contract: :param test_contract:
@ -88,37 +93,37 @@ def generate_unit_test(test_contract: str, filename: str,
content += f'\tlet attacker = "{addresses.attacker}";\n' content += f'\tlet attacker = "{addresses.attacker}";\n'
for unit_test in unit_tests: for unit_test in unit_tests:
content += f'\tit("{unit_test.description}", async () => {{\n' content += f'\tit("{unit_test.description}", async () => {{\n'
content += f'\t\tlet instance = await {test_contract}.deployed();\n' content += f"\t\tlet instance = await {test_contract}.deployed();\n"
callers = _extract_caller(unit_test.caller) callers = _extract_caller(unit_test.caller)
if unit_test.return_type == PropertyReturn.SUCCESS: if unit_test.return_type == PropertyReturn.SUCCESS:
for caller in callers: for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message: if assert_message:
content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n' content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n'
else: else:
content += f'\t\tassert.equal(test_{caller}, true);\n' content += f"\t\tassert.equal(test_{caller}, true);\n"
elif unit_test.return_type == PropertyReturn.FAIL: elif unit_test.return_type == PropertyReturn.FAIL:
for caller in callers: for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message: if assert_message:
content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n' content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n'
else: else:
content += f'\t\tassert.equal(test_{caller}, false);\n' content += f"\t\tassert.equal(test_{caller}, false);\n"
elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW: elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW:
for caller in callers: for caller in callers:
content += f'\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' content += f"\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
elif unit_test.return_type == PropertyReturn.THROW: elif unit_test.return_type == PropertyReturn.THROW:
callers = _extract_caller(unit_test.caller) callers = _extract_caller(unit_test.caller)
for caller in callers: for caller in callers:
content += f'\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' content += f"\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
content += '\t});\n' content += "\t});\n"
content += '});\n' content += "});\n"
output_dir = Path(output_dir, 'test') output_dir = Path(output_dir, "test")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
output_dir = Path(output_dir, 'crytic') output_dir = Path(output_dir, "crytic")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
write_file(output_dir, filename, content) write_file(output_dir, filename, content)
@ -133,28 +138,31 @@ def generate_migration(test_contract: str, output_dir: Path, owner_address: str)
:param owner_address: :param owner_address:
:return: :return:
""" """
content = f'''{test_contract} = artifacts.require("{test_contract}"); content = f"""{test_contract} = artifacts.require("{test_contract}");
module.exports = function(deployer) {{ module.exports = function(deployer) {{
deployer.deploy({test_contract}, {{from: "{owner_address}"}}); deployer.deploy({test_contract}, {{from: "{owner_address}"}});
}}; }};
''' """
output_dir = Path(output_dir, 'migrations') output_dir = Path(output_dir, "migrations")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
migration_files = [js_file for js_file in output_dir.iterdir() if js_file.suffix == '.js' migration_files = [
and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)] js_file
for js_file in output_dir.iterdir()
if js_file.suffix == ".js" and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)
]
idx = len(migration_files) idx = len(migration_files)
filename = f'{idx + 1}_{test_contract}.js' filename = f"{idx + 1}_{test_contract}.js"
potential_previous_filename = f'{idx}_{test_contract}.js' potential_previous_filename = f"{idx}_{test_contract}.js"
for m in migration_files: for m in migration_files:
if m.name == potential_previous_filename: if m.name == potential_previous_filename:
write_file(output_dir, potential_previous_filename, content) write_file(output_dir, potential_previous_filename, content)
return return
if test_contract in m.name: if test_contract in m.name:
logger.error(f'Potential conflicts with {m.name}') logger.error(f"Potential conflicts with {m.name}")
write_file(output_dir, filename, content) write_file(output_dir, filename, content)

@ -12,27 +12,37 @@ from slither.tools.properties.platforms.echidna import generate_echidna_config
from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable
from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG
from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable
from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ERC20_NotMintableNotBurnable from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import (
from slither.tools.properties.properties.ercs.erc20.properties.transfer import ERC20_Transferable, ERC20_Pausable ERC20_NotMintableNotBurnable,
)
from slither.tools.properties.properties.ercs.erc20.properties.transfer import (
ERC20_Transferable,
ERC20_Pausable,
)
from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test
from slither.tools.properties.properties.properties import property_to_solidity, Property from slither.tools.properties.properties.properties import property_to_solidity, Property
from slither.tools.properties.solidity.generate_properties import generate_solidity_properties, generate_test_contract, \ from slither.tools.properties.solidity.generate_properties import (
generate_solidity_interface generate_solidity_properties,
generate_test_contract,
generate_solidity_interface,
)
from slither.utils.colors import red, green from slither.utils.colors import red, green
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
PropertyDescription = namedtuple('PropertyDescription', ['properties', 'description']) PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"])
ERC20_PROPERTIES = { ERC20_PROPERTIES = {
"Transferable": PropertyDescription(ERC20_Transferable, 'Test the correct tokens transfer'), "Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"),
"Pausable": PropertyDescription(ERC20_Pausable, 'Test the pausable functionality'), "Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"),
"NotMintable": PropertyDescription(ERC20_NotMintable, 'Test that no one can mint tokens'), "NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"),
"NotMintableNotBurnable": PropertyDescription(ERC20_NotMintableNotBurnable, "NotMintableNotBurnable": PropertyDescription(
'Test that no one can mint or burn tokens'), ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens"
"NotBurnable": PropertyDescription(ERC20_NotBurnable, 'Test that no one can burn tokens'), ),
"Burnable": PropertyDescription(ERC20_NotBurnable, "NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"),
'Test the burn of tokens. Require the "burn(address) returns()" function') "Burnable": PropertyDescription(
ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function'
),
} }
@ -54,7 +64,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
:return: :return:
""" """
if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]: if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]:
logging.error(f'{contract.slither.crytic_compile.type} not yet supported by slither-prop') logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop")
return return
# Check if the contract is an ERC20 contract and if the functions have the correct visibility # Check if the contract is an ERC20 contract and if the functions have the correct visibility
@ -65,7 +75,9 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
properties = ERC20_PROPERTIES.get(type_property, None) properties = ERC20_PROPERTIES.get(type_property, None)
if properties is None: if properties is None:
logger.error(f'{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}') logger.error(
f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}"
)
return return
properties = properties.properties properties = properties.properties
@ -78,51 +90,53 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
# Generate the contract containing the properties # Generate the contract containing the properties
generate_solidity_interface(output_dir, addresses) generate_solidity_interface(output_dir, addresses)
property_file = generate_solidity_properties(contract, type_property, solidity_properties, output_dir) property_file = generate_solidity_properties(
contract, type_property, solidity_properties, output_dir
)
# Generate the Test contract # Generate the Test contract
initialization_recommendation = _initialization_recommendation(type_property) initialization_recommendation = _initialization_recommendation(type_property)
contract_filename, contract_name = generate_test_contract(contract, contract_filename, contract_name = generate_test_contract(
type_property, contract, type_property, output_dir, property_file, initialization_recommendation
output_dir, )
property_file,
initialization_recommendation)
# Generate Echidna config file # Generate Echidna config file
echidna_config_filename = generate_echidna_config(Path(contract.slither.crytic_compile.target).parent, addresses) echidna_config_filename = generate_echidna_config(
Path(contract.slither.crytic_compile.target).parent, addresses
)
unit_test_info = '' unit_test_info = ""
# If truffle, generate unit tests # If truffle, generate unit tests
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses) unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses)
logger.info('################################################') logger.info("################################################")
logger.info(green(f'Update the constructor in {Path(output_dir, contract_filename)}')) logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}"))
if unit_test_info: if unit_test_info:
logger.info(green(unit_test_info)) logger.info(green(unit_test_info))
logger.info(green('To run Echidna:')) logger.info(green("To run Echidna:"))
txt = f'\t echidna-test {contract.slither.crytic_compile.target} ' txt = f"\t echidna-test {contract.slither.crytic_compile.target} "
txt += f'--contract {contract_name} --config {echidna_config_filename}' txt += f"--contract {contract_name} --config {echidna_config_filename}"
logger.info(green(txt)) logger.info(green(txt))
def _initialization_recommendation(type_property: str) -> str: def _initialization_recommendation(type_property: str) -> str:
content = '' content = ""
content += '\t\t// Add below a minimal configuration:\n' content += "\t\t// Add below a minimal configuration:\n"
content += '\t\t// - crytic_owner must have some tokens \n' content += "\t\t// - crytic_owner must have some tokens \n"
content += '\t\t// - crytic_user must have some tokens \n' content += "\t\t// - crytic_user must have some tokens \n"
content += '\t\t// - crytic_attacker must have some tokens \n' content += "\t\t// - crytic_attacker must have some tokens \n"
if type_property in ['Pausable']: if type_property in ["Pausable"]:
content += '\t\t// - The contract must be paused \n' content += "\t\t// - The contract must be paused \n"
if type_property in ['NotMintable', 'NotMintableNotBurnable']: if type_property in ["NotMintable", "NotMintableNotBurnable"]:
content += '\t\t// - The contract must not be mintable \n' content += "\t\t// - The contract must not be mintable \n"
if type_property in ['NotBurnable', 'NotMintableNotBurnable']: if type_property in ["NotBurnable", "NotMintableNotBurnable"]:
content += '\t\t// - The contract must not be burnable \n' content += "\t\t// - The contract must not be burnable \n"
content += '\n' content += "\n"
content += '\n' content += "\n"
return content return content
@ -130,44 +144,44 @@ def _initialization_recommendation(type_property: str) -> str:
# TODO: move this to crytic-compile # TODO: move this to crytic-compile
def _platform_to_output_dir(platform: AbstractPlatform) -> Path: def _platform_to_output_dir(platform: AbstractPlatform) -> Path:
if platform.TYPE == PlatformType.TRUFFLE: if platform.TYPE == PlatformType.TRUFFLE:
return Path(platform.target, 'contracts', 'crytic') return Path(platform.target, "contracts", "crytic")
if platform.TYPE == PlatformType.SOLC: if platform.TYPE == PlatformType.SOLC:
return Path(platform.target).parent return Path(platform.target).parent
def _check_compatibility(contract): def _check_compatibility(contract):
errors = '' errors = ""
if not contract.is_erc20(): if not contract.is_erc20():
errors = f'{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc' errors = f"{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc"
return errors return errors
transfer = contract.get_function_from_signature('transfer(address,uint256)') transfer = contract.get_function_from_signature("transfer(address,uint256)")
if transfer.visibility != 'public': if transfer.visibility != "public":
errors = f'slither-prop requires {transfer.canonical_name} to be public. Please change the visibility' errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility"
transfer_from = contract.get_function_from_signature('transferFrom(address,address,uint256)') transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)")
if transfer_from.visibility != 'public': if transfer_from.visibility != "public":
if errors: if errors:
errors += '\n' errors += "\n"
errors += f'slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility' errors += f"slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility"
approve = contract.get_function_from_signature('approve(address,uint256)') approve = contract.get_function_from_signature("approve(address,uint256)")
if approve.visibility != 'public': if approve.visibility != "public":
if errors: if errors:
errors += '\n' errors += "\n"
errors += f'slither-prop requires {approve.canonical_name} to be public. Please change the visibility' errors += f"slither-prop requires {approve.canonical_name} to be public. Please change the visibility"
return errors return errors
def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]: def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]:
solidity_properties = '' solidity_properties = ""
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
solidity_properties += '\n'.join([property_to_solidity(p) for p in ERC20_CONFIG]) solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG])
solidity_properties += '\n'.join([property_to_solidity(p) for p in properties]) solidity_properties += "\n".join([property_to_solidity(p) for p in properties])
unit_tests = [p for p in properties if p.is_unit_test] unit_tests = [p for p in properties if p.is_unit_test]
return solidity_properties, unit_tests return solidity_properties, unit_tests

@ -1,30 +1,38 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotBurnable = [ ERC20_NotBurnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', Property(
description='The total supply does not decrease.', name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
content=''' description="The total supply does not decrease.",
\t\treturn initialTotalSupply == this.totalSupply();''', content="""
type=PropertyType.MEDIUM_SEVERITY, \t\treturn initialTotalSupply == this.totalSupply();""",
return_type=PropertyReturn.SUCCESS, type=PropertyType.MEDIUM_SEVERITY,
is_unit_test=True, return_type=PropertyReturn.SUCCESS,
is_property_test=True, is_unit_test=True,
caller=PropertyCaller.ANY), is_property_test=True,
caller=PropertyCaller.ANY,
),
] ]
# Require burn(address) returns() # Require burn(address) returns()
ERC20_Burnable = [ ERC20_Burnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', Property(
description='Cannot burn more than available balance', name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
content=''' description="Cannot burn more than available balance",
content="""
\t\tuint balance = balanceOf(msg.sender); \t\tuint balance = balanceOf(msg.sender);
\t\tburn(balance + 1); \t\tburn(balance + 1);
\t\treturn false;''', \t\treturn false;""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.THROW, return_type=PropertyReturn.THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL) caller=PropertyCaller.ALL,
)
] ]

@ -1,65 +1,76 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_CONFIG = [ ERC20_CONFIG = [
Property(
Property(name='init_total_supply()', name="init_total_supply()",
description='The total supply is correctly initialized.', description="The total supply is correctly initialized.",
content=''' content="""
\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;''', \t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=False, is_property_test=False,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='init_owner_balance()', Property(
description="Owner's balance is correctly initialized.", name="init_owner_balance()",
content=''' description="Owner's balance is correctly initialized.",
\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);''', content="""
type=PropertyType.CODE_QUALITY, \t\treturn initialBalance_owner == this.balanceOf(crytic_owner);""",
return_type=PropertyReturn.SUCCESS, type=PropertyType.CODE_QUALITY,
is_unit_test=True, return_type=PropertyReturn.SUCCESS,
is_property_test=False, is_unit_test=True,
caller=PropertyCaller.ANY), is_property_test=False,
caller=PropertyCaller.ANY,
Property(name='init_user_balance()', ),
description="User's balance is correctly initialized.", Property(
content=''' name="init_user_balance()",
\t\treturn initialBalance_user == this.balanceOf(crytic_user);''', description="User's balance is correctly initialized.",
type=PropertyType.CODE_QUALITY, content="""
return_type=PropertyReturn.SUCCESS, \t\treturn initialBalance_user == this.balanceOf(crytic_user);""",
is_unit_test=True, type=PropertyType.CODE_QUALITY,
is_property_test=False, return_type=PropertyReturn.SUCCESS,
caller=PropertyCaller.ANY), is_unit_test=True,
is_property_test=False,
Property(name='init_attacker_balance()', caller=PropertyCaller.ANY,
description="Attacker's balance is correctly initialized.", ),
content=''' Property(
\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);''', name="init_attacker_balance()",
type=PropertyType.CODE_QUALITY, description="Attacker's balance is correctly initialized.",
return_type=PropertyReturn.SUCCESS, content="""
is_unit_test=True, \t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);""",
is_property_test=False, type=PropertyType.CODE_QUALITY,
caller=PropertyCaller.ANY), return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
Property(name='init_caller_balance()', is_property_test=False,
description="All the users have a positive balance.", caller=PropertyCaller.ANY,
content=''' ),
\t\treturn this.balanceOf(msg.sender) >0 ;''', Property(
type=PropertyType.CODE_QUALITY, name="init_caller_balance()",
return_type=PropertyReturn.SUCCESS, description="All the users have a positive balance.",
is_unit_test=True, content="""
is_property_test=False, \t\treturn this.balanceOf(msg.sender) >0 ;""",
caller=PropertyCaller.ALL), type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ALL,
),
# Note: there is a potential overflow on the addition, but we dont consider it # Note: there is a potential overflow on the addition, but we dont consider it
Property(name='init_total_supply_is_balances()', Property(
description="The total supply is the user and owner balance.", name="init_total_supply_is_balances()",
content=''' description="The total supply is the user and owner balance.",
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();''', content="""
type=PropertyType.CODE_QUALITY, \t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();""",
return_type=PropertyReturn.SUCCESS, type=PropertyType.CODE_QUALITY,
is_unit_test=True, return_type=PropertyReturn.SUCCESS,
is_property_test=False, is_unit_test=True,
caller=PropertyCaller.ANY), is_property_test=False,
] caller=PropertyCaller.ANY,
),
]

@ -1,13 +1,20 @@
from slither.tools.properties.properties.properties import PropertyType, PropertyReturn, Property, PropertyCaller from slither.tools.properties.properties.properties import (
PropertyType,
PropertyReturn,
Property,
PropertyCaller,
)
ERC20_NotMintable = [ ERC20_NotMintable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotMintable()', Property(
description='The total supply does not increase.', name="crytic_supply_constant_ERC20PropertiesNotMintable()",
content=''' description="The total supply does not increase.",
\t\treturn initialTotalSupply >= totalSupply();''', content="""
type=PropertyType.MEDIUM_SEVERITY, \t\treturn initialTotalSupply >= totalSupply();""",
return_type=PropertyReturn.SUCCESS, type=PropertyType.MEDIUM_SEVERITY,
is_unit_test=True, return_type=PropertyReturn.SUCCESS,
is_property_test=True, is_unit_test=True,
caller=PropertyCaller.ANY), is_property_test=True,
caller=PropertyCaller.ANY,
),
] ]

@ -1,14 +1,20 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotMintableNotBurnable = [ ERC20_NotMintableNotBurnable = [
Property(
Property(name='crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()', name="crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()",
description='The total supply does not change.', description="The total supply does not change.",
content=''' content="""
\t\treturn initialTotalSupply == this.totalSupply();''', \t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
] ),
]

@ -1,96 +1,108 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_Transferable = [ ERC20_Transferable = [
Property(
Property(name='crytic_zero_always_empty_ERC20Properties()', name="crytic_zero_always_empty_ERC20Properties()",
description='The address 0x0 should not receive tokens.', description="The address 0x0 should not receive tokens.",
content=''' content="""
\t\treturn this.balanceOf(address(0x0)) == 0;''', \t\treturn this.balanceOf(address(0x0)) == 0;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ANY), caller=PropertyCaller.ANY,
),
Property(name='crytic_approve_overwrites()', Property(
description='Allowance can be changed.', name="crytic_approve_overwrites()",
content=''' description="Allowance can be changed.",
content="""
\t\tbool approve_return; \t\tbool approve_return;
\t\tapprove_return = approve(crytic_user, 10); \t\tapprove_return = approve(crytic_user, 10);
\t\trequire(approve_return); \t\trequire(approve_return);
\t\tapprove_return = approve(crytic_user, 20); \t\tapprove_return = approve(crytic_user, 20);
\t\trequire(approve_return); \t\trequire(approve_return);
\t\treturn this.allowance(msg.sender, crytic_user) == 20;''', \t\treturn this.allowance(msg.sender, crytic_user) == 20;""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_less_than_total_ERC20Properties()', Property(
description='Balance of one user must be less or equal to the total supply.', name="crytic_less_than_total_ERC20Properties()",
content=''' description="Balance of one user must be less or equal to the total supply.",
\t\treturn this.balanceOf(msg.sender) <= totalSupply();''', content="""
type=PropertyType.MEDIUM_SEVERITY, \t\treturn this.balanceOf(msg.sender) <= totalSupply();""",
return_type=PropertyReturn.SUCCESS, type=PropertyType.MEDIUM_SEVERITY,
is_unit_test=True, return_type=PropertyReturn.SUCCESS,
is_property_test=True, is_unit_test=True,
caller=PropertyCaller.ALL), is_property_test=True,
caller=PropertyCaller.ALL,
Property(name='crytic_totalSupply_consistant_ERC20Properties()', ),
description='Balance of the crytic users must be less or equal to the total supply.', Property(
content=''' name="crytic_totalSupply_consistant_ERC20Properties()",
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();''', description="Balance of the crytic users must be less or equal to the total supply.",
type=PropertyType.MEDIUM_SEVERITY, content="""
return_type=PropertyReturn.SUCCESS, \t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();""",
is_unit_test=True, type=PropertyType.MEDIUM_SEVERITY,
is_property_test=True, return_type=PropertyReturn.SUCCESS,
caller=PropertyCaller.ANY), is_unit_test=True,
is_property_test=True,
Property(name='crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()', caller=PropertyCaller.ANY,
description='No one should be able to send tokens to the address 0x0 (transfer).', ),
content=''' Property(
name="crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()",
description="No one should be able to send tokens to the address 0x0 (transfer).",
content="""
\t\tif (this.balanceOf(msg.sender) == 0){ \t\tif (this.balanceOf(msg.sender) == 0){
\t\t\trevert(); \t\t\trevert();
\t\t} \t\t}
\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));''', \t\treturn transfer(address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()', Property(
description='No one should be able to send tokens to the address 0x0 (transferFrom).', name="crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()",
content=''' description="No one should be able to send tokens to the address 0x0 (transferFrom).",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == 0){ \t\tif (balance == 0){
\t\t\trevert(); \t\t\trevert();
\t\t} \t\t}
\t\tapprove(msg.sender, balance); \t\tapprove(msg.sender, balance);
\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));''', \t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY, type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_self_transferFrom_ERC20PropertiesTransferable()', Property(
description='Self transferFrom works.', name="crytic_self_transferFrom_ERC20PropertiesTransferable()",
content=''' description="Self transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance); \t\tbool approve_return = approve(msg.sender, balance);
\t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance); \t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;''', \t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()', Property(
description='transferFrom works.', name="crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()",
content=''' description="transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance); \t\tbool approve_return = approve(msg.sender, balance);
\t\taddress other = crytic_user; \t\taddress other = crytic_user;
@ -98,29 +110,30 @@ ERC20_Transferable = [
\t\t\tother = crytic_owner; \t\t\tother = crytic_owner;
\t\t} \t\t}
\t\tbool transfer_return = transferFrom(msg.sender, other, balance); \t\tbool transfer_return = transferFrom(msg.sender, other, balance);
\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;''', \t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(
Property(name='crytic_self_transfer_ERC20PropertiesTransferable()', name="crytic_self_transfer_ERC20PropertiesTransferable()",
description='Self transfer works.', description="Self transfer works.",
content=''' content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tbool transfer_return = transfer(msg.sender, balance); \t\tbool transfer_return = transfer(msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;''', \t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_transfer_to_other_ERC20PropertiesTransferable()', Property(
description='transfer works.', name="crytic_transfer_to_other_ERC20PropertiesTransferable()",
content=''' description="transfer works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\taddress other = crytic_user; \t\taddress other = crytic_user;
\t\tif (other == msg.sender) { \t\tif (other == msg.sender) {
@ -130,74 +143,76 @@ ERC20_Transferable = [
\t\t\tbool transfer_other = transfer(other, 1); \t\t\tbool transfer_other = transfer(other, 1);
\t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other; \t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other;
\t\t} \t\t}
\t\treturn true;''', \t\treturn true;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_revert_transfer_to_user_ERC20PropertiesTransferable()', Property(
description='Cannot transfer more than the balance.', name="crytic_revert_transfer_to_user_ERC20PropertiesTransferable()",
content=''' description="Cannot transfer more than the balance.",
content="""
\t\tuint balance = this.balanceOf(msg.sender); \t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == (2 ** 256 - 1)) \t\tif (balance == (2 ** 256 - 1))
\t\t\treturn true; \t\t\treturn true;
\t\tbool transfer_other = transfer(crytic_user, balance+1); \t\tbool transfer_other = transfer(crytic_user, balance+1);
\t\treturn transfer_other;''', \t\treturn transfer_other;""",
type=PropertyType.HIGH_SEVERITY, type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
] ]
ERC20_Pausable = [ ERC20_Pausable = [
Property(
Property(name='crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()', name="crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()",
description='Cannot transfer.', description="Cannot transfer.",
content=''' content="""
\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));''', \t\treturn transfer(crytic_user, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()', Property(
description='Cannot execute transferFrom.', name="crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()",
content=''' description="Cannot execute transferFrom.",
content="""
\t\tapprove(msg.sender, this.balanceOf(msg.sender)); \t\tapprove(msg.sender, this.balanceOf(msg.sender));
\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));''', \t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW, return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
Property(name='crytic_constantBalance()', Property(
description='Cannot change the balance.', name="crytic_constantBalance()",
content=''' description="Cannot change the balance.",
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;''', content="""
type=PropertyType.MEDIUM_SEVERITY, \t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;""",
return_type=PropertyReturn.SUCCESS, type=PropertyType.MEDIUM_SEVERITY,
is_unit_test=True, return_type=PropertyReturn.SUCCESS,
is_property_test=True, is_unit_test=True,
caller=PropertyCaller.ALL), is_property_test=True,
caller=PropertyCaller.ALL,
Property(name='crytic_constantAllowance()', ),
description='Cannot change the allowance.', Property(
content=''' name="crytic_constantAllowance()",
description="Cannot change the allowance.",
content="""
\t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) && \t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) &&
\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);''', \t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);""",
type=PropertyType.MEDIUM_SEVERITY, type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS, return_type=PropertyReturn.SUCCESS,
is_unit_test=True, is_unit_test=True,
is_property_test=True, is_property_test=True,
caller=PropertyCaller.ALL), caller=PropertyCaller.ALL,
),
] ]

@ -11,26 +11,32 @@ from slither.tools.properties.properties.properties import Property
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def generate_truffle_test(contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses) -> str: def generate_truffle_test(
test_contract = f'Test{contract.name}{type_property}' contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses
filename_init = f'Initialization{test_contract}.js' ) -> str:
filename = f'{test_contract}.js' test_contract = f"Test{contract.name}{type_property}"
filename_init = f"Initialization{test_contract}.js"
filename = f"{test_contract}.js"
output_dir = Path(contract.slither.crytic_compile.target) output_dir = Path(contract.slither.crytic_compile.target)
generate_migration(test_contract, output_dir, addresses.owner) generate_migration(test_contract, output_dir, addresses.owner)
generate_unit_test(test_contract, generate_unit_test(
filename_init, test_contract,
ERC20_CONFIG, filename_init,
output_dir, ERC20_CONFIG,
addresses, output_dir,
f'Check the constructor of {test_contract}') addresses,
f"Check the constructor of {test_contract}",
generate_unit_test(test_contract, filename, unit_tests, output_dir, addresses,) )
log_info = '\n' generate_unit_test(
log_info += 'To run the unit tests:\n' test_contract, filename, unit_tests, output_dir, addresses,
)
log_info = "\n"
log_info += "To run the unit tests:\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n"
return log_info return log_info

@ -36,4 +36,4 @@ class Property(NamedTuple):
def property_to_solidity(p: Property): def property_to_solidity(p: Property):
return f'\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n' return f"\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n"

@ -9,59 +9,64 @@ from slither.tools.properties.utils import write_file
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def generate_solidity_properties(contract: Contract, type_property: str, solidity_properties: str, def generate_solidity_properties(
output_dir: Path) -> Path: contract: Contract, type_property: str, solidity_properties: str, output_dir: Path
) -> Path:
solidity_import = f'import "./interfaces.sol";\n' solidity_import = f'import "./interfaces.sol";\n'
solidity_import += f'import "../{contract.source_mapping["filename_short"]}";' solidity_import += f'import "../{contract.source_mapping["filename_short"]}";'
test_contract_name = f'Properties{contract.name}{type_property}' test_contract_name = f"Properties{contract.name}{type_property}"
solidity_content = f'{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}' solidity_content = (
solidity_content += f'{{\n\n{solidity_properties}\n}}\n' f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}"
)
solidity_content += f"{{\n\n{solidity_properties}\n}}\n"
filename = f'{test_contract_name}.sol' filename = f"{test_contract_name}.sol"
write_file(output_dir, filename, solidity_content) write_file(output_dir, filename, solidity_content)
return Path(filename) return Path(filename)
def generate_test_contract(contract: Contract, def generate_test_contract(
type_property: str, contract: Contract,
output_dir: Path, type_property: str,
property_file: Path, output_dir: Path,
initialization_recommendation: str) -> Tuple[str, str]: property_file: Path,
test_contract_name = f'Test{contract.name}{type_property}' initialization_recommendation: str,
properties_name = f'Properties{contract.name}{type_property}' ) -> Tuple[str, str]:
test_contract_name = f"Test{contract.name}{type_property}"
properties_name = f"Properties{contract.name}{type_property}"
content = '' content = ""
content += f'import "./{property_file}";\n' content += f'import "./{property_file}";\n'
content += f"contract {test_contract_name} is {properties_name} {{\n" content += f"contract {test_contract_name} is {properties_name} {{\n"
content += '\tconstructor() public{\n' content += "\tconstructor() public{\n"
content += '\t\t// Existing addresses:\n' content += "\t\t// Existing addresses:\n"
content += '\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n' content += "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n"
content += '\t\t// - crytic_user: Legitimate user\n' content += "\t\t// - crytic_user: Legitimate user\n"
content += '\t\t// - crytic_attacker: Attacker\n' content += "\t\t// - crytic_attacker: Attacker\n"
content += '\t\t// \n' content += "\t\t// \n"
content += initialization_recommendation content += initialization_recommendation
content += '\t\t// \n' content += "\t\t// \n"
content += '\t\t// \n' content += "\t\t// \n"
content += '\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n' content += "\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n"
content += '\t\tinitialTotalSupply = totalSupply();\n' content += "\t\tinitialTotalSupply = totalSupply();\n"
content += '\t\tinitialBalance_owner = balanceOf(crytic_owner);\n' content += "\t\tinitialBalance_owner = balanceOf(crytic_owner);\n"
content += '\t\tinitialBalance_user = balanceOf(crytic_user);\n' content += "\t\tinitialBalance_user = balanceOf(crytic_user);\n"
content += '\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n' content += "\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n"
content += '\t}\n}\n' content += "\t}\n}\n"
filename = f'{test_contract_name}.sol' filename = f"{test_contract_name}.sol"
write_file(output_dir, filename, content, allow_overwrite=False) write_file(output_dir, filename, content, allow_overwrite=False)
return filename, test_contract_name return filename, test_contract_name
def generate_solidity_interface(output_dir: Path, addresses: Addresses): def generate_solidity_interface(output_dir: Path, addresses: Addresses):
content = f''' content = f"""
contract CryticInterface{{ contract CryticInterface{{
address internal crytic_owner = address({addresses.owner}); address internal crytic_owner = address({addresses.owner});
address internal crytic_user = address({addresses.user}); address internal crytic_user = address({addresses.user});
@ -70,7 +75,7 @@ contract CryticInterface{{
uint internal initialBalance_owner; uint internal initialBalance_owner;
uint internal initialBalance_user; uint internal initialBalance_user;
uint internal initialBalance_attacker; uint internal initialBalance_attacker;
}}''' }}"""
# Static file, we discard if it exists as it should never change # Static file, we discard if it exists as it should never change
write_file(output_dir, 'interfaces.sol', content, discard_if_exist=True) write_file(output_dir, "interfaces.sol", content, discard_if_exist=True)

@ -6,11 +6,13 @@ from slither.utils.colors import green, yellow
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def write_file(output_dir: Path, def write_file(
filename: str, output_dir: Path,
content: str, filename: str,
allow_overwrite: bool = True, content: str,
discard_if_exist: bool = False): allow_overwrite: bool = True,
discard_if_exist: bool = False,
):
""" """
Write the content into output_dir/filename Write the content into output_dir/filename
:param output_dir: :param output_dir:
@ -25,10 +27,10 @@ def write_file(output_dir: Path,
if discard_if_exist: if discard_if_exist:
return return
if not allow_overwrite: if not allow_overwrite:
logger.info(yellow(f'{file_to_write} already exist and will not be overwritten')) logger.info(yellow(f"{file_to_write} already exist and will not be overwritten"))
return return
logger.info(yellow(f'Overwrite {file_to_write}')) logger.info(yellow(f"Overwrite {file_to_write}"))
else: else:
logger.info(green(f'Write {file_to_write}')) logger.info(green(f"Write {file_to_write}"))
with open(file_to_write, 'w') as f: with open(file_to_write, "w") as f:
f.write(content) f.write(content)

@ -8,62 +8,56 @@ import operator
from crytic_compile import cryticparser from crytic_compile import cryticparser
from .info import info from .info import info
from .test import test from .test import test
from .train import train from .train import train
from .plot import plot from .plot import plot
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
modes = ["info", "test", "train", "plot"] modes = ["info", "test", "train", "plot"]
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector') parser = argparse.ArgumentParser(
description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector"
parser.add_argument('mode', )
help="|".join(modes))
parser.add_argument("mode", help="|".join(modes))
parser.add_argument('model',
help='model.bin') parser.add_argument("model", help="model.bin")
parser.add_argument('--filename', parser.add_argument("--filename", action="store", dest="filename", help="contract.sol")
action='store',
dest='filename', parser.add_argument("--fname", action="store", dest="fname", help="Target function")
help='contract.sol')
parser.add_argument("--ext", action="store", dest="ext", help="Extension to filter contracts")
parser.add_argument('--fname',
action='store', parser.add_argument(
dest='fname', "--nsamples",
help='Target function') action="store",
type=int,
parser.add_argument('--ext', dest="nsamples",
action='store', help="Number of contract samples used for training",
dest='ext', )
help='Extension to filter contracts')
parser.add_argument(
parser.add_argument('--nsamples', "--ntop",
action='store', action="store",
type=int, type=int,
dest='nsamples', dest="ntop",
help='Number of contract samples used for training') default=10,
help="Number of more similar contracts to show for testing",
parser.add_argument('--ntop', )
action='store',
type=int, parser.add_argument(
dest='ntop', "--input", action="store", dest="input", help="File or directory used as input"
default=10, )
help='Number of more similar contracts to show for testing')
parser.add_argument(
parser.add_argument('--input', "--version", help="displays the current version", version="0.0", action="version"
action='store', )
dest='input',
help='File or directory used as input')
parser.add_argument('--version',
help='displays the current version',
version="0.0",
action='version')
cryticparser.init(parser) cryticparser.init(parser)
@ -74,6 +68,7 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -81,27 +76,29 @@ def parse_args():
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def main(): def main():
args = parse_args() args = parse_args()
default_log = logging.INFO default_log = logging.INFO
logger.setLevel(default_log) logger.setLevel(default_log)
mode = args.mode mode = args.mode
if mode == "info": if mode == "info":
info(args) info(args)
elif mode == "train": elif mode == "train":
train(args) train(args)
elif mode == "test": elif mode == "test":
test(args) test(args)
elif mode == "plot": elif mode == "plot":
plot(args) plot(args)
else: else:
logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes)) logger.error("Invalid mode!. It should be one of these: %s" % ", ".join(modes))
sys.exit(-1) sys.exit(-1)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()
# endregion # endregion

@ -7,16 +7,18 @@ except ImportError:
print("$ pip3 install numpy --user\n") print("$ pip3 install numpy --user\n")
sys.exit(-1) sys.exit(-1)
def load_cache(infile, nsamples=None): def load_cache(infile, nsamples=None):
cache = dict() cache = dict()
with np.load(infile, allow_pickle=True) as data: with np.load(infile, allow_pickle=True) as data:
array = data['arr_0'][0] array = data["arr_0"][0]
for i,(x,y) in enumerate(array): for i, (x, y) in enumerate(array):
cache[x] = y cache[x] = y
if i == nsamples: if i == nsamples:
break break
return cache return cache
def save_cache(cache, outfile): def save_cache(cache, outfile):
np.savez(outfile,[np.array(cache)]) np.savez(outfile, [np.array(cache)])

@ -2,16 +2,47 @@ import logging
import os import os
from slither import Slither from slither import Slither
from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function from slither.core.declarations import (
Structure,
Enum,
SolidityVariableComposed,
SolidityVariable,
Function,
)
from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \ from slither.slithir.operations import (
Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \ Assignment,
SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \ Index,
HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall AccessMember,
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable Length,
Balance,
Binary,
Unary,
Condition,
NewArray,
NewStructure,
NewContract,
NewElementaryType,
SolidityCall,
Push,
Delete,
EventCall,
LibraryCall,
InternalDynamicCall,
HighLevelCall,
LowLevelCall,
TypeConversion,
Return,
Transfer,
Send,
Unpack,
InitArray,
InternalCall,
)
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, IndexVariable
from .cache import load_cache from .cache import load_cache
simil_logger = logging.getLogger("Slither-simil") simil_logger = logging.getLogger("Slither-simil")
@ -20,11 +51,12 @@ compiler_logger.setLevel(logging.CRITICAL)
slither_logger = logging.getLogger("Slither") slither_logger = logging.getLogger("Slither")
slither_logger.setLevel(logging.CRITICAL) slither_logger.setLevel(logging.CRITICAL)
def parse_target(target): def parse_target(target):
if target is None: if target is None:
return None, None return None, None
parts = target.split('.') parts = target.split(".")
if len(parts) == 1: if len(parts) == 1:
return None, parts[0] return None, parts[0]
elif len(parts) == 2: elif len(parts) == 2:
@ -32,25 +64,27 @@ def parse_target(target):
else: else:
simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'")
def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs):
r = dict() r = dict()
if infile.endswith(".npz"): if infile.endswith(".npz"):
r = load_cache(infile, nsamples=nsamples) r = load_cache(infile, nsamples=nsamples)
else: else:
contracts = load_contracts(infile, ext=ext, nsamples=nsamples) contracts = load_contracts(infile, ext=ext, nsamples=nsamples)
for contract in contracts: for contract in contracts:
for x,ir in encode_contract(contract, **kwargs).items(): for x, ir in encode_contract(contract, **kwargs).items():
if ir != []: if ir != []:
y = " ".join(ir) y = " ".join(ir)
r[x] = vmodel.get_sentence_vector(y) r[x] = vmodel.get_sentence_vector(y)
return r return r
def load_contracts(dirname, ext=None, nsamples=None, **kwargs): def load_contracts(dirname, ext=None, nsamples=None, **kwargs):
r = [] r = []
walk = list(os.walk(dirname)) walk = list(os.walk(dirname))
for x, y, files in walk: for x, y, files in walk:
for f in files: for f in files:
if ext is None or f.endswith(ext): if ext is None or f.endswith(ext):
r.append(x + "/".join(y) + "/" + f) r.append(x + "/".join(y) + "/" + f)
@ -60,6 +94,7 @@ def load_contracts(dirname, ext=None, nsamples=None, **kwargs):
# TODO: shuffle # TODO: shuffle
return r[:nsamples] return r[:nsamples]
def ntype(_type): def ntype(_type):
if isinstance(_type, ElementaryType): if isinstance(_type, ElementaryType):
_type = str(_type) _type = str(_type)
@ -79,8 +114,8 @@ def ntype(_type):
else: else:
_type = str(_type) _type = str(_type)
_type = _type.replace(" memory","") _type = _type.replace(" memory", "")
_type = _type.replace(" storage ref","") _type = _type.replace(" storage ref", "")
if "struct" in _type: if "struct" in _type:
return "struct" return "struct"
@ -93,92 +128,94 @@ def ntype(_type):
elif "mapping" in _type: elif "mapping" in _type:
return "mapping" return "mapping"
else: else:
return _type.replace(" ","_") return _type.replace(" ", "_")
def encode_ir(ir): def encode_ir(ir):
# operations # operations
if isinstance(ir, Assignment): if isinstance(ir, Assignment):
return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) return "({}):=({})".format(encode_ir(ir.lvalue), encode_ir(ir.rvalue))
if isinstance(ir, Index): if isinstance(ir, Index):
return 'index({})'.format(ntype(ir._type)) return "index({})".format(ntype(ir._type))
if isinstance(ir, Member): if isinstance(ir, AccessMember):
return 'member' #.format(ntype(ir._type)) return "member" # .format(ntype(ir._type))
if isinstance(ir, Length): if isinstance(ir, Length):
return 'length' return "length"
if isinstance(ir, Balance): if isinstance(ir, Balance):
return 'balance' return "balance"
if isinstance(ir, Binary): if isinstance(ir, Binary):
return 'binary({})'.format(ir.type_str) return "binary({})".format(str(ir.type))
if isinstance(ir, Unary): if isinstance(ir, Unary):
return 'unary({})'.format(ir.type_str) return "unary({})".format(str(ir.type))
if isinstance(ir, Condition): if isinstance(ir, Condition):
return 'condition({})'.format(encode_ir(ir.value)) return "condition({})".format(encode_ir(ir.value))
if isinstance(ir, NewStructure): if isinstance(ir, NewStructure):
return 'new_structure' return "new_structure"
if isinstance(ir, NewContract): if isinstance(ir, NewContract):
return 'new_contract' return "new_contract"
if isinstance(ir, NewArray): if isinstance(ir, NewArray):
return 'new_array({})'.format(ntype(ir._array_type)) return "new_array({})".format(ntype(ir._array_type))
if isinstance(ir, NewElementaryType): if isinstance(ir, NewElementaryType):
return 'new_elementary({})'.format(ntype(ir._type)) return "new_elementary({})".format(ntype(ir._type))
if isinstance(ir, Push): if isinstance(ir, Push):
return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue)) return "push({},{})".format(encode_ir(ir.value), encode_ir(ir.lvalue))
if isinstance(ir, Delete): if isinstance(ir, Delete):
return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable)) return "delete({},{})".format(encode_ir(ir.lvalue), encode_ir(ir.variable))
if isinstance(ir, SolidityCall): if isinstance(ir, SolidityCall):
return 'solidity_call({})'.format(ir.function.full_name) return "solidity_call({})".format(ir.function.full_name)
if isinstance(ir, InternalCall): if isinstance(ir, InternalCall):
return 'internal_call({})'.format(ntype(ir._type_call)) return "internal_call({})".format(ntype(ir._type_call))
if isinstance(ir, EventCall): # is this useful? if isinstance(ir, EventCall): # is this useful?
return 'event' return "event"
if isinstance(ir, LibraryCall): if isinstance(ir, LibraryCall):
return 'library_call' return "library_call"
if isinstance(ir, InternalDynamicCall): if isinstance(ir, InternalDynamicCall):
return 'internal_dynamic_call' return "internal_dynamic_call"
if isinstance(ir, HighLevelCall): # TODO: improve if isinstance(ir, HighLevelCall): # TODO: improve
return 'high_level_call' return "high_level_call"
if isinstance(ir, LowLevelCall): # TODO: improve if isinstance(ir, LowLevelCall): # TODO: improve
return 'low_level_call' return "low_level_call"
if isinstance(ir, TypeConversion): if isinstance(ir, TypeConversion):
return 'type_conversion({})'.format(ntype(ir.type)) return "type_conversion({})".format(ntype(ir.type))
if isinstance(ir, Return): # this can be improved using values if isinstance(ir, Return): # this can be improved using values
return 'return' #.format(ntype(ir.type)) return "return" # .format(ntype(ir.type))
if isinstance(ir, Transfer): if isinstance(ir, Transfer):
return 'transfer({})'.format(encode_ir(ir.call_value)) return "transfer({})".format(encode_ir(ir.call_value))
if isinstance(ir, Send): if isinstance(ir, Send):
return 'send({})'.format(encode_ir(ir.call_value)) return "send({})".format(encode_ir(ir.call_value))
if isinstance(ir, Unpack): # TODO: improve if isinstance(ir, Unpack): # TODO: improve
return 'unpack' return "unpack"
if isinstance(ir, InitArray): # TODO: improve if isinstance(ir, InitArray): # TODO: improve
return 'init_array' return "init_array"
if isinstance(ir, Function): # TODO: investigate this if isinstance(ir, Function): # TODO: investigate this
return 'function_solc' return "function_solc"
# variables # variables
if isinstance(ir, Constant): if isinstance(ir, Constant):
return 'constant({})'.format(ntype(ir._type)) return "constant({})".format(ntype(ir._type))
if isinstance(ir, SolidityVariableComposed): if isinstance(ir, SolidityVariableComposed):
return 'solidity_variable_composed({})'.format(ir.name) return "solidity_variable_composed({})".format(ir.name)
if isinstance(ir, SolidityVariable): if isinstance(ir, SolidityVariable):
return 'solidity_variable{}'.format(ir.name) return "solidity_variable{}".format(ir.name)
if isinstance(ir, TemporaryVariable): if isinstance(ir, TemporaryVariable):
return 'temporary_variable' return "temporary_variable"
if isinstance(ir, ReferenceVariable): if isinstance(ir, IndexVariable):
return 'reference({})'.format(ntype(ir._type)) return "reference({})".format(ntype(ir._type))
if isinstance(ir, LocalVariable): if isinstance(ir, LocalVariable):
return 'local_solc_variable({})'.format(ir._location) return "local_solc_variable({})".format(ir._location)
if isinstance(ir, StateVariable): if isinstance(ir, StateVariable):
return 'state_solc_variable({})'.format(ntype(ir._type)) return "state_solc_variable({})".format(ntype(ir._type))
if isinstance(ir, LocalVariableInitFromTuple): if isinstance(ir, LocalVariableInitFromTuple):
return 'local_variable_init_tuple' return "local_variable_init_tuple"
if isinstance(ir, TupleVariable): if isinstance(ir, TupleVariable):
return 'tuple_variable' return "tuple_variable"
# default # default
else: else:
simil_logger.error(type(ir),"is missing encoding!") simil_logger.error(type(ir), "is missing encoding!")
return '' return ""
def encode_contract(cfilename, **kwargs): def encode_contract(cfilename, **kwargs):
r = dict() r = dict()
@ -186,7 +223,7 @@ def encode_contract(cfilename, **kwargs):
try: try:
slither = Slither(cfilename, **kwargs) slither = Slither(cfilename, **kwargs)
except: except:
simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc']) simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs["solc"])
return r return r
# Iterate over all the contracts # Iterate over all the contracts
@ -198,7 +235,7 @@ def encode_contract(cfilename, **kwargs):
if function.nodes == [] or function.is_constructor_variables: if function.nodes == [] or function.is_constructor_variables:
continue continue
x = (cfilename,contract.name,function.name) x = (cfilename, contract.name, function.name)
r[x] = [] r[x] = []
@ -210,5 +247,3 @@ def encode_contract(cfilename, **kwargs):
for ir in node.irs: for ir in node.irs:
r[x].append(encode_ir(ir)) r[x].append(encode_ir(ir))
return r return r

@ -3,18 +3,19 @@ import sys
import os.path import os.path
import traceback import traceback
from .model import load_model from .model import load_model
from .encode import parse_target, encode_contract from .encode import parse_target, encode_contract
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def info(args): def info(args):
try: try:
model = args.model model = args.model
if os.path.isfile(model): if os.path.isfile(model):
model = load_model(model) model = load_model(model)
else: else:
model = None model = None
@ -22,22 +23,22 @@ def info(args):
filename = args.filename filename = args.filename
contract, fname = parse_target(args.fname) contract, fname = parse_target(args.fname)
solc = args.solc solc = args.solc
if filename is None and contract is None and fname is None: if filename is None and contract is None and fname is None:
logger.info("%s uses the following words:",args.model) logger.info("%s uses the following words:", args.model)
for word in model.get_words(): for word in model.get_words():
logger.info(word) logger.info(word)
sys.exit(0) sys.exit(0)
if filename is None or contract is None or fname is None: if filename is None or contract is None or fname is None:
logger.error('The encode mode requires filename, contract and fname parameters.') logger.error("The encode mode requires filename, contract and fname parameters.")
sys.exit(-1) sys.exit(-1)
irs = encode_contract(filename, **vars(args)) irs = encode_contract(filename, **vars(args))
if len(irs) == 0: if len(irs) == 0:
sys.exit(-1) sys.exit(-1)
x = (filename,contract,fname) x = (filename, contract, fname)
y = " ".join(irs[x]) y = " ".join(irs[x])
logger.info("Function {} in contract {} is encoded as:".format(fname, contract)) logger.info("Function {} in contract {} is encoded as:".format(fname, contract))
@ -47,8 +48,6 @@ def info(args):
logger.info(fvector) logger.info(fvector)
except Exception: except Exception:
logger.error('Error in %s' % args.filename) logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
sys.exit(-1) sys.exit(-1)

@ -5,7 +5,7 @@ import operator
import numpy as np import numpy as np
import random import random
from .model import load_model from .model import load_model
from .encode import load_and_encode, parse_target from .encode import load_and_encode, parse_target
try: try:
@ -17,10 +17,13 @@ except ImportError:
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def plot(args): def plot(args):
if decomposition is None or plt is None: if decomposition is None or plt is None:
logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:") logger.error(
"ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:"
)
logger.error("$ pip3 install sklearn matplotlib --user") logger.error("$ pip3 install sklearn matplotlib --user")
sys.exit(-1) sys.exit(-1)
@ -29,50 +32,50 @@ def plot(args):
model = args.model model = args.model
model = load_model(model) model = load_model(model)
filename = args.filename filename = args.filename
#contract = args.contract # contract = args.contract
contract, fname = parse_target(args.fname) contract, fname = parse_target(args.fname)
#solc = args.solc # solc = args.solc
infile = args.input infile = args.input
#ext = args.filter # ext = args.filter
#nsamples = args.nsamples # nsamples = args.nsamples
if fname is None or infile is None: if fname is None or infile is None:
logger.error('The plot mode requieres fname and input parameters.') logger.error("The plot mode requieres fname and input parameters.")
sys.exit(-1) sys.exit(-1)
logger.info('Loading data..') logger.info("Loading data..")
cache = load_and_encode(infile, **vars(args)) cache = load_and_encode(infile, **vars(args))
data = list() data = list()
fs = list() fs = list()
logger.info('Procesing data..') logger.info("Procesing data..")
for (f,c,n),y in cache.items(): for (f, c, n), y in cache.items():
if (c == contract or contract is None) and n == fname: if (c == contract or contract is None) and n == fname:
fs.append(f) fs.append(f)
data.append(y) data.append(y)
if len(data) == 0: if len(data) == 0:
logger.error('No contract was found with function %s', fname) logger.error("No contract was found with function %s", fname)
sys.exit(-1) sys.exit(-1)
data = np.array(data) data = np.array(data)
pca = decomposition.PCA(n_components=2) pca = decomposition.PCA(n_components=2)
tdata = pca.fit_transform(data) tdata = pca.fit_transform(data)
logger.info('Plotting data..') logger.info("Plotting data..")
plt.figure(figsize=(20,10)) plt.figure(figsize=(20, 10))
assert(len(tdata) == len(fs)) assert len(tdata) == len(fs)
for ([x,y],l) in zip(tdata, fs): for ([x, y], l) in zip(tdata, fs):
x = random.gauss(0, 0.01) + x x = random.gauss(0, 0.01) + x
y = random.gauss(0, 0.01) + y y = random.gauss(0, 0.01) + y
plt.scatter(x, y, c='blue') plt.scatter(x, y, c="blue")
plt.text(x-0.001,y+0.001, l) plt.text(x - 0.001, y + 0.001, l)
logger.info("Saving figure to plot.png..")
plt.savefig("plot.png", bbox_inches="tight")
logger.info('Saving figure to plot.png..')
plt.savefig('plot.png', bbox_inches='tight')
except Exception: except Exception:
logger.error('Error in %s' % args.filename) logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
sys.exit(-1) sys.exit(-1)

@ -1,5 +1,6 @@
import numpy as np import numpy as np
def similarity(v1, v2): def similarity(v1, v2):
n1 = np.linalg.norm(v1) n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2) n2 = np.linalg.norm(v2)

@ -5,50 +5,51 @@ import traceback
import operator import operator
import numpy as np import numpy as np
from .model import load_model from .model import load_model
from .encode import encode_contract, load_and_encode, parse_target from .encode import encode_contract, load_and_encode, parse_target
from .cache import save_cache from .cache import save_cache
from .similarity import similarity from .similarity import similarity
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def test(args): def test(args):
try: try:
model = args.model model = args.model
model = load_model(model) model = load_model(model)
filename = args.filename filename = args.filename
contract, fname = parse_target(args.fname) contract, fname = parse_target(args.fname)
infile = args.input infile = args.input
ntop = args.ntop ntop = args.ntop
if filename is None or contract is None or fname is None or infile is None: if filename is None or contract is None or fname is None or infile is None:
logger.error('The test mode requires filename, contract, fname and input parameters.') logger.error("The test mode requires filename, contract, fname and input parameters.")
sys.exit(-1) sys.exit(-1)
irs = encode_contract(filename, **vars(args)) irs = encode_contract(filename, **vars(args))
if len(irs) == 0: if len(irs) == 0:
sys.exit(-1) sys.exit(-1)
y = " ".join(irs[(filename,contract,fname)]) y = " ".join(irs[(filename, contract, fname)])
fvector = model.get_sentence_vector(y) fvector = model.get_sentence_vector(y)
cache = load_and_encode(infile, model, **vars(args)) cache = load_and_encode(infile, model, **vars(args))
#save_cache("cache.npz", cache) # save_cache("cache.npz", cache)
r = dict() r = dict()
for x,y in cache.items(): for x, y in cache.items():
r[x] = similarity(fvector, y) r[x] = similarity(fvector, y)
r = sorted(r.items(), key=operator.itemgetter(1), reverse=True) r = sorted(r.items(), key=operator.itemgetter(1), reverse=True)
logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop) logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop)
format_table = "{: <65} {: <20} {: <20} {: <10}" format_table = "{: <65} {: <20} {: <20} {: <10}"
logger.info(format_table.format(*["filename", "contract", "function", "score"])) logger.info(format_table.format(*["filename", "contract", "function", "score"]))
for x,score in r[:ntop]: for x, score in r[:ntop]:
score = str(round(score, 3)) score = str(round(score, 3))
logger.info(format_table.format(*(list(x)+[score]))) logger.info(format_table.format(*(list(x) + [score])))
except Exception: except Exception:
logger.error('Error in %s' % args.filename) logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
sys.exit(-1) sys.exit(-1)

@ -5,12 +5,13 @@ import traceback
import operator import operator
import os import os
from .model import train_unsupervised from .model import train_unsupervised
from .encode import encode_contract, load_contracts from .encode import encode_contract, load_contracts
from .cache import save_cache from .cache import save_cache
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def train(args): def train(args):
try: try:
@ -20,35 +21,37 @@ def train(args):
nsamples = args.nsamples nsamples = args.nsamples
if dirname is None: if dirname is None:
logger.error('The train mode requires the input parameter.') logger.error("The train mode requires the input parameter.")
sys.exit(-1) sys.exit(-1)
contracts = load_contracts(dirname, **vars(args)) contracts = load_contracts(dirname, **vars(args))
logger.info('Saving extracted data into %s', last_data_train_filename) logger.info("Saving extracted data into %s", last_data_train_filename)
cache = [] cache = []
with open(last_data_train_filename, 'w') as f: with open(last_data_train_filename, "w") as f:
for filename in contracts: for filename in contracts:
#cache[filename] = dict() # cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items(): for (filename, contract, function), ir in encode_contract(
filename, **vars(args)
).items():
if ir != []: if ir != []:
x = " ".join(ir) x = " ".join(ir)
f.write(x+"\n") f.write(x + "\n")
cache.append((os.path.split(filename)[-1], contract, function, x)) cache.append((os.path.split(filename)[-1], contract, function, x))
logger.info('Starting training') logger.info("Starting training")
model = train_unsupervised(input=last_data_train_filename, model='skipgram') model = train_unsupervised(input=last_data_train_filename, model="skipgram")
logger.info('Training complete') logger.info("Training complete")
logger.info('Saving model') logger.info("Saving model")
model.save_model(model_filename) model.save_model(model_filename)
for i,(filename, contract, function, irs) in enumerate(cache): for i, (filename, contract, function, irs) in enumerate(cache):
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs)) cache[i] = ((filename, contract, function), model.get_sentence_vector(irs))
logger.info('Saving cache in cache.npz') logger.info("Saving cache in cache.npz")
save_cache(cache, "cache.npz") save_cache(cache, "cache.npz")
logger.info('Done!') logger.info("Done!")
except Exception: except Exception:
logger.error('Error in %s' % args.filename) logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
sys.exit(-1) sys.exit(-1)

@ -10,63 +10,77 @@ logging.basicConfig()
logger = logging.getLogger("Slither").setLevel(logging.INFO) logger = logging.getLogger("Slither").setLevel(logging.INFO)
# Slither detectors for which slither-format currently works # Slither detectors for which slither-format currently works
available_detectors = ["unused-state", available_detectors = [
"solc-version", "unused-state",
"pragma", "solc-version",
"naming-convention", "pragma",
"external-function", "naming-convention",
"constable-states", "external-function",
"constant-function-asm", "constable-states",
"constatnt-function-state"] "constant-function-asm",
"constatnt-function-state",
]
detectors_to_run = [] detectors_to_run = []
def parse_args(): def parse_args():
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser(description='slither_format', parser = argparse.ArgumentParser(description="slither_format", usage="slither_format filename")
usage='slither_format filename')
parser.add_argument(
parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.') "filename", help="The filename of the contract or truffle directory to analyze."
parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False) )
parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False) parser.add_argument(
parser.add_argument('--version', "--verbose-test",
help='displays the current version', "-v",
version='0.1.0', help="verbose mode output for testing",
action='version') action="store_true",
default=False,
parser.add_argument('--config-file', )
help='Provide a config file (default: slither.config.json)', parser.add_argument(
action='store', "--verbose-json", "-j", help="verbose json output", action="store_true", default=False
dest='config_file', )
default='slither.config.json') parser.add_argument(
"--version", help="displays the current version", version="0.1.0", action="version"
)
group_detector = parser.add_argument_group('Detectors')
group_detector.add_argument('--detect', parser.add_argument(
help='Comma-separated list of detectors, defaults to all, ' "--config-file",
'available detectors: {}'.format( help="Provide a config file (default: slither.config.json)",
', '.join(d for d in available_detectors)), action="store",
action='store', dest="config_file",
dest='detectors_to_run', default="slither.config.json",
default='all') )
group_detector.add_argument('--exclude', group_detector = parser.add_argument_group("Detectors")
help='Comma-separated list of detectors to exclude,' group_detector.add_argument(
'available detectors: {}'.format( "--detect",
', '.join(d for d in available_detectors)), help="Comma-separated list of detectors, defaults to all, "
action='store', "available detectors: {}".format(", ".join(d for d in available_detectors)),
dest='detectors_to_exclude', action="store",
default='all') dest="detectors_to_run",
default="all",
cryticparser.init(parser) )
if len(sys.argv) == 1: group_detector.add_argument(
parser.print_help(sys.stderr) "--exclude",
help="Comma-separated list of detectors to exclude,"
"available detectors: {}".format(", ".join(d for d in available_detectors)),
action="store",
dest="detectors_to_exclude",
default="all",
)
cryticparser.init(parser)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1) sys.exit(1)
return parser.parse_args() return parser.parse_args()
@ -80,11 +94,12 @@ def main():
read_config_file(args) read_config_file(args)
# Perform slither analysis on the given filename # Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args)) slither = Slither(args.filename, **vars(args))
# Format the input files based on slither analysis # Format the input files based on slither analysis
slither_format(slither, **vars(args)) slither_format(slither, **vars(args))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

@ -11,27 +11,29 @@ from slither.detectors.attributes.const_functions_state import ConstantFunctions
from slither.utils.colors import yellow from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format') logger = logging.getLogger("Slither.Format")
all_detectors = { all_detectors = {
'unused-state': UnusedStateVars, "unused-state": UnusedStateVars,
'solc-version': IncorrectSolc, "solc-version": IncorrectSolc,
'pragma': ConstantPragma, "pragma": ConstantPragma,
'naming-convention': NamingConvention, "naming-convention": NamingConvention,
'external-function': ExternalFunction, "external-function": ExternalFunction,
'constable-states' : ConstCandidateStateVars, "constable-states": ConstCandidateStateVars,
'constant-function-asm': ConstantFunctionsAsm, "constant-function-asm": ConstantFunctionsAsm,
'constant-functions-state': ConstantFunctionsState "constant-functions-state": ConstantFunctionsState,
} }
def slither_format(slither, **kwargs): def slither_format(slither, **kwargs):
'''' """'
Keyword Args: Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all detectors_to_run (str): Comma-separated list of detectors, defaults to all
''' """
detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'), detectors_to_run = choose_detectors(
kwargs.get('detectors_to_exclude', '')) kwargs.get("detectors_to_run", "all"), kwargs.get("detectors_to_exclude", "")
)
for detector in detectors_to_run: for detector in detectors_to_run:
slither.register_detector(detector) slither.register_detector(detector)
@ -42,32 +44,32 @@ def slither_format(slither, **kwargs):
detector_results = [x for x in detector_results if x] # remove empty results detector_results = [x for x in detector_results if x] # remove empty results
detector_results = [item for sublist in detector_results for item in sublist] # flatten detector_results = [item for sublist in detector_results for item in sublist] # flatten
export = Path('crytic-export', 'patches') export = Path("crytic-export", "patches")
export.mkdir(parents=True, exist_ok=True) export.mkdir(parents=True, exist_ok=True)
counter_result = 0 counter_result = 0
logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.')) logger.info(yellow("slither-format is in beta, carefully review each patch before merging it."))
for result in detector_results: for result in detector_results:
if not 'patches' in result: if not "patches" in result:
continue continue
one_line_description = result["description"].split("\n")[0] one_line_description = result["description"].split("\n")[0]
export_result = Path(export, f'{counter_result}') export_result = Path(export, f"{counter_result}")
export_result.mkdir(parents=True, exist_ok=True) export_result.mkdir(parents=True, exist_ok=True)
counter_result += 1 counter_result += 1
counter = 0 counter = 0
logger.info(f'Issue: {one_line_description}') logger.info(f"Issue: {one_line_description}")
logger.info(f'Generated: ({export_result})') logger.info(f"Generated: ({export_result})")
for file, diff, in result['patches_diff'].items(): for file, diff, in result["patches_diff"].items():
filename = f'fix_{counter}.patch' filename = f"fix_{counter}.patch"
path = Path(export_result, filename) path = Path(export_result, filename)
logger.info(f'\t- {filename}') logger.info(f"\t- {filename}")
with open(path, 'w') as f: with open(path, "w") as f:
f.write(diff) f.write(diff)
counter += 1 counter += 1
@ -79,26 +81,28 @@ def slither_format(slither, **kwargs):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude): def choose_detectors(detectors_to_run, detectors_to_exclude):
# If detectors are specified, run only these ones # If detectors are specified, run only these ones
cls_detectors_to_run = [] cls_detectors_to_run = []
exclude = detectors_to_exclude.split(',') exclude = detectors_to_exclude.split(",")
if detectors_to_run == 'all': if detectors_to_run == "all":
for d in all_detectors: for d in all_detectors:
if d in exclude: if d in exclude:
continue continue
cls_detectors_to_run.append(all_detectors[d]) cls_detectors_to_run.append(all_detectors[d])
else: else:
exclude = detectors_to_exclude.split(',') exclude = detectors_to_exclude.split(",")
for d in detectors_to_run.split(','): for d in detectors_to_run.split(","):
if d in all_detectors: if d in all_detectors:
if d in exclude: if d in exclude:
continue continue
cls_detectors_to_run.append(all_detectors[d]) cls_detectors_to_run.append(all_detectors[d])
else: else:
raise Exception('Error: {} is not a detector'.format(d)) raise Exception("Error: {} is not a detector".format(d))
return cls_detectors_to_run return cls_detectors_to_run
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -106,6 +110,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def print_patches(number_of_slither_results, patches): def print_patches(number_of_slither_results, patches):
logger.info("Number of Slither results: " + str(number_of_slither_results)) logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0 number_of_patches = 0
@ -115,39 +120,38 @@ def print_patches(number_of_slither_results, patches):
for file in patches: for file in patches:
logger.info("Patch file: " + file) logger.info("Patch file: " + file)
for patch in patches[file]: for patch in patches[file]:
logger.info("Detector: " + patch['detector']) logger.info("Detector: " + patch["detector"])
logger.info("Old string: " + patch['old_string'].replace("\n","")) logger.info("Old string: " + patch["old_string"].replace("\n", ""))
logger.info("New string: " + patch['new_string'].replace("\n","")) logger.info("New string: " + patch["new_string"].replace("\n", ""))
logger.info("Location start: " + str(patch['start'])) logger.info("Location start: " + str(patch["start"]))
logger.info("Location end: " + str(patch['end'])) logger.info("Location end: " + str(patch["end"]))
def print_patches_json(number_of_slither_results, patches): def print_patches_json(number_of_slither_results, patches):
print('{',end='') print("{", end="")
print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",') print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",')
print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',') print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",")
print("\"Patchlets\":" + '[') print('"Patchlets":' + "[")
for index, file in enumerate(patches): for index, file in enumerate(patches):
if index > 0: if index > 0:
print(',') print(",")
print('{',end='') print("{", end="")
print("\"Patch file\":" + '"' + file + '",') print('"Patch file":' + '"' + file + '",')
print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',') print('"Number of patches":' + '"' + str(len(patches[file])) + '"', ",")
print("\"Patches\":" + '[') print('"Patches":' + "[")
for index, patch in enumerate(patches[file]): for index, patch in enumerate(patches[file]):
if index > 0: if index > 0:
print(',') print(",")
print('{',end='') print("{", end="")
print("\"Detector\":" + '"' + patch['detector'] + '",') print('"Detector":' + '"' + patch["detector"] + '",')
print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",') print('"Old string":' + '"' + patch["old_string"].replace("\n", "") + '",')
print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",') print('"New string":' + '"' + patch["new_string"].replace("\n", "") + '",')
print("\"Location start\":" + '"' + str(patch['start']) + '",') print('"Location start":' + '"' + str(patch["start"]) + '",')
print("\"Location end\":" + '"' + str(patch['end']) + '"') print('"Location end":' + '"' + str(patch["end"]) + '"')
if 'overlaps' in patch: if "overlaps" in patch:
print("\"Overlaps\":" + "Yes") print('"Overlaps":' + "Yes")
print('}',end='') print("}", end="")
print(']',end='') print("]", end="")
print('}',end='') print("}", end="")
print(']',end='') print("]", end="")
print('}') print("}")

@ -12,7 +12,12 @@ from slither.utils.colors import red
from slither.utils.output import output_to_json from slither.utils.output import output_to_json
from .checks import all_checks from .checks import all_checks
from .checks.abstract_checks import AbstractCheck from .checks.abstract_checks import AbstractCheck
from .utils.command_line import output_detectors_json, output_wiki, output_detectors, output_to_markdown from .utils.command_line import (
output_detectors_json,
output_wiki,
output_detectors,
output_to_markdown,
)
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
@ -21,49 +26,53 @@ logger.setLevel(logging.INFO)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.', description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.",
usage="slither-check-upgradeability contract.sol ContractName") usage="slither-check-upgradeability contract.sol ContractName",
)
parser.add_argument('contract.sol', help='Codebase to analyze')
parser.add_argument('ContractName', help='Contract name (logic contract)') parser.add_argument("contract.sol", help="Codebase to analyze")
parser.add_argument("ContractName", help="Contract name (logic contract)")
parser.add_argument('--proxy-name', help='Proxy name')
parser.add_argument('--proxy-filename', help='Proxy filename (if different)') parser.add_argument("--proxy-name", help="Proxy name")
parser.add_argument("--proxy-filename", help="Proxy filename (if different)")
parser.add_argument('--new-contract-name', help='New contract name (if changed)')
parser.add_argument('--new-contract-filename', help='New implementation filename (if different)') parser.add_argument("--new-contract-name", help="New contract name (if changed)")
parser.add_argument(
parser.add_argument('--json', "--new-contract-filename", help="New implementation filename (if different)"
help='Export the results as a JSON file ("--json -" to export to stdout)', )
action='store',
default=False) parser.add_argument(
"--json",
parser.add_argument('--list-detectors', help='Export the results as a JSON file ("--json -" to export to stdout)',
help='List available detectors', action="store",
action=ListDetectors, default=False,
nargs=0, )
default=False)
parser.add_argument(
parser.add_argument('--markdown-root', "--list-detectors",
help='URL for markdown generation', help="List available detectors",
action='store', action=ListDetectors,
default="") nargs=0,
default=False,
parser.add_argument('--wiki-detectors', )
help=argparse.SUPPRESS,
action=OutputWiki, parser.add_argument(
default=False) "--markdown-root", help="URL for markdown generation", action="store", default=""
)
parser.add_argument('--list-detectors-json',
help=argparse.SUPPRESS, parser.add_argument(
action=ListDetectorsJson, "--wiki-detectors", help=argparse.SUPPRESS, action=OutputWiki, default=False
nargs=0, )
default=False)
parser.add_argument(
parser.add_argument('--markdown', "--list-detectors-json",
help=argparse.SUPPRESS, help=argparse.SUPPRESS,
action=OutputMarkdown, action=ListDetectorsJson,
default=False) nargs=0,
default=False,
)
parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False)
cryticparser.init(parser) cryticparser.init(parser)
@ -80,6 +89,7 @@ def parse_args():
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def _get_checks(): def _get_checks():
detectors = [getattr(all_checks, name) for name in dir(all_checks)] detectors = [getattr(all_checks, name) for name in dir(all_checks)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)] detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)]
@ -123,13 +133,18 @@ def _run_checks(detectors):
def _checks_on_contract(detectors, contract): def _checks_on_contract(detectors, contract):
detectors = [d(logger, contract) for d in detectors if (not d.REQUIRE_PROXY and detectors = [
not d.REQUIRE_CONTRACT_V2)] d(logger, contract)
for d in detectors
if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2)
]
return _run_checks(detectors), len(detectors) return _run_checks(detectors), len(detectors)
def _checks_on_contract_update(detectors, contract_v1, contract_v2): def _checks_on_contract_update(detectors, contract_v1, contract_v2):
detectors = [d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2] detectors = [
d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2
]
return _run_checks(detectors), len(detectors) return _run_checks(detectors), len(detectors)
@ -147,15 +162,11 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy):
def main(): def main():
json_results = { json_results = {"proxy-present": False, "contract_v2-present": False, "detectors": []}
'proxy-present': False,
'contract_v2-present': False,
'detectors': []
}
args = parse_args() args = parse_args()
v1_filename = vars(args)['contract.sol'] v1_filename = vars(args)["contract.sol"]
number_detectors_run = 0 number_detectors_run = 0
detectors = _get_checks() detectors = _get_checks()
try: try:
@ -165,14 +176,14 @@ def main():
v1_name = args.ContractName v1_name = args.ContractName
v1_contract = v1.get_contract_from_name(v1_name) v1_contract = v1.get_contract_from_name(v1_name)
if v1_contract is None: if v1_contract is None:
info = 'Contract {} not found in {}'.format(v1_name, v1.filename) info = "Contract {} not found in {}".format(v1_name, v1.filename)
logger.error(red(info)) logger.error(red(info))
if args.json: if args.json:
output_to_json(args.json, str(info), json_results) output_to_json(args.json, str(info), json_results)
return return
detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract) detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract)
json_results['detectors'] += detectors_results json_results["detectors"] += detectors_results
number_detectors_run += number_detectors number_detectors_run += number_detectors
# Analyze Proxy # Analyze Proxy
@ -185,15 +196,17 @@ def main():
proxy_contract = proxy.get_contract_from_name(args.proxy_name) proxy_contract = proxy.get_contract_from_name(args.proxy_name)
if proxy_contract is None: if proxy_contract is None:
info = 'Proxy {} not found in {}'.format(args.proxy_name, proxy.filename) info = "Proxy {} not found in {}".format(args.proxy_name, proxy.filename)
logger.error(red(info)) logger.error(red(info))
if args.json: if args.json:
output_to_json(args.json, str(info), json_results) output_to_json(args.json, str(info), json_results)
return return
json_results['proxy-present'] = True json_results["proxy-present"] = True
detectors_results, number_detectors = _checks_on_contract_and_proxy(detectors, v1_contract, proxy_contract) detectors_results, number_detectors = _checks_on_contract_and_proxy(
json_results['detectors'] += detectors_results detectors, v1_contract, proxy_contract
)
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors number_detectors_run += number_detectors
# Analyze new version # Analyze new version
if args.new_contract_name: if args.new_contract_name:
@ -204,30 +217,36 @@ def main():
v2_contract = v2.get_contract_from_name(args.new_contract_name) v2_contract = v2.get_contract_from_name(args.new_contract_name)
if v2_contract is None: if v2_contract is None:
info = 'New logic contract {} not found in {}'.format(args.new_contract_name, v2.filename) info = "New logic contract {} not found in {}".format(
args.new_contract_name, v2.filename
)
logger.error(red(info)) logger.error(red(info))
if args.json: if args.json:
output_to_json(args.json, str(info), json_results) output_to_json(args.json, str(info), json_results)
return return
json_results['contract_v2-present'] = True json_results["contract_v2-present"] = True
if proxy_contract: if proxy_contract:
detectors_results, _ = _checks_on_contract_and_proxy(detectors, detectors_results, _ = _checks_on_contract_and_proxy(
v2_contract, detectors, v2_contract, proxy_contract
proxy_contract) )
json_results['detectors'] += detectors_results json_results["detectors"] += detectors_results
detectors_results, number_detectors = _checks_on_contract_update(detectors, v1_contract, v2_contract) detectors_results, number_detectors = _checks_on_contract_update(
json_results['detectors'] += detectors_results detectors, v1_contract, v2_contract
)
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors number_detectors_run += number_detectors
# If there is a V2, we run the contract-only check on the V2 # If there is a V2, we run the contract-only check on the V2
detectors_results, _ = _checks_on_contract(detectors, v2_contract) detectors_results, _ = _checks_on_contract(detectors, v2_contract)
json_results['detectors'] += detectors_results json_results["detectors"] += detectors_results
number_detectors_run += number_detectors number_detectors_run += number_detectors
logger.info(f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run') logger.info(
f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run'
)
if args.json: if args.json:
output_to_json(args.json, None, json_results) output_to_json(args.json, None, json_results)
@ -237,4 +256,5 @@ def main():
output_to_json(args.json, str(e), json_results) output_to_json(args.json, str(e), json_results)
return return
# endregion # endregion

@ -19,28 +19,28 @@ classification_colors = {
CheckClassification.INFORMATIONAL: green, CheckClassification.INFORMATIONAL: green,
CheckClassification.LOW: yellow, CheckClassification.LOW: yellow,
CheckClassification.MEDIUM: yellow, CheckClassification.MEDIUM: yellow,
CheckClassification.HIGH: red CheckClassification.HIGH: red,
} }
classification_txt = { classification_txt = {
CheckClassification.INFORMATIONAL: 'Informational', CheckClassification.INFORMATIONAL: "Informational",
CheckClassification.LOW: 'Low', CheckClassification.LOW: "Low",
CheckClassification.MEDIUM: 'Medium', CheckClassification.MEDIUM: "Medium",
CheckClassification.HIGH: 'High', CheckClassification.HIGH: "High",
} }
class AbstractCheck(metaclass=abc.ABCMeta): class AbstractCheck(metaclass=abc.ABCMeta):
ARGUMENT = '' ARGUMENT = ""
HELP = '' HELP = ""
IMPACT = None IMPACT = None
WIKI = '' WIKI = ""
WIKI_TITLE = '' WIKI_TITLE = ""
WIKI_DESCRIPTION = '' WIKI_DESCRIPTION = ""
WIKI_EXPLOIT_SCENARIO = '' WIKI_EXPLOIT_SCENARIO = ""
WIKI_RECOMMENDATION = '' WIKI_RECOMMENDATION = ""
REQUIRE_CONTRACT = False REQUIRE_CONTRACT = False
REQUIRE_PROXY = False REQUIRE_PROXY = False
@ -53,43 +53,69 @@ class AbstractCheck(metaclass=abc.ABCMeta):
self.contract_v2 = contract_v2 self.contract_v2 = contract_v2
if not self.ARGUMENT: if not self.ARGUMENT:
raise IncorrectCheckInitialization('NAME is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"NAME is not initialized {}".format(self.__class__.__name__)
)
if not self.HELP: if not self.HELP:
raise IncorrectCheckInitialization('HELP is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"HELP is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI: if not self.WIKI:
raise IncorrectCheckInitialization('WIKI is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"WIKI is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_TITLE: if not self.WIKI_TITLE:
raise IncorrectCheckInitialization('WIKI_TITLE is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"WIKI_TITLE is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_DESCRIPTION: if not self.WIKI_DESCRIPTION:
raise IncorrectCheckInitialization('WIKI_DESCRIPTION is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"WIKI_DESCRIPTION is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [CheckClassification.INFORMATIONAL]: if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [
raise IncorrectCheckInitialization('WIKI_EXPLOIT_SCENARIO is not initialized {}'.format(self.__class__.__name__)) CheckClassification.INFORMATIONAL
]:
raise IncorrectCheckInitialization(
"WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_RECOMMENDATION: if not self.WIKI_RECOMMENDATION:
raise IncorrectCheckInitialization('WIKI_RECOMMENDATION is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2: if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2:
# This is not a fundatemenal issues # This is not a fundatemenal issues
# But it requires to change __main__ to avoid running two times the detectors # But it requires to change __main__ to avoid running two times the detectors
txt = 'REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}'.format(self.__class__.__name__) txt = "REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}".format(
self.__class__.__name__
)
raise IncorrectCheckInitialization(txt) raise IncorrectCheckInitialization(txt)
if self.IMPACT not in [CheckClassification.LOW, if self.IMPACT not in [
CheckClassification.MEDIUM, CheckClassification.LOW,
CheckClassification.HIGH, CheckClassification.MEDIUM,
CheckClassification.INFORMATIONAL]: CheckClassification.HIGH,
raise IncorrectCheckInitialization('IMPACT is not initialized {}'.format(self.__class__.__name__)) CheckClassification.INFORMATIONAL,
]:
raise IncorrectCheckInitialization(
"IMPACT is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_CONTRACT_V2 and contract_v2 is None: if self.REQUIRE_CONTRACT_V2 and contract_v2 is None:
raise IncorrectCheckInitialization('ContractV2 is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"ContractV2 is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_PROXY and proxy is None: if self.REQUIRE_PROXY and proxy is None:
raise IncorrectCheckInitialization('Proxy is not initialized {}'.format(self.__class__.__name__)) raise IncorrectCheckInitialization(
"Proxy is not initialized {}".format(self.__class__.__name__)
)
@abc.abstractmethod @abc.abstractmethod
def _check(self): def _check(self):
@ -102,19 +128,17 @@ class AbstractCheck(metaclass=abc.ABCMeta):
all_results = [r.data for r in all_results] all_results = [r.data for r in all_results]
if all_results: if all_results:
if self.logger: if self.logger:
info = '\n' info = "\n"
for idx, result in enumerate(all_results): for idx, result in enumerate(all_results):
info += result['description'] info += result["description"]
info += 'Reference: {}'.format(self.WIKI) info += "Reference: {}".format(self.WIKI)
self._log(info) self._log(info)
return all_results return all_results
def generate_result(self, info, additional_fields=None): def generate_result(self, info, additional_fields=None):
output = Output(info, output = Output(info, additional_fields, markdown_root=self.contract.slither.markdown_root)
additional_fields,
markdown_root=self.contract.slither.markdown_root)
output.data['check'] = self.ARGUMENT output.data["check"] = self.ARGUMENT
return output return output

@ -1,11 +1,23 @@
from .initialization import (InitializablePresent, InitializableInherited, from .initialization import (
InitializableInitializer, MissingInitializerModifier, MissingCalls, MultipleCalls, InitializeTarget) InitializablePresent,
InitializableInherited,
InitializableInitializer,
MissingInitializerModifier,
MissingCalls,
MultipleCalls,
InitializeTarget,
)
from .functions_ids import IDCollision, FunctionShadowing from .functions_ids import IDCollision, FunctionShadowing
from .variable_initialization import VariableWithInit from .variable_initialization import VariableWithInit
from .variables_order import (MissingVariable, DifferentVariableContractProxy, from .variables_order import (
DifferentVariableContractNewContract, ExtraVariablesProxy, ExtraVariablesNewContract) MissingVariable,
DifferentVariableContractProxy,
DifferentVariableContractNewContract,
ExtraVariablesProxy,
ExtraVariablesNewContract,
)
from .constant import WereConstant, BecameConstant from .constant import WereConstant, BecameConstant

@ -2,17 +2,17 @@ from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, C
class WereConstant(AbstractCheck): class WereConstant(AbstractCheck):
ARGUMENT = 'were-constant' ARGUMENT = "were-constant"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Variables that should be constant' HELP = "Variables that should be constant"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant"
WIKI_TITLE = 'Variables that should be constant' WIKI_TITLE = "Variables that should be constant"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect state variables that should be `constant̀`. Detect state variables that should be `constant̀`.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
uint variable1; uint variable1;
@ -28,11 +28,11 @@ contract ContractV2{
``` ```
Because `variable2` is not anymore a `constant`, the storage location of `variable3` will be different. Because `variable2` is not anymore a `constant`, the storage location of `variable3` will be different.
As a result, `ContractV2` will have a corrupted storage layout. As a result, `ContractV2` will have a corrupted storage layout.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Do not remove `constant` from a state variables during an update. Do not remove `constant` from a state variables during an update.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True REQUIRE_CONTRACT_V2 = True
@ -66,12 +66,13 @@ Do not remove `constant` from a state variables during an update.
if state_v1.is_constant: if state_v1.is_constant:
if not state_v2.is_constant: if not state_v2.is_constant:
# If v2 has additional non constant variables, we need to skip them # If v2 has additional non constant variables, we need to skip them
if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type) if (
and v2_additional_variables > 0): state_v1.name != state_v2.name or state_v1.type != state_v2.type
) and v2_additional_variables > 0:
v2_additional_variables -= 1 v2_additional_variables -= 1
idx_v2 += 1 idx_v2 += 1
continue continue
info = [state_v1, ' was constant, but ', state_v2, 'is not.\n'] info = [state_v1, " was constant, but ", state_v2, "is not.\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
@ -80,19 +81,20 @@ Do not remove `constant` from a state variables during an update.
return results return results
class BecameConstant(AbstractCheck): class BecameConstant(AbstractCheck):
ARGUMENT = 'became-constant' ARGUMENT = "became-constant"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Variables that should not be constant' HELP = "Variables that should not be constant"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant"
WIKI_TITLE = 'Variables that should not be constant' WIKI_TITLE = "Variables that should not be constant"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect state variables that should not be `constant̀`. Detect state variables that should not be `constant̀`.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
uint variable1; uint variable1;
@ -108,11 +110,11 @@ contract ContractV2{
``` ```
Because `variable2` is now a `constant`, the storage location of `variable3` will be different. Because `variable2` is now a `constant`, the storage location of `variable3` will be different.
As a result, `ContractV2` will have a corrupted storage layout. As a result, `ContractV2` will have a corrupted storage layout.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Do not make an existing state variable `constant`. Do not make an existing state variable `constant`.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True REQUIRE_CONTRACT_V2 = True
@ -146,13 +148,14 @@ Do not make an existing state variable `constant`.
if state_v1.is_constant: if state_v1.is_constant:
if not state_v2.is_constant: if not state_v2.is_constant:
# If v2 has additional non constant variables, we need to skip them # If v2 has additional non constant variables, we need to skip them
if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type) if (
and v2_additional_variables > 0): state_v1.name != state_v2.name or state_v1.type != state_v2.type
) and v2_additional_variables > 0:
v2_additional_variables -= 1 v2_additional_variables -= 1
idx_v2 += 1 idx_v2 += 1
continue continue
elif state_v2.is_constant: elif state_v2.is_constant:
info = [state_v1, ' was not constant but ', state_v2, ' is.\n'] info = [state_v1, " was not constant but ", state_v2, " is.\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -5,11 +5,16 @@ from slither.utils.function import get_function_id
def get_signatures(c): def get_signatures(c):
functions = c.functions functions = c.functions
functions = [f.full_name for f in functions if f.visibility in ['public', 'external'] and functions = [
not f.is_constructor and not f.is_fallback] f.full_name
for f in functions
if f.visibility in ["public", "external"] and not f.is_constructor and not f.is_fallback
]
variables = c.state_variables variables = c.state_variables
variables = [variable.name + '()' for variable in variables if variable.visibility in ['public']] variables = [
variable.name + "()" for variable in variables if variable.visibility in ["public"]
]
return list(set(functions + variables)) return list(set(functions + variables))
@ -21,26 +26,26 @@ def _get_function_or_variable(contract, signature):
for variable in contract.state_variables: for variable in contract.state_variables:
# Todo: can lead to incorrect variable in case of shadowing # Todo: can lead to incorrect variable in case of shadowing
if variable.visibility in ['public']: if variable.visibility in ["public"]:
if variable.name + '()' == signature: if variable.name + "()" == signature:
return variable return variable
raise SlitherError(f'Function id checks: {signature} not found in {contract.name}') raise SlitherError(f"Function id checks: {signature} not found in {contract.name}")
class IDCollision(AbstractCheck): class IDCollision(AbstractCheck):
ARGUMENT = 'function-id-collision' ARGUMENT = "function-id-collision"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Functions ids collision' HELP = "Functions ids collision"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions"
WIKI_TITLE = 'Functions ids collisions' WIKI_TITLE = "Functions ids collisions"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect function id collision between the contract and the proxy. Detect function id collision between the contract and the proxy.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
function gsf() public { function gsf() public {
@ -56,11 +61,11 @@ contract Proxy{
``` ```
`Proxy.tgeo()` and `Contract.gsf()` have the same function id (0x67e43e43). `Proxy.tgeo()` and `Contract.gsf()` have the same function id (0x67e43e43).
As a result, `Proxy.tgeo()` will shadow Contract.gsf()`. As a result, `Proxy.tgeo()` will shadow Contract.gsf()`.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Rename the function. Avoid public functions in the proxy. Rename the function. Avoid public functions in the proxy.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = True REQUIRE_PROXY = True
@ -77,11 +82,18 @@ Rename the function. Avoid public functions in the proxy.
for (k, _) in signatures_ids_implem.items(): for (k, _) in signatures_ids_implem.items():
if k in signatures_ids_proxy: if k in signatures_ids_proxy:
if signatures_ids_implem[k] != signatures_ids_proxy[k]: if signatures_ids_implem[k] != signatures_ids_proxy[k]:
implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k]) implem_function = _get_function_or_variable(
self.contract, signatures_ids_implem[k]
)
proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k])
info = ['Function id collision found: ', implem_function, info = [
' ', proxy_function, '\n'] "Function id collision found: ",
implem_function,
" ",
proxy_function,
"\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
@ -89,18 +101,18 @@ Rename the function. Avoid public functions in the proxy.
class FunctionShadowing(AbstractCheck): class FunctionShadowing(AbstractCheck):
ARGUMENT = 'function-shadowing' ARGUMENT = "function-shadowing"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Functions shadowing' HELP = "Functions shadowing"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing"
WIKI_TITLE = 'Functions shadowing' WIKI_TITLE = "Functions shadowing"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect function shadowing between the contract and the proxy. Detect function shadowing between the contract and the proxy.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
function get() public { function get() public {
@ -115,11 +127,11 @@ contract Proxy{
} }
``` ```
`Proxy.get` will shadow any call to `get()`. As a result `get()` is never executed in the logic contract and cannot be updated. `Proxy.get` will shadow any call to `get()`. As a result `get()` is never executed in the logic contract and cannot be updated.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Rename the function. Avoid public functions in the proxy. Rename the function. Avoid public functions in the proxy.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = True REQUIRE_PROXY = True
@ -136,11 +148,18 @@ Rename the function. Avoid public functions in the proxy.
for (k, _) in signatures_ids_implem.items(): for (k, _) in signatures_ids_implem.items():
if k in signatures_ids_proxy: if k in signatures_ids_proxy:
if signatures_ids_implem[k] == signatures_ids_proxy[k]: if signatures_ids_implem[k] == signatures_ids_proxy[k]:
implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k]) implem_function = _get_function_or_variable(
self.contract, signatures_ids_implem[k]
)
proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k])
info = ['Function shadowing found: ', implem_function, info = [
' ', proxy_function, '\n'] "Function shadowing found: ",
implem_function,
" ",
proxy_function,
"\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -13,16 +13,20 @@ class MultipleInitTarget(Exception):
def _get_initialize_functions(contract): def _get_initialize_functions(contract):
return [f for f in contract.functions if f.name == 'initialize' and f.is_implemented] return [f for f in contract.functions if f.name == "initialize" and f.is_implemented]
def _get_all_internal_calls(function): def _get_all_internal_calls(function):
all_ir = function.all_slithir_operations() all_ir = function.all_slithir_operations()
return [i.function for i in all_ir if isinstance(i, InternalCall) and i.function_name == "initialize"] return [
i.function
for i in all_ir
if isinstance(i, InternalCall) and i.function_name == "initialize"
]
def _get_most_derived_init(contract): def _get_most_derived_init(contract):
init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == 'initialize'] init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == "initialize"]
if len(init_functions) > 1: if len(init_functions) > 1:
if len([f for f in init_functions if f.contract_declarer == contract]) == 1: if len([f for f in init_functions if f.contract_declarer == contract]) == 1:
return next((f for f in init_functions if f.contract_declarer == contract)) return next((f for f in init_functions if f.contract_declarer == contract))
@ -33,80 +37,82 @@ def _get_most_derived_init(contract):
class InitializablePresent(AbstractCheck): class InitializablePresent(AbstractCheck):
ARGUMENT = 'init-missing' ARGUMENT = "init-missing"
IMPACT = CheckClassification.INFORMATIONAL IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initializable is missing' HELP = "Initializable is missing"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing"
WIKI_TITLE = 'Initializable is missing' WIKI_TITLE = "Initializable is missing"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect if a contract `Initializable` is present. Detect if a contract `Initializable` is present.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Review manually the contract's initialization.. Review manually the contract's initialization..
Consider using a `Initializable` contract to follow [standard practice](https://docs.openzeppelin.com/upgrades/2.7/writing-upgradeable). Consider using a `Initializable` contract to follow [standard practice](https://docs.openzeppelin.com/upgrades/2.7/writing-upgradeable).
''' """
def _check(self): def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable') initializable = self.contract.slither.get_contract_from_name("Initializable")
if initializable is None: if initializable is None:
info = ["Initializable contract not found, the contract does not follow a standard initalization schema.\n"] info = [
"Initializable contract not found, the contract does not follow a standard initalization schema.\n"
]
json = self.generate_result(info) json = self.generate_result(info)
return [json] return [json]
return [] return []
class InitializableInherited(AbstractCheck): class InitializableInherited(AbstractCheck):
ARGUMENT = 'init-inherited' ARGUMENT = "init-inherited"
IMPACT = CheckClassification.INFORMATIONAL IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initializable is not inherited' HELP = "Initializable is not inherited"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited"
WIKI_TITLE = 'Initializable is not inherited' WIKI_TITLE = "Initializable is not inherited"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect if `Initializable` is inherited. Detect if `Initializable` is inherited.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Review manually the contract's initialization. Consider inheriting `Initializable`. Review manually the contract's initialization. Consider inheriting `Initializable`.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable') initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent # See InitializablePresent
if initializable is None: if initializable is None:
return [] return []
if initializable not in self.contract.inheritance: if initializable not in self.contract.inheritance:
info = [self.contract, ' does not inherit from ', initializable, '.\n'] info = [self.contract, " does not inherit from ", initializable, ".\n"]
json = self.generate_result(info) json = self.generate_result(info)
return [json] return [json]
return [] return []
class InitializableInitializer(AbstractCheck): class InitializableInitializer(AbstractCheck):
ARGUMENT = 'initializer-missing' ARGUMENT = "initializer-missing"
IMPACT = CheckClassification.INFORMATIONAL IMPACT = CheckClassification.INFORMATIONAL
HELP = 'initializer() is missing' HELP = "initializer() is missing"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing"
WIKI_TITLE = 'initializer() is missing' WIKI_TITLE = "initializer() is missing"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect the lack of `Initializable.initializer()` modifier. Detect the lack of `Initializable.initializer()` modifier.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Review manually the contract's initialization. Consider inheriting a `Initializable.initializer()` modifier. Review manually the contract's initialization. Consider inheriting a `Initializable.initializer()` modifier.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable') initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent # See InitializablePresent
if initializable is None: if initializable is None:
return [] return []
@ -114,26 +120,26 @@ Review manually the contract's initialization. Consider inheriting a `Initializa
if initializable not in self.contract.inheritance: if initializable not in self.contract.inheritance:
return [] return []
initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()') initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()")
if initializer is None: if initializer is None:
info = ['Initializable.initializer() does not exist.\n'] info = ["Initializable.initializer() does not exist.\n"]
json = self.generate_result(info) json = self.generate_result(info)
return [json] return [json]
return [] return []
class MissingInitializerModifier(AbstractCheck): class MissingInitializerModifier(AbstractCheck):
ARGUMENT = 'missing-init-modifier' ARGUMENT = "missing-init-modifier"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'initializer() is not called' HELP = "initializer() is not called"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called"
WIKI_TITLE = 'initializer() is not called' WIKI_TITLE = "initializer() is not called"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect if `Initializable.initializer()` is called. Detect if `Initializable.initializer()` is called.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
function initialize() public{ function initialize() public{
@ -143,23 +149,23 @@ contract Contract{
``` ```
`initialize` should have the `initializer` modifier to prevent someone from initializing the contract multiple times. `initialize` should have the `initializer` modifier to prevent someone from initializing the contract multiple times.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Use `Initializable.initializer()`. Use `Initializable.initializer()`.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable') initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent # See InitializablePresent
if initializable is None: if initializable is None:
return [] return []
# See InitializableInherited # See InitializableInherited
if initializable not in self.contract.inheritance: if initializable not in self.contract.inheritance:
return [] return []
initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()') initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()")
# InitializableInitializer # InitializableInitializer
if initializer is None: if initializer is None:
return [] return []
@ -168,24 +174,24 @@ Use `Initializable.initializer()`.
all_init_functions = _get_initialize_functions(self.contract) all_init_functions = _get_initialize_functions(self.contract)
for f in all_init_functions: for f in all_init_functions:
if initializer not in f.modifiers: if initializer not in f.modifiers:
info = [f, ' does not call the initializer modifier.\n'] info = [f, " does not call the initializer modifier.\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
return results return results
class MissingCalls(AbstractCheck): class MissingCalls(AbstractCheck):
ARGUMENT = 'missing-calls' ARGUMENT = "missing-calls"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Missing calls to init functions' HELP = "Missing calls to init functions"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called"
WIKI_TITLE = 'Initialize functions are not called' WIKI_TITLE = "Initialize functions are not called"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect missing calls to initialize functions. Detect missing calls to initialize functions.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Base{ contract Base{
function initialize() public{ function initialize() public{
@ -200,11 +206,11 @@ contract Derived is Base{
``` ```
`Derived.initialize` does not call `Base.initialize` leading the contract to not be correctly initialized. `Derived.initialize` does not call `Base.initialize` leading the contract to not be correctly initialized.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Ensure all the initialize functions are reached by the most derived initialize function. Ensure all the initialize functions are reached by the most derived initialize function.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
@ -215,7 +221,7 @@ Ensure all the initialize functions are reached by the most derived initialize f
try: try:
most_derived_init = _get_most_derived_init(self.contract) most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget: except MultipleInitTarget:
logger.error(red(f'Too many init targets in {self.contract}')) logger.error(red(f"Too many init targets in {self.contract}"))
return [] return []
if most_derived_init is None: if most_derived_init is None:
@ -225,24 +231,24 @@ Ensure all the initialize functions are reached by the most derived initialize f
all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init]
missing_calls = [f for f in all_init_functions if not f in all_init_functions_called] missing_calls = [f for f in all_init_functions if not f in all_init_functions_called]
for f in missing_calls: for f in missing_calls:
info = ['Missing call to ', f, ' in ', most_derived_init, '.\n'] info = ["Missing call to ", f, " in ", most_derived_init, ".\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
return results return results
class MultipleCalls(AbstractCheck): class MultipleCalls(AbstractCheck):
ARGUMENT = 'multiple-calls' ARGUMENT = "multiple-calls"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Init functions called multiple times' HELP = "Init functions called multiple times"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times"
WIKI_TITLE = 'Initialize functions are called multiple times' WIKI_TITLE = "Initialize functions are called multiple times"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect multiple calls to a initialize function. Detect multiple calls to a initialize function.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Base{ contract Base{
function initialize(uint) public{ function initialize(uint) public{
@ -264,11 +270,11 @@ contract DerivedDerived is Derived{
``` ```
`Base.initialize(uint)` is called two times in `DerivedDerived.initiliaze` execution, leading to a potential corruption. `Base.initialize(uint)` is called two times in `DerivedDerived.initiliaze` execution, leading to a potential corruption.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Call only one time every initialize function. Call only one time every initialize function.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
@ -280,38 +286,41 @@ Call only one time every initialize function.
most_derived_init = _get_most_derived_init(self.contract) most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget: except MultipleInitTarget:
# Should be already reported by MissingCalls # Should be already reported by MissingCalls
#logger.error(red(f'Too many init targets in {self.contract}')) # logger.error(red(f'Too many init targets in {self.contract}'))
return [] return []
if most_derived_init is None: if most_derived_init is None:
return [] return []
all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init]
double_calls = list(set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1])) double_calls = list(
set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1])
)
for f in double_calls: for f in double_calls:
info = [f, ' is called multiple times in ', most_derived_init, '.\n'] info = [f, " is called multiple times in ", most_derived_init, ".\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
return results return results
class InitializeTarget(AbstractCheck): class InitializeTarget(AbstractCheck):
ARGUMENT = 'initialize-target' ARGUMENT = "initialize-target"
IMPACT = CheckClassification.INFORMATIONAL IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initialize function that must be called' HELP = "Initialize function that must be called"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function"
WIKI_TITLE = 'Initialize function' WIKI_TITLE = "Initialize function"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Show the function that must be called at deployment. Show the function that must be called at deployment.
This finding does not have an immediate security impact and is informative. This finding does not have an immediate security impact and is informative.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Ensure that the function is called at deployment. Ensure that the function is called at deployment.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
@ -322,12 +331,12 @@ Ensure that the function is called at deployment.
most_derived_init = _get_most_derived_init(self.contract) most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget: except MultipleInitTarget:
# Should be already reported by MissingCalls # Should be already reported by MissingCalls
#logger.error(red(f'Too many init targets in {self.contract}')) # logger.error(red(f'Too many init targets in {self.contract}'))
return [] return []
if most_derived_init is None: if most_derived_init is None:
return [] return []
info = [self.contract, f' needs to be initialized by ', most_derived_init, '.\n'] info = [self.contract, f" needs to be initialized by ", most_derived_init, ".\n"]
json = self.generate_result(info) json = self.generate_result(info)
return [json] return [json]

@ -2,29 +2,29 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat
class VariableWithInit(AbstractCheck): class VariableWithInit(AbstractCheck):
ARGUMENT = 'variables-initialized' ARGUMENT = "variables-initialized"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'State variables with an initial value' HELP = "State variables with an initial value"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized"
WIKI_TITLE = 'State variable initialized' WIKI_TITLE = "State variable initialized"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect state variables that are initialized. Detect state variables that are initialized.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
uint variable = 10; uint variable = 10;
} }
``` ```
Using `Contract` will the delegatecall proxy pattern will lead `variable` to be 0 when called through the proxy. Using `Contract` will the delegatecall proxy pattern will lead `variable` to be 0 when called through the proxy.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Using initialize functions to write initial values in state variables. Using initialize functions to write initial values in state variables.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
@ -32,7 +32,7 @@ Using initialize functions to write initial values in state variables.
results = [] results = []
for s in self.contract.state_variables: for s in self.contract.state_variables:
if s.initialized and not s.is_constant: if s.initialized and not s.is_constant:
info = [s, ' is a state variable with an initial value.\n'] info = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
return results return results

@ -2,16 +2,16 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat
class MissingVariable(AbstractCheck): class MissingVariable(AbstractCheck):
ARGUMENT = 'missing-variables' ARGUMENT = "missing-variables"
IMPACT = CheckClassification.MEDIUM IMPACT = CheckClassification.MEDIUM
HELP = 'Variable missing in the v2' HELP = "Variable missing in the v2"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables"
WIKI_TITLE = 'Missing variables' WIKI_TITLE = "Missing variables"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect variables that were present in the original contracts but are not in the updated one. Detect variables that were present in the original contracts but are not in the updated one.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract V1{ contract V1{
uint variable1; uint variable1;
@ -25,11 +25,11 @@ contract V2{
The new version, `V2` does not contain `variable1`. The new version, `V2` does not contain `variable1`.
If a new variable is added in an update of `V2`, this variable will hold the latest value of `variable2` and If a new variable is added in an update of `V2`, this variable will hold the latest value of `variable2` and
will be corrupted. will be corrupted.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Do not change the order of the state variables in the updated contract. Do not change the order of the state variables in the updated contract.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True REQUIRE_CONTRACT_V2 = True
@ -44,7 +44,7 @@ Do not change the order of the state variables in the updated contract.
for idx in range(0, len(order1)): for idx in range(0, len(order1)):
variable1 = order1[idx] variable1 = order1[idx]
if len(order2) <= idx: if len(order2) <= idx:
info = ['Variable missing in ', contract2, ': ', variable1, '\n'] info = ["Variable missing in ", contract2, ": ", variable1, "\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
@ -52,18 +52,18 @@ Do not change the order of the state variables in the updated contract.
class DifferentVariableContractProxy(AbstractCheck): class DifferentVariableContractProxy(AbstractCheck):
ARGUMENT = 'order-vars-proxy' ARGUMENT = "order-vars-proxy"
IMPACT = CheckClassification.HIGH IMPACT = CheckClassification.HIGH
HELP = 'Incorrect vars order with the proxy' HELP = "Incorrect vars order with the proxy"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy"
WIKI_TITLE = 'Incorrect variables with the proxy' WIKI_TITLE = "Incorrect variables with the proxy"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect variables that are different between the contract and the proxy. Detect variables that are different between the contract and the proxy.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
uint variable1; uint variable1;
@ -74,11 +74,11 @@ contract Proxy{
} }
``` ```
`Contract` and `Proxy` do not have the same storage layout. As a result the storage of both contracts can be corrupted. `Contract` and `Proxy` do not have the same storage layout. As a result the storage of both contracts can be corrupted.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract. Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = True REQUIRE_PROXY = True
@ -104,9 +104,9 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
variable1 = order1[idx] variable1 = order1[idx]
variable2 = order2[idx] variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type): if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info = ['Different variables between ', contract1, ' and ', contract2, '\n'] info = ["Different variables between ", contract1, " and ", contract2, "\n"]
info += [f'\t ', variable1, '\n'] info += [f"\t ", variable1, "\n"]
info += [f'\t ', variable2, '\n'] info += [f"\t ", variable2, "\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
@ -114,17 +114,17 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
class DifferentVariableContractNewContract(DifferentVariableContractProxy): class DifferentVariableContractNewContract(DifferentVariableContractProxy):
ARGUMENT = 'order-vars-contracts' ARGUMENT = "order-vars-contracts"
HELP = 'Incorrect vars order with the v2' HELP = "Incorrect vars order with the v2"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2"
WIKI_TITLE = 'Incorrect variables with the v2' WIKI_TITLE = "Incorrect variables with the v2"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect variables that are different between the original contract and the updated one. Detect variables that are different between the original contract and the updated one.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
uint variable1; uint variable1;
@ -135,11 +135,11 @@ contract ContractV2{
} }
``` ```
`Contract` and `ContractV2` do not have the same storage layout. As a result the storage of both contracts can be corrupted. `Contract` and `ContractV2` do not have the same storage layout. As a result the storage of both contracts can be corrupted.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Respect the variable order of the original contract in the updated contract. Respect the variable order of the original contract in the updated contract.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = False REQUIRE_PROXY = False
@ -150,18 +150,20 @@ Respect the variable order of the original contract in the updated contract.
class ExtraVariablesProxy(AbstractCheck): class ExtraVariablesProxy(AbstractCheck):
ARGUMENT = 'extra-vars-proxy' ARGUMENT = "extra-vars-proxy"
IMPACT = CheckClassification.MEDIUM IMPACT = CheckClassification.MEDIUM
HELP = 'Extra vars in the proxy' HELP = "Extra vars in the proxy"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy' WIKI = (
WIKI_TITLE = 'Extra variables in the proxy' "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy"
)
WIKI_TITLE = "Extra variables in the proxy"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Detect variables that are in the proxy and not in the contract. Detect variables that are in the proxy and not in the contract.
''' """
WIKI_EXPLOIT_SCENARIO = ''' WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Contract{ contract Contract{
uint variable1; uint variable1;
@ -173,11 +175,11 @@ contract Proxy{
} }
``` ```
`Proxy` contains additional variables. A future update of `Contract` is likely to corrupt the proxy. `Proxy` contains additional variables. A future update of `Contract` is likely to corrupt the proxy.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract. Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract.
''' """
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
REQUIRE_PROXY = True REQUIRE_PROXY = True
@ -203,7 +205,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
while idx < len(order2): while idx < len(order2):
variable2 = order2[idx] variable2 = order2[idx]
info = ['Extra variables in ', contract2, ': ', variable2, '\n'] info = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)
idx = idx + 1 idx = idx + 1
@ -212,21 +214,21 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
class ExtraVariablesNewContract(ExtraVariablesProxy): class ExtraVariablesNewContract(ExtraVariablesProxy):
ARGUMENT = 'extra-vars-v2' ARGUMENT = "extra-vars-v2"
HELP = 'Extra vars in the v2' HELP = "Extra vars in the v2"
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2' WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2"
WIKI_TITLE = 'Extra variables in the v2' WIKI_TITLE = "Extra variables in the v2"
WIKI_DESCRIPTION = ''' WIKI_DESCRIPTION = """
Show new variables in the updated contract. Show new variables in the updated contract.
This finding does not have an immediate security impact and is informative. This finding does not have an immediate security impact and is informative.
''' """
WIKI_RECOMMENDATION = ''' WIKI_RECOMMENDATION = """
Ensure that all the new variables are expected. Ensure that all the new variables are expected.
''' """
IMPACT = CheckClassification.INFORMATIONAL IMPACT = CheckClassification.INFORMATIONAL

@ -4,7 +4,9 @@ from slither.utils.myprettytable import MyPrettyTable
def output_wiki(detector_classes, filter_wiki): def output_wiki(detector_classes, filter_wiki):
# Sort by impact, confidence, and name # Sort by impact, confidence, and name
detectors_list = sorted(detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)) detectors_list = sorted(
detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)
)
for detector in detectors_list: for detector in detectors_list:
if filter_wiki not in detector.WIKI: if filter_wiki not in detector.WIKI:
@ -16,16 +18,16 @@ def output_wiki(detector_classes, filter_wiki):
exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO
recommendation = detector.WIKI_RECOMMENDATION recommendation = detector.WIKI_RECOMMENDATION
print('\n## {}'.format(title)) print("\n## {}".format(title))
print('### Configuration') print("### Configuration")
print('* Check: `{}`'.format(argument)) print("* Check: `{}`".format(argument))
print('* Severity: `{}`'.format(impact)) print("* Severity: `{}`".format(impact))
print('\n### Description') print("\n### Description")
print(description) print(description)
if exploit_scenario: if exploit_scenario:
print('\n### Exploit Scenario:') print("\n### Exploit Scenario:")
print(exploit_scenario) print(exploit_scenario)
print('\n### Recommendation') print("\n### Recommendation")
print(recommendation) print(recommendation)
@ -38,27 +40,31 @@ def output_detectors(detector_classes):
require_proxy = detector.REQUIRE_PROXY require_proxy = detector.REQUIRE_PROXY
require_v2 = detector.REQUIRE_CONTRACT_V2 require_v2 = detector.REQUIRE_CONTRACT_V2
detectors_list.append((argument, help_info, impact, require_proxy, require_v2)) detectors_list.append((argument, help_info, impact, require_proxy, require_v2))
table = MyPrettyTable(["Num", table = MyPrettyTable(["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"])
"Check",
"What it Detects",
"Impact",
"Proxy",
"Contract V2"])
# Sort by impact, confidence, and name # Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1 idx = 1
for (argument, help_info, impact, proxy, v2) in detectors_list: for (argument, help_info, impact, proxy, v2) in detectors_list:
table.add_row([idx, argument, help_info, classification_txt[impact], 'X' if proxy else '', 'X' if v2 else '']) table.add_row(
[
idx,
argument,
help_info,
classification_txt[impact],
"X" if proxy else "",
"X" if v2 else "",
]
)
idx = idx + 1 idx = idx + 1
print(table) print(table)
def output_to_markdown(detector_classes, filter_wiki): def output_to_markdown(detector_classes, filter_wiki):
def extract_help(cls): def extract_help(cls):
if cls.WIKI == '': if cls.WIKI == "":
return cls.HELP return cls.HELP
return '[{}]({})'.format(cls.HELP, cls.WIKI) return "[{}]({})".format(cls.HELP, cls.WIKI)
detectors_list = [] detectors_list = []
for detector in detector_classes: for detector in detector_classes:
@ -73,12 +79,16 @@ def output_to_markdown(detector_classes, filter_wiki):
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1 idx = 1
for (argument, help_info, impact, proxy, v2) in detectors_list: for (argument, help_info, impact, proxy, v2) in detectors_list:
print('{} | `{}` | {} | {} | {} | {}'.format(idx, print(
argument, "{} | `{}` | {} | {} | {} | {}".format(
help_info, idx,
classification_txt[impact], argument,
'X' if proxy else '', help_info,
'X' if v2 else '')) classification_txt[impact],
"X" if proxy else "",
"X" if v2 else "",
)
)
idx = idx + 1 idx = idx + 1
@ -92,26 +102,42 @@ def output_detectors_json(detector_classes):
wiki_description = detector.WIKI_DESCRIPTION wiki_description = detector.WIKI_DESCRIPTION
wiki_exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO wiki_exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO
wiki_recommendation = detector.WIKI_RECOMMENDATION wiki_recommendation = detector.WIKI_RECOMMENDATION
detectors_list.append((argument, detectors_list.append(
help_info, (
impact, argument,
wiki_url, help_info,
wiki_description, impact,
wiki_exploit_scenario, wiki_url,
wiki_recommendation)) wiki_description,
wiki_exploit_scenario,
wiki_recommendation,
)
)
# Sort by impact, confidence, and name # Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1 idx = 1
table = [] table = []
for (argument, help_info, impact, wiki_url, description, exploit, recommendation) in detectors_list: for (
table.append({'index': idx, argument,
'check': argument, help_info,
'title': help_info, impact,
'impact': classification_txt[impact], wiki_url,
'wiki_url': wiki_url, description,
'description': description, exploit,
'exploit_scenario': exploit, recommendation,
'recommendation': recommendation}) ) in detectors_list:
table.append(
{
"index": idx,
"check": argument,
"title": help_info,
"impact": classification_txt[impact],
"wiki_url": wiki_url,
"description": description,
"exploit_scenario": exploit,
"recommendation": recommendation,
}
)
idx = idx + 1 idx = idx + 1
return table return table

@ -3,11 +3,13 @@ import logging
from .expression import ExpressionVisitor from .expression import ExpressionVisitor
from slither.core.expressions import BinaryOperationType, Literal from slither.core.expressions import BinaryOperationType, Literal
class NotConstant(Exception): class NotConstant(Exception):
pass pass
KEY = 'ConstantFolding' KEY = "ConstantFolding"
def get_val(expression): def get_val(expression):
val = expression.context[KEY] val = expression.context[KEY]
@ -15,11 +17,12 @@ def get_val(expression):
del expression.context[KEY] del expression.context[KEY]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[KEY] = val expression.context[KEY] = val
class ConstantFolding(ExpressionVisitor):
class ConstantFolding(ExpressionVisitor):
def __init__(self, expression, type): def __init__(self, expression, type):
self._type = type self._type = type
super(ConstantFolding, self).__init__(expression) super(ConstantFolding, self).__init__(expression)
@ -40,24 +43,24 @@ class ConstantFolding(ExpressionVisitor):
def _post_binary_operation(self, expression): def _post_binary_operation(self, expression):
left = get_val(expression.expression_left) left = get_val(expression.expression_left)
right = get_val(expression.expression_right) right = get_val(expression.expression_right)
if expression.type == BinaryOperationType.POWER: if expression.type == BinaryOperationType.POWER:
set_val(expression, left ** right) set_val(expression, left ** right)
elif expression.type == BinaryOperationType.MULTIPLICATION: elif expression.type == BinaryOperationType.MULTIPLICATION:
set_val(expression, left * right) set_val(expression, left * right)
elif expression.type == BinaryOperationType.DIVISION: elif expression.type == BinaryOperationType.DIVISION:
set_val(expression, left / right) set_val(expression, left / right)
elif expression.type == BinaryOperationType.MODULO: elif expression.type == BinaryOperationType.MODULO:
set_val(expression, left % right) set_val(expression, left % right)
elif expression.type == BinaryOperationType.ADDITION: elif expression.type == BinaryOperationType.ADDITION:
set_val(expression, left + right) set_val(expression, left + right)
elif expression.type == BinaryOperationType.SUBTRACTION: elif expression.type == BinaryOperationType.SUBTRACTION:
if(left-right) <0: if (left - right) < 0:
# Could trigger underflow # Could trigger underflow
raise NotConstant raise NotConstant
set_val(expression, left - right) set_val(expression, left - right)
elif expression.type == BinaryOperationType.LEFT_SHIFT: elif expression.type == BinaryOperationType.LEFT_SHIFT:
set_val(expression, left << right) set_val(expression, left << right)
elif expression.type == BinaryOperationType.RIGHT_SHIFT: elif expression.type == BinaryOperationType.RIGHT_SHIFT:
set_val(expression, left >> right) set_val(expression, left >> right)
else: else:
raise NotConstant raise NotConstant
@ -110,6 +113,3 @@ class ConstantFolding(ExpressionVisitor):
def _post_type_conversion(self, expression): def _post_type_conversion(self, expression):
raise NotConstant raise NotConstant

@ -1,11 +1,11 @@
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.expressions.assignment_operation import AssignmentOperationType
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
key = 'ExportValues' key = "ExportValues"
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -13,11 +13,12 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class ExportValues(ExpressionVisitor):
class ExportValues(ExpressionVisitor):
def result(self): def result(self):
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))

@ -1,10 +1,12 @@
import logging import logging
from typing import Any
from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.assignment_operation import AssignmentOperation
from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.binary_operation import BinaryOperation
from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.conditional_expression import ConditionalExpression from slither.core.expressions.conditional_expression import ConditionalExpression
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.expressions.expression import Expression
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.index_access import IndexAccess
from slither.core.expressions.literal import Literal from slither.core.expressions.literal import Literal
@ -19,24 +21,24 @@ from slither.exceptions import SlitherError
logger = logging.getLogger("ExpressionVisitor") logger = logging.getLogger("ExpressionVisitor")
class ExpressionVisitor:
def __init__(self, expression): class ExpressionVisitor:
def __init__(self, expression: Expression):
# Inherited class must declared their variables prior calling super().__init__ # Inherited class must declared their variables prior calling super().__init__
self._expression = expression self._expression = expression
self._result = None self._result: Any = None
self._visit_expression(self.expression) self._visit_expression(self.expression)
def result(self): def result(self):
return self._result return self._result
@property @property
def expression(self): def expression(self) -> Expression:
return self._expression return self._expression
# visit an expression # visit an expression
# call pre_visit, visit_expression_name, post_visit # call pre_visit, visit_expression_name, post_visit
def _visit_expression(self, expression): def _visit_expression(self, expression: Expression):
self._pre_visit(expression) self._pre_visit(expression)
if isinstance(expression, AssignmentOperation): if isinstance(expression, AssignmentOperation):
@ -88,7 +90,7 @@ class ExpressionVisitor:
pass pass
else: else:
raise SlitherError('Expression not handled: {}'.format(expression)) raise SlitherError("Expression not handled: {}".format(expression))
self._post_visit(expression) self._post_visit(expression)
@ -207,7 +209,7 @@ class ExpressionVisitor:
pass pass
else: else:
raise SlitherError('Expression not handled: {}'.format(expression)) raise SlitherError("Expression not handled: {}".format(expression))
# pre_expression_name # pre_expression_name
@ -308,7 +310,7 @@ class ExpressionVisitor:
pass pass
else: else:
raise SlitherError('Expression not handled: {}'.format(expression)) raise SlitherError("Expression not handled: {}".format(expression))
# post_expression_name # post_expression_name
@ -356,5 +358,3 @@ class ExpressionVisitor:
def _post_unary_operation(self, expression): def _post_unary_operation(self, expression):
pass pass

@ -1,17 +1,18 @@
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
def get(expression): def get(expression):
val = expression.context['ExpressionPrinter'] val = expression.context["ExpressionPrinter"]
# we delete the item to reduce memory use # we delete the item to reduce memory use
del expression.context['ExpressionPrinter'] del expression.context["ExpressionPrinter"]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context['ExpressionPrinter'] = val expression.context["ExpressionPrinter"] = val
class ExpressionPrinter(ExpressionVisitor):
class ExpressionPrinter(ExpressionVisitor):
def result(self): def result(self):
if not self._result: if not self._result:
self._result = get(self.expression) self._result = get(self.expression)
@ -20,19 +21,19 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_assignement_operation(self, expression): def _post_assignement_operation(self, expression):
left = get(expression.expression_left) left = get(expression.expression_left)
right = get(expression.expression_right) right = get(expression.expression_right)
val = "{} {} {}".format(left, expression.type_str, right) val = "{} {} {}".format(left, expression.type, right)
set_val(expression, val) set_val(expression, val)
def _post_binary_operation(self, expression): def _post_binary_operation(self, expression):
left = get(expression.expression_left) left = get(expression.expression_left)
right = get(expression.expression_right) right = get(expression.expression_right)
val = "{} {} {}".format(left, expression.type_str, right) val = "{} {} {}".format(left, expression.type, right)
set_val(expression, val) set_val(expression, val)
def _post_call_expression(self, expression): def _post_call_expression(self, expression):
called = get(expression.called) called = get(expression.called)
arguments = [get(x) for x in expression.arguments if x] arguments = [get(x) for x in expression.arguments if x]
val = "{}({})".format(called, ','.join(arguments)) val = "{}({})".format(called, ",".join(arguments))
set_val(expression, val) set_val(expression, val)
def _post_conditional_expression(self, expression): def _post_conditional_expression(self, expression):
@ -66,7 +67,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_new_array(self, expression): def _post_new_array(self, expression):
array = str(expression.array_type) array = str(expression.array_type)
depth = expression.depth depth = expression.depth
val = "new {}{}".format(array, '[]'*depth) val = "new {}{}".format(array, "[]" * depth)
set_val(expression, val) set_val(expression, val)
def _post_new_contract(self, expression): def _post_new_contract(self, expression):
@ -81,7 +82,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_tuple_expression(self, expression): def _post_tuple_expression(self, expression):
expressions = [get(e) for e in expression.expressions if e] expressions = [get(e) for e in expression.expressions if e]
val = "({})".format(','.join(expressions)) val = "({})".format(",".join(expressions))
set_val(expression, val) set_val(expression, val)
def _post_type_conversion(self, expression): def _post_type_conversion(self, expression):

@ -1,11 +1,10 @@
from typing import List
from slither.core.expressions.expression import Expression
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType key = "FindCall"
from slither.core.variables.variable import Variable
key = 'FindCall'
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -13,12 +12,13 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class FindCalls(ExpressionVisitor):
def result(self): class FindCalls(ExpressionVisitor):
def result(self) -> List[Expression]:
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))
return self._result return self._result

@ -4,7 +4,8 @@ from slither.core.expressions.index_access import IndexAccess
from slither.visitors.expression.right_value import RightValue from slither.visitors.expression.right_value import RightValue
key = 'FindPush' key = "FindPush"
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -12,11 +13,12 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class FindPush(ExpressionVisitor):
class FindPush(ExpressionVisitor):
def result(self): def result(self):
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))
@ -66,7 +68,7 @@ class FindPush(ExpressionVisitor):
def _post_member_access(self, expression): def _post_member_access(self, expression):
val = [] val = []
if expression.member_name == 'push': if expression.member_name == "push":
right = RightValue(expression.expression) right = RightValue(expression.expression)
val = right.result() val = right.result()
set_val(expression, val) set_val(expression, val)

@ -1,13 +1,12 @@
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
class HasConditional(ExpressionVisitor):
class HasConditional(ExpressionVisitor):
def result(self): def result(self):
# == True, to convert None to false # == True, to convert None to false
return self._result is True return self._result is True
def _post_conditional_expression(self, expression): def _post_conditional_expression(self, expression):
# if self._result is True: # if self._result is True:
# raise('Slither does not support nested ternary operator') # raise('Slither does not support nested ternary operator')
self._result = True self._result = True

@ -6,7 +6,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
key = 'LeftValue' key = "LeftValue"
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -14,11 +15,12 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class LeftValue(ExpressionVisitor):
class LeftValue(ExpressionVisitor):
def result(self): def result(self):
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))
@ -64,8 +66,8 @@ class LeftValue(ExpressionVisitor):
def _post_identifier(self, expression): def _post_identifier(self, expression):
if isinstance(expression.value, Variable): if isinstance(expression.value, Variable):
set_val(expression, [expression.value]) set_val(expression, [expression.value])
# elif isinstance(expression.value, SolidityInbuilt): # elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression]) # set_val(expression, [expression])
else: else:
set_val(expression, []) set_val(expression, [])

@ -1,4 +1,3 @@
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.expressions.assignment_operation import AssignmentOperationType
@ -6,7 +5,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.core.declarations.solidity_variables import SolidityVariable from slither.core.declarations.solidity_variables import SolidityVariable
key = 'ReadVar' key = "ReadVar"
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -14,17 +14,17 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class ReadVar(ExpressionVisitor):
class ReadVar(ExpressionVisitor):
def result(self): def result(self):
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))
return self._result return self._result
# overide assignement # overide assignement
# dont explore if its direct assignement (we explore if its +=, -=, ...) # dont explore if its direct assignement (we explore if its +=, -=, ...)
def _visit_assignement_operation(self, expression): def _visit_assignement_operation(self, expression):

@ -10,7 +10,8 @@ from slither.core.expressions.expression import Expression
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
key = 'RightValue' key = "RightValue"
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -18,11 +19,12 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class RightValue(ExpressionVisitor):
class RightValue(ExpressionVisitor):
def result(self): def result(self):
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))
@ -68,8 +70,8 @@ class RightValue(ExpressionVisitor):
def _post_identifier(self, expression): def _post_identifier(self, expression):
if isinstance(expression.value, Variable): if isinstance(expression.value, Variable):
set_val(expression, [expression.value]) set_val(expression, [expression.value])
# elif isinstance(expression.value, SolidityInbuilt): # elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression]) # set_val(expression, [expression])
else: else:
set_val(expression, []) set_val(expression, [])

@ -1,4 +1,3 @@
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.assignment_operation import AssignmentOperation
@ -10,7 +9,8 @@ from slither.core.expressions.member_access import MemberAccess
from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.index_access import IndexAccess
key = 'WriteVar' key = "WriteVar"
def get(expression): def get(expression):
val = expression.context[key] val = expression.context[key]
@ -18,11 +18,12 @@ def get(expression):
del expression.context[key] del expression.context[key]
return val return val
def set_val(expression, val): def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
class WriteVar(ExpressionVisitor):
class WriteVar(ExpressionVisitor):
def result(self): def result(self):
if self._result is None: if self._result is None:
self._result = list(set(get(self.expression))) self._result = list(set(get(self.expression)))
@ -71,27 +72,28 @@ class WriteVar(ExpressionVisitor):
set_val(expression, [expression]) set_val(expression, [expression])
else: else:
set_val(expression, []) set_val(expression, [])
# if isinstance(expression.value, Variable):
# set_val(expression, [expression.value]) # if isinstance(expression.value, Variable):
# else: # set_val(expression, [expression.value])
# set_val(expression, []) # else:
# set_val(expression, [])
def _post_index_access(self, expression): def _post_index_access(self, expression):
left = get(expression.expression_left) left = get(expression.expression_left)
right = get(expression.expression_right) right = get(expression.expression_right)
val = left + right val = left + right
if expression.is_lvalue: if expression.is_lvalue:
# val += [expression] # val += [expression]
val += [expression.expression_left] val += [expression.expression_left]
# n = expression.expression_left # n = expression.expression_left
# parse all the a.b[..].c[..] # parse all the a.b[..].c[..]
# while isinstance(n, (IndexAccess, MemberAccess)): # while isinstance(n, (IndexAccess, MemberAccess)):
# if isinstance(n, IndexAccess): # if isinstance(n, IndexAccess):
# val += [n.expression_left] # val += [n.expression_left]
# n = n.expression_left # n = n.expression_left
# else: # else:
# val += [n.expression] # val += [n.expression]
# n = n.expression # n = n.expression
set_val(expression, val) set_val(expression, val)
def _post_literal(self, expression): def _post_literal(self, expression):

Loading…
Cancel
Save