Merge slither/tools slither/visitors/expression from dev-0.7

pull/514/head
Josselin 4 years ago
parent 010d84125a
commit 4319bb3605
  1. 14
      slither/tools/demo/__main__.py
  2. 37
      slither/tools/erc_conformance/__main__.py
  3. 14
      slither/tools/erc_conformance/erc/erc20.py
  4. 95
      slither/tools/erc_conformance/erc/ercs.py
  5. 30
      slither/tools/kspec_coverage/__main__.py
  6. 77
      slither/tools/kspec_coverage/analysis.py
  7. 3
      slither/tools/kspec_coverage/kspec_coverage.py
  8. 21
      slither/tools/possible_paths/__main__.py
  9. 37
      slither/tools/possible_paths/possible_paths.py
  10. 82
      slither/tools/properties/__main__.py
  11. 8
      slither/tools/properties/addresses/address.py
  12. 6
      slither/tools/properties/platforms/echidna.py
  13. 68
      slither/tools/properties/platforms/truffle.py
  14. 130
      slither/tools/properties/properties/erc20.py
  15. 32
      slither/tools/properties/properties/ercs/erc20/properties/burn.py
  16. 75
      slither/tools/properties/properties/ercs/erc20/properties/initialization.py
  17. 19
      slither/tools/properties/properties/ercs/erc20/properties/mint.py
  18. 20
      slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py
  19. 207
      slither/tools/properties/properties/ercs/erc20/properties/transfer.py
  20. 24
      slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py
  21. 2
      slither/tools/properties/properties/properties.py
  22. 63
      slither/tools/properties/solidity/generate_properties.py
  23. 14
      slither/tools/properties/utils.py
  24. 67
      slither/tools/similarity/__main__.py
  25. 8
      slither/tools/similarity/cache.py
  26. 147
      slither/tools/similarity/encode.py
  27. 11
      slither/tools/similarity/info.py
  28. 41
      slither/tools/similarity/plot.py
  29. 1
      slither/tools/similarity/similarity.py
  30. 15
      slither/tools/similarity/test.py
  31. 31
      slither/tools/similarity/train.py
  32. 91
      slither/tools/slither_format/__main__.py
  33. 120
      slither/tools/slither_format/slither_format.py
  34. 128
      slither/tools/upgradeability/__main__.py
  35. 90
      slither/tools/upgradeability/checks/abstract_checks.py
  36. 20
      slither/tools/upgradeability/checks/all_checks.py
  37. 55
      slither/tools/upgradeability/checks/constant.py
  38. 83
      slither/tools/upgradeability/checks/functions_ids.py
  39. 173
      slither/tools/upgradeability/checks/initialization.py
  40. 22
      slither/tools/upgradeability/checks/variable_initialization.py
  41. 108
      slither/tools/upgradeability/checks/variables_order.py
  42. 88
      slither/tools/upgradeability/utils/command_line.py
  43. 12
      slither/visitors/expression/constants_folding.py
  44. 7
      slither/visitors/expression/export_values.py
  45. 20
      slither/visitors/expression/expression.py
  46. 21
      slither/visitors/expression/expression_printer.py
  47. 12
      slither/visitors/expression/find_calls.py
  48. 8
      slither/visitors/expression/find_push.py
  49. 7
      slither/visitors/expression/has_conditional.py
  50. 10
      slither/visitors/expression/left_value.py
  51. 8
      slither/visitors/expression/read_var.py
  52. 10
      slither/visitors/expression/right_value.py
  53. 16
      slither/visitors/expression/write_var.py

@ -9,16 +9,17 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-demo")
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Demo',
usage='slither-demo filename')
parser = argparse.ArgumentParser(description="Demo", usage="slither-demo filename")
parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
@ -32,7 +33,8 @@ def main():
# Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args))
logger.info('Analysis done!')
logger.info("Analysis done!")
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -17,28 +17,29 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
ADDITIONAL_CHECKS = {
"ERC20": check_erc20
}
ADDITIONAL_CHECKS = {"ERC20": check_erc20}
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Check the ERC 20 conformance',
usage='slither-erc project contractName')
parser = argparse.ArgumentParser(
description="Check the ERC 20 conformance", usage="slither-erc project contractName"
)
parser.add_argument('project',
help='The codebase to be tested.')
parser.add_argument("project", help="The codebase to be tested.")
parser.add_argument('contract_name',
help='The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.')
parser.add_argument(
"contract_name",
help="The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.",
)
parser.add_argument(
"--erc",
@ -47,22 +48,26 @@ def parse_args():
default="erc20",
)
parser.add_argument('--json',
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store',
default=False)
action="store",
default=False,
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
return parser.parse_args()
def _log_error(err, args):
if args.json:
output_to_json(args.json, str(err), {"upgradeability-check": []})
logger.error(err)
def main():
args = parse_args()
@ -76,7 +81,7 @@ def main():
contract = slither.get_contract_from_name(args.contract_name)
if not contract:
err = f'Contract not found: {args.contract_name}'
err = f"Contract not found: {args.contract_name}"
_log_error(err, args)
return
# First elem is the function, second is the event
@ -87,7 +92,7 @@ def main():
ADDITIONAL_CHECKS[args.erc.upper()](contract, ret)
else:
err = f'Incorrect ERC selected {args.erc}'
err = f"Incorrect ERC selected {args.erc}"
_log_error(err, args)
return
@ -95,5 +100,5 @@ def main():
output_to_json(args.json, None, {"upgradeability-check": ret})
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -6,21 +6,25 @@ logger = logging.getLogger("Slither-conformance")
def approval_race_condition(contract, ret):
increaseAllowance = contract.get_function_from_signature('increaseAllowance(address,uint256)')
increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)")
if not increaseAllowance:
increaseAllowance = contract.get_function_from_signature('safeIncreaseAllowance(address,uint256)')
increaseAllowance = contract.get_function_from_signature(
"safeIncreaseAllowance(address,uint256)"
)
if increaseAllowance:
txt = f'\t[✓] {contract.name} has {increaseAllowance.full_name}'
txt = f"\t[✓] {contract.name} has {increaseAllowance.full_name}"
logger.info(txt)
else:
txt = f'\t[ ] {contract.name} is not protected for the ERC20 approval race condition'
txt = f"\t[ ] {contract.name} is not protected for the ERC20 approval race condition"
logger.info(txt)
lack_of_erc20_race_condition_protection = output.Output(txt)
lack_of_erc20_race_condition_protection.add(contract)
ret["lack_of_erc20_race_condition_protection"].append(lack_of_erc20_race_condition_protection.data)
ret["lack_of_erc20_race_condition_protection"].append(
lack_of_erc20_race_condition_protection.data
)
def check_erc20(contract, ret, explored=None):

@ -22,13 +22,15 @@ def _check_signature(erc_function, contract, ret):
# The check on state variable is needed until we have a better API to handle state variable getters
state_variable_as_function = contract.get_state_variable_from_name(name)
if not state_variable_as_function or not state_variable_as_function.visibility in ['public', 'external']:
if not state_variable_as_function or not state_variable_as_function.visibility in [
"public",
"external",
]:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt)
missing_func = output.Output(txt, additional_fields={
"function": sig,
"required": required
})
missing_func = output.Output(
txt, additional_fields={"function": sig, "required": required}
)
missing_func.add(contract)
ret["missing_function"].append(missing_func.data)
return
@ -38,10 +40,9 @@ def _check_signature(erc_function, contract, ret):
if types != parameters:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt)
missing_func = output.Output(txt, additional_fields={
"function": sig,
"required": required
})
missing_func = output.Output(
txt, additional_fields={"function": sig, "required": required}
)
missing_func.add(contract)
ret["missing_function"].append(missing_func.data)
return
@ -53,45 +54,51 @@ def _check_signature(erc_function, contract, ret):
function_return_type = function.return_type
function_view = function.view
txt = f'[✓] {sig} is present'
txt = f"[✓] {sig} is present"
logger.info(txt)
if function_return_type:
function_return_type = ','.join([str(x) for x in function_return_type])
function_return_type = ",".join([str(x) for x in function_return_type])
if function_return_type == return_type:
txt = f'\t[✓] {sig} -> () (correct return value)'
txt = f"\t[✓] {sig} -> () (correct return value)"
logger.info(txt)
else:
txt = f'\t[ ] {sig} -> () should return {return_type}'
txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={
incorrect_return = output.Output(
txt,
additional_fields={
"expected_return_type": return_type,
"actual_return_type": function_return_type
})
"actual_return_type": function_return_type,
},
)
incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data)
elif not return_type:
txt = f'\t[✓] {sig} -> () (correct return type)'
txt = f"\t[✓] {sig} -> () (correct return type)"
logger.info(txt)
else:
txt = f'\t[ ] {sig} -> () should return {return_type}'
txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={
incorrect_return = output.Output(
txt,
additional_fields={
"expected_return_type": return_type,
"actual_return_type": function_return_type
})
"actual_return_type": function_return_type,
},
)
incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data)
if view:
if function_view:
txt = f'\t[✓] {sig} is view'
txt = f"\t[✓] {sig} is view"
logger.info(txt)
else:
txt = f'\t[ ] {sig} should be view'
txt = f"\t[ ] {sig} should be view"
logger.info(txt)
should_be_view = output.Output(txt)
@ -103,12 +110,12 @@ def _check_signature(erc_function, contract, ret):
event_sig = f'{event.name}({",".join(event.parameters)})'
if not function:
txt = f'\t[ ] Must emit be view {event_sig}'
txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={
"missing_event": event_sig
})
missing_event_emmited = output.Output(
txt, additional_fields={"missing_event": event_sig}
)
missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -121,15 +128,15 @@ def _check_signature(erc_function, contract, ret):
event_found = True
break
if event_found:
txt = f'\t[✓] {event_sig} is emitted'
txt = f"\t[✓] {event_sig} is emitted"
logger.info(txt)
else:
txt = f'\t[ ] Must emit be view {event_sig}'
txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={
"missing_event": event_sig
})
missing_event_emmited = output.Output(
txt, additional_fields={"missing_event": event_sig}
)
missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -143,31 +150,27 @@ def _check_events(erc_event, contract, ret):
event = contract.get_event_from_signature(sig)
if not event:
txt = f'[ ] {sig} is missing'
txt = f"[ ] {sig} is missing"
logger.info(txt)
missing_event = output.Output(txt, additional_fields={
"event": sig
})
missing_event = output.Output(txt, additional_fields={"event": sig})
missing_event.add(contract)
ret["missing_event"].append(missing_event.data)
return
txt = f'[✓] {sig} is present'
txt = f"[✓] {sig} is present"
logger.info(txt)
for i, index in enumerate(indexes):
if index:
if event.elems[i].indexed:
txt = f'\t[✓] parameter {i} is indexed'
txt = f"\t[✓] parameter {i} is indexed"
logger.info(txt)
else:
txt = f'\t[ ] parameter {i} should be indexed'
txt = f"\t[ ] parameter {i} should be indexed"
logger.info(txt)
missing_event_index = output.Output(txt, additional_fields={
"missing_index": i
})
missing_event_index = output.Output(txt, additional_fields={"missing_index": i})
missing_event_index.add_event(event)
ret["missing_event_index"].append(missing_event_index.data)
@ -179,16 +182,16 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None):
explored.add(contract)
logger.info(f'# Check {contract.name}\n')
logger.info(f"# Check {contract.name}\n")
logger.info(f'## Check functions')
logger.info(f"## Check functions")
for erc_function in erc_functions:
_check_signature(erc_function, contract, ret)
logger.info(f'\n## Check events')
logger.info(f"\n## Check events")
for erc_event in erc_events:
_check_events(erc_event, contract, ret)
logger.info('\n')
logger.info("\n")
for derived_contract in contract.derived_contracts:
generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored)

@ -11,27 +11,36 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='slither-kspec-coverage',
usage='slither-kspec-coverage contract.sol kspec.md')
parser = argparse.ArgumentParser(
description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md"
)
parser.add_argument('contract', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('kspec', help='The filename of the Klab spec markdown for the analyzed contract(s)')
parser.add_argument(
"contract", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument(
"kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)"
)
parser.add_argument('--version', help='displays the current version', version='0.1.0',action='version')
parser.add_argument('--json',
parser.add_argument(
"--version", help="displays the current version", version="0.1.0", action="version"
)
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store',
default=False
action="store",
default=False,
)
cryticparser.init(parser)
@ -54,5 +63,6 @@ def main():
kspec_coverage(args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -7,25 +7,22 @@ from slither.utils.colors import yellow, green, red
from slither.utils import output
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('Slither.kspec')
logger = logging.getLogger("Slither.kspec")
def _refactor_type(type):
return {
'uint': 'uint256',
'int': 'int256'
}.get(type, type)
return {"uint": "uint256", "int": "int256"}.get(type, type)
def _get_all_covered_kspec_functions(target):
# Create a set of our discovered functions which are covered
covered_functions = set()
BEHAVIOUR_PATTERN = re.compile('behaviour\s+(\S+)\s+of\s+(\S+)')
INTERFACE_PATTERN = re.compile('interface\s+([^\r\n]+)')
BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)")
INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)")
# Read the file contents
with open(target, 'r', encoding='utf8') as target_file:
with open(target, "r", encoding="utf8") as target_file:
lines = target_file.readlines()
# Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex,
@ -38,10 +35,12 @@ def _get_all_covered_kspec_functions(target):
match = INTERFACE_PATTERN.match(lines[i + 1])
if match:
function_full_name = match.groups()[0]
start, end = function_full_name.index('(') + 1, function_full_name.index(')')
function_arguments = function_full_name[start:end].split(',')
function_arguments = [_refactor_type(arg.strip().split(' ')[0]) for arg in function_arguments]
function_full_name = function_full_name[:start] + ','.join(function_arguments) + ')'
start, end = function_full_name.index("(") + 1, function_full_name.index(")")
function_arguments = function_full_name[start:end].split(",")
function_arguments = [
_refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments
]
function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")"
covered_functions.add((contract_name, function_full_name))
i += 1
i += 1
@ -50,14 +49,25 @@ def _get_all_covered_kspec_functions(target):
def _get_slither_functions(slither):
# Use contract == contract_declarer to avoid dupplicate
all_functions_declared = [f for f in slither.functions if (f.contract == f.contract_declarer
all_functions_declared = [
f
for f in slither.functions
if (
f.contract == f.contract_declarer
and f.is_implemented
and not f.is_constructor
and not f.is_constructor_variables)]
and not f.is_constructor_variables
)
]
# Use list(set()) because same state variable instances can be shared accross contracts
# TODO: integrate state variables
all_functions_declared += list(set([s for s in slither.state_variables if s.visibility in ['public', 'external']]))
slither_functions = {(function.contract.name, function.full_name): function for function in all_functions_declared}
all_functions_declared += list(
set([s for s in slither.state_variables if s.visibility in ["public", "external"]])
)
slither_functions = {
(function.contract.name, function.full_name): function
for function in all_functions_declared
}
return slither_functions
@ -110,35 +120,42 @@ def _run_coverage_analysis(args, slither, kspec_functions):
else:
kspec_missing.append(slither_func)
logger.info('## Check for functions coverage')
logger.info("## Check for functions coverage")
json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json)
json_kspec_missing_functions = _generate_output([f for f in kspec_missing if isinstance(f, Function)],
json_kspec_missing_functions = _generate_output(
[f for f in kspec_missing if isinstance(f, Function)],
"[ ] (Missing function)",
red,
args.json)
json_kspec_missing_variables = _generate_output([f for f in kspec_missing if isinstance(f, Variable)],
args.json,
)
json_kspec_missing_variables = _generate_output(
[f for f in kspec_missing if isinstance(f, Variable)],
"[ ] (Missing variable)",
yellow,
args.json)
json_kspec_unresolved = _generate_output_unresolved(kspec_functions_unresolved,
"[ ] (Unresolved)",
yellow,
args.json)
args.json,
)
json_kspec_unresolved = _generate_output_unresolved(
kspec_functions_unresolved, "[ ] (Unresolved)", yellow, args.json
)
# Handle unresolved kspecs
if args.json:
output.output_to_json(args.json, None, {
output.output_to_json(
args.json,
None,
{
"functions_present": json_kspec_present,
"functions_missing": json_kspec_missing_functions,
"variables_missing": json_kspec_missing_variables,
"functions_unresolved": json_kspec_unresolved
})
"functions_unresolved": json_kspec_unresolved,
},
)
def run_analysis(args, slither, kspec):
# Get all of our kspec'd functions (tuple(contract_name, function_name)).
if ',' in kspec:
kspecs = kspec.split(',')
if "," in kspec:
kspecs = kspec.split(",")
kspec_functions = set()
for kspec in kspecs:
kspec_functions |= _get_all_covered_kspec_functions(kspec)

@ -1,6 +1,7 @@
from slither.tools.kspec_coverage.analysis import run_analysis
from slither import Slither
def kspec_coverage(args):
contract = args.contract
@ -10,5 +11,3 @@ def kspec_coverage(args):
# Run the analysis on the Klab specs
run_analysis(args, slither, kspec)

@ -9,18 +9,21 @@ from crytic_compile import cryticparser
logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO)
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='PossiblePaths',
usage='possible_paths.py filename [contract.function targets]')
parser = argparse.ArgumentParser(
description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]"
)
parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument('targets', nargs='+')
parser.add_argument("targets", nargs="+")
cryticparser.init(parser)
@ -62,12 +65,16 @@ def main():
print("\n")
# Format all function paths.
reaching_paths_str = [' -> '.join([f"{f.canonical_name}" for f in reaching_path]) for reaching_path in reaching_paths]
reaching_paths_str = [
" -> ".join([f"{f.canonical_name}" for f in reaching_path])
for reaching_path in reaching_paths
]
# Print a sorted list of all function paths which can reach the targets.
print(f"The following paths reach the specified targets:")
for reaching_path in sorted(reaching_paths_str):
print(f"{reaching_path}\n")
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -1,4 +1,5 @@
class ResolveFunctionException(Exception): pass
class ResolveFunctionException(Exception):
pass
def resolve_function(slither, contract_name, function_name):
@ -16,11 +17,15 @@ def resolve_function(slither, contract_name, function_name):
raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}")
# Obtain the target function
target_function = next((function for function in contract.functions if function.name == function_name), None)
target_function = next(
(function for function in contract.functions if function.name == function_name), None
)
# Verify we have resolved the function specified.
if target_function is None:
raise ResolveFunctionException(f"Could not resolve target function: {contract_name}.{function_name}")
raise ResolveFunctionException(
f"Could not resolve target function: {contract_name}.{function_name}"
)
# Add the resolved function to the new list.
return target_function
@ -44,17 +49,23 @@ def resolve_functions(slither, functions):
for item in functions:
if isinstance(item, str):
# If the item is a single string, we assume it is of form 'ContractName.FunctionName'.
parts = item.split('.')
parts = item.split(".")
if len(parts) < 2:
raise ResolveFunctionException("Provided string descriptor must be of form 'ContractName.FunctionName'")
raise ResolveFunctionException(
"Provided string descriptor must be of form 'ContractName.FunctionName'"
)
resolved.append(resolve_function(slither, parts[0], parts[1]))
elif isinstance(item, tuple):
# If the item is a tuple, it should be a 2-tuple providing contract and function names.
if len(item) != 2:
raise ResolveFunctionException("Provided tuple descriptor must provide a contract and function name.")
raise ResolveFunctionException(
"Provided tuple descriptor must provide a contract and function name."
)
resolved.append(resolve_function(slither, item[0], item[1]))
else:
raise ResolveFunctionException(f"Unexpected function descriptor type to resolve in list: {type(item)}")
raise ResolveFunctionException(
f"Unexpected function descriptor type to resolve in list: {type(item)}"
)
# Return the resolved list.
return resolved
@ -66,9 +77,12 @@ def all_function_definitions(function):
:param function: The function to obtain all definitions at and beneath.
:return: Returns a list composed of the provided function definition and any base definitions.
"""
return [function] + [f for c in function.contract.inheritance
return [function] + [
f
for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name]
if f.full_name == function.full_name
]
def __find_target_paths(slither, target_function, current_path=[]):
@ -102,7 +116,7 @@ def __find_target_paths(slither, target_function, current_path=[]):
results = results.union(path_results)
# If this path is external accessible from this point, we add the current path to the list.
if target_function.visibility in ['public', 'external'] and len(current_path) > 1:
if target_function.visibility in ["public", "external"] and len(current_path) > 1:
results.add(tuple(current_path))
return results
@ -122,6 +136,3 @@ def find_target_paths(slither, target_functions):
results = results.union(__find_target_paths(slither, target_function))
return results

@ -16,20 +16,21 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither")
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
def _all_scenarios():
txt = '\n'
txt += '#################### ERC20 ####################\n'
txt = "\n"
txt += "#################### ERC20 ####################\n"
for k, value in ERC20_PROPERTIES.items():
txt += f'{k} - {value.description}\n'
txt += f"{k} - {value.description}\n"
return txt
def _all_properties():
table = MyPrettyTable(["Num", "Description", "Scenario"])
idx = 0
@ -39,6 +40,7 @@ def _all_properties():
idx = idx + 1
return table
class ListScenarios(argparse.Action):
def __call__(self, parser, *args, **kwargs):
logger.info(_all_scenarios())
@ -56,43 +58,51 @@ def parse_args():
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Demo',
usage='slither-demo filename',
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('--contract',
help='The targeted contract.')
parser.add_argument('--scenario',
help=f'Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable',
default='Transferable')
parser.add_argument('--list-scenarios',
help='List available scenarios',
parser = argparse.ArgumentParser(
description="Demo",
usage="slither-demo filename",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument("--contract", help="The targeted contract.")
parser.add_argument(
"--scenario",
help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable",
default="Transferable",
)
parser.add_argument(
"--list-scenarios",
help="List available scenarios",
action=ListScenarios,
nargs=0,
default=False)
default=False,
)
parser.add_argument('--list-properties',
help='List available properties',
parser.add_argument(
"--list-properties",
help="List available properties",
action=ListProperties,
nargs=0,
default=False)
default=False,
)
parser.add_argument('--address-owner',
help=f'Owner address. Default {OWNER_ADDRESS}',
default=None)
parser.add_argument(
"--address-owner", help=f"Owner address. Default {OWNER_ADDRESS}", default=None
)
parser.add_argument('--address-user',
help=f'Owner address. Default {USER_ADDRESS}',
default=None)
parser.add_argument(
"--address-user", help=f"Owner address. Default {USER_ADDRESS}", default=None
)
parser.add_argument('--address-attacker',
help=f'Attacker address. Default {ATTACKER_ADDRESS}',
default=None)
parser.add_argument(
"--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
@ -116,9 +126,9 @@ def main():
contract = slither.contracts[0]
else:
if args.contract is None:
logger.error(f'Specify the target: --contract ContractName')
logger.error(f"Specify the target: --contract ContractName")
else:
logger.error(f'{args.contract} not found')
logger.error(f"{args.contract} not found")
return
addresses = Addresses(args.address_owner, args.address_user, args.address_attacker)
@ -126,5 +136,5 @@ def main():
generate_erc20(contract, args.scenario, addresses)
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -8,8 +8,12 @@ ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef"
class Addresses:
def __init__(self, owner: Optional[str] = None, user: Optional[str] = None, attacker: Optional[str] = None):
def __init__(
self,
owner: Optional[str] = None,
user: Optional[str] = None,
attacker: Optional[str] = None,
):
self.owner = owner if owner else OWNER_ADDRESS
self.user = user if user else USER_ADDRESS
self.attacker = attacker if attacker else ATTACKER_ADDRESS

@ -11,11 +11,11 @@ def generate_echidna_config(output_dir: Path, addresses: Addresses) -> str:
:param addresses:
:return:
"""
content = 'prefix: crytic_\n'
content = "prefix: crytic_\n"
content += f'deployer: "{addresses.owner}"\n'
content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n'
content += f'psender: "{addresses.user}"\n'
content += 'coverage: true\n'
filename = 'echidna_config.yaml'
content += "coverage: true\n"
filename = "echidna_config.yaml"
write_file(output_dir, filename, content)
return filename

@ -7,21 +7,21 @@ from slither.tools.properties.addresses.address import Addresses
from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller
from slither.tools.properties.utils import write_file
PATTERN_TRUFFLE_MIGRATION = re.compile('^[0-9]*_')
PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_")
logger = logging.getLogger("Slither")
def _extract_caller(p: PropertyCaller):
if p == PropertyCaller.OWNER:
return ['owner']
return ["owner"]
if p == PropertyCaller.SENDER:
return ['user']
return ["user"]
if p == PropertyCaller.ATTACKER:
return ['attacker']
return ["attacker"]
if p == PropertyCaller.ALL:
return ['owner', 'user', 'attacker']
return ["owner", "user", "attacker"]
assert p == PropertyCaller.ANY
return ['user']
return ["user"]
def _helpers():
@ -31,7 +31,7 @@ def _helpers():
- catchRevertThrow: check if the call revert/throw
:return:
"""
return '''
return """
async function catchRevertThrowReturnFalse(promise) {
try {
const ret = await promise;
@ -61,12 +61,17 @@ async function catchRevertThrow(promise) {
}
assert(false, "Expected revert/throw/or return false");
};
'''
"""
def generate_unit_test(test_contract: str, filename: str,
unit_tests: List[Property], output_dir: Path,
addresses: Addresses, assert_message: str = ''):
def generate_unit_test(
test_contract: str,
filename: str,
unit_tests: List[Property],
output_dir: Path,
addresses: Addresses,
assert_message: str = "",
):
"""
Generate unit tests files
:param test_contract:
@ -88,37 +93,37 @@ def generate_unit_test(test_contract: str, filename: str,
content += f'\tlet attacker = "{addresses.attacker}";\n'
for unit_test in unit_tests:
content += f'\tit("{unit_test.description}", async () => {{\n'
content += f'\t\tlet instance = await {test_contract}.deployed();\n'
content += f"\t\tlet instance = await {test_contract}.deployed();\n"
callers = _extract_caller(unit_test.caller)
if unit_test.return_type == PropertyReturn.SUCCESS:
for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n'
content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message:
content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n'
else:
content += f'\t\tassert.equal(test_{caller}, true);\n'
content += f"\t\tassert.equal(test_{caller}, true);\n"
elif unit_test.return_type == PropertyReturn.FAIL:
for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n'
content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message:
content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n'
else:
content += f'\t\tassert.equal(test_{caller}, false);\n'
content += f"\t\tassert.equal(test_{caller}, false);\n"
elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW:
for caller in callers:
content += f'\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n'
content += f"\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
elif unit_test.return_type == PropertyReturn.THROW:
callers = _extract_caller(unit_test.caller)
for caller in callers:
content += f'\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n'
content += '\t});\n'
content += f"\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
content += "\t});\n"
content += '});\n'
content += "});\n"
output_dir = Path(output_dir, 'test')
output_dir = Path(output_dir, "test")
output_dir.mkdir(exist_ok=True)
output_dir = Path(output_dir, 'crytic')
output_dir = Path(output_dir, "crytic")
output_dir.mkdir(exist_ok=True)
write_file(output_dir, filename, content)
@ -133,28 +138,31 @@ def generate_migration(test_contract: str, output_dir: Path, owner_address: str)
:param owner_address:
:return:
"""
content = f'''{test_contract} = artifacts.require("{test_contract}");
content = f"""{test_contract} = artifacts.require("{test_contract}");
module.exports = function(deployer) {{
deployer.deploy({test_contract}, {{from: "{owner_address}"}});
}};
'''
"""
output_dir = Path(output_dir, 'migrations')
output_dir = Path(output_dir, "migrations")
output_dir.mkdir(exist_ok=True)
migration_files = [js_file for js_file in output_dir.iterdir() if js_file.suffix == '.js'
and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)]
migration_files = [
js_file
for js_file in output_dir.iterdir()
if js_file.suffix == ".js" and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)
]
idx = len(migration_files)
filename = f'{idx + 1}_{test_contract}.js'
potential_previous_filename = f'{idx}_{test_contract}.js'
filename = f"{idx + 1}_{test_contract}.js"
potential_previous_filename = f"{idx}_{test_contract}.js"
for m in migration_files:
if m.name == potential_previous_filename:
write_file(output_dir, potential_previous_filename, content)
return
if test_contract in m.name:
logger.error(f'Potential conflicts with {m.name}')
logger.error(f"Potential conflicts with {m.name}")
write_file(output_dir, filename, content)

@ -12,27 +12,37 @@ from slither.tools.properties.platforms.echidna import generate_echidna_config
from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable
from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG
from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable
from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ERC20_NotMintableNotBurnable
from slither.tools.properties.properties.ercs.erc20.properties.transfer import ERC20_Transferable, ERC20_Pausable
from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import (
ERC20_NotMintableNotBurnable,
)
from slither.tools.properties.properties.ercs.erc20.properties.transfer import (
ERC20_Transferable,
ERC20_Pausable,
)
from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test
from slither.tools.properties.properties.properties import property_to_solidity, Property
from slither.tools.properties.solidity.generate_properties import generate_solidity_properties, generate_test_contract, \
generate_solidity_interface
from slither.tools.properties.solidity.generate_properties import (
generate_solidity_properties,
generate_test_contract,
generate_solidity_interface,
)
from slither.utils.colors import red, green
logger = logging.getLogger("Slither")
PropertyDescription = namedtuple('PropertyDescription', ['properties', 'description'])
PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"])
ERC20_PROPERTIES = {
"Transferable": PropertyDescription(ERC20_Transferable, 'Test the correct tokens transfer'),
"Pausable": PropertyDescription(ERC20_Pausable, 'Test the pausable functionality'),
"NotMintable": PropertyDescription(ERC20_NotMintable, 'Test that no one can mint tokens'),
"NotMintableNotBurnable": PropertyDescription(ERC20_NotMintableNotBurnable,
'Test that no one can mint or burn tokens'),
"NotBurnable": PropertyDescription(ERC20_NotBurnable, 'Test that no one can burn tokens'),
"Burnable": PropertyDescription(ERC20_NotBurnable,
'Test the burn of tokens. Require the "burn(address) returns()" function')
"Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"),
"Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"),
"NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"),
"NotMintableNotBurnable": PropertyDescription(
ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens"
),
"NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"),
"Burnable": PropertyDescription(
ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function'
),
}
@ -54,7 +64,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
:return:
"""
if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]:
logging.error(f'{contract.slither.crytic_compile.type} not yet supported by slither-prop')
logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop")
return
# Check if the contract is an ERC20 contract and if the functions have the correct visibility
@ -65,7 +75,9 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
properties = ERC20_PROPERTIES.get(type_property, None)
if properties is None:
logger.error(f'{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}')
logger.error(
f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}"
)
return
properties = properties.properties
@ -78,51 +90,53 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
# Generate the contract containing the properties
generate_solidity_interface(output_dir, addresses)
property_file = generate_solidity_properties(contract, type_property, solidity_properties, output_dir)
property_file = generate_solidity_properties(
contract, type_property, solidity_properties, output_dir
)
# Generate the Test contract
initialization_recommendation = _initialization_recommendation(type_property)
contract_filename, contract_name = generate_test_contract(contract,
type_property,
output_dir,
property_file,
initialization_recommendation)
contract_filename, contract_name = generate_test_contract(
contract, type_property, output_dir, property_file, initialization_recommendation
)
# Generate Echidna config file
echidna_config_filename = generate_echidna_config(Path(contract.slither.crytic_compile.target).parent, addresses)
echidna_config_filename = generate_echidna_config(
Path(contract.slither.crytic_compile.target).parent, addresses
)
unit_test_info = ''
unit_test_info = ""
# If truffle, generate unit tests
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses)
logger.info('################################################')
logger.info(green(f'Update the constructor in {Path(output_dir, contract_filename)}'))
logger.info("################################################")
logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}"))
if unit_test_info:
logger.info(green(unit_test_info))
logger.info(green('To run Echidna:'))
txt = f'\t echidna-test {contract.slither.crytic_compile.target} '
txt += f'--contract {contract_name} --config {echidna_config_filename}'
logger.info(green("To run Echidna:"))
txt = f"\t echidna-test {contract.slither.crytic_compile.target} "
txt += f"--contract {contract_name} --config {echidna_config_filename}"
logger.info(green(txt))
def _initialization_recommendation(type_property: str) -> str:
content = ''
content += '\t\t// Add below a minimal configuration:\n'
content += '\t\t// - crytic_owner must have some tokens \n'
content += '\t\t// - crytic_user must have some tokens \n'
content += '\t\t// - crytic_attacker must have some tokens \n'
if type_property in ['Pausable']:
content += '\t\t// - The contract must be paused \n'
if type_property in ['NotMintable', 'NotMintableNotBurnable']:
content += '\t\t// - The contract must not be mintable \n'
if type_property in ['NotBurnable', 'NotMintableNotBurnable']:
content += '\t\t// - The contract must not be burnable \n'
content += '\n'
content += '\n'
content = ""
content += "\t\t// Add below a minimal configuration:\n"
content += "\t\t// - crytic_owner must have some tokens \n"
content += "\t\t// - crytic_user must have some tokens \n"
content += "\t\t// - crytic_attacker must have some tokens \n"
if type_property in ["Pausable"]:
content += "\t\t// - The contract must be paused \n"
if type_property in ["NotMintable", "NotMintableNotBurnable"]:
content += "\t\t// - The contract must not be mintable \n"
if type_property in ["NotBurnable", "NotMintableNotBurnable"]:
content += "\t\t// - The contract must not be burnable \n"
content += "\n"
content += "\n"
return content
@ -130,44 +144,44 @@ def _initialization_recommendation(type_property: str) -> str:
# TODO: move this to crytic-compile
def _platform_to_output_dir(platform: AbstractPlatform) -> Path:
if platform.TYPE == PlatformType.TRUFFLE:
return Path(platform.target, 'contracts', 'crytic')
return Path(platform.target, "contracts", "crytic")
if platform.TYPE == PlatformType.SOLC:
return Path(platform.target).parent
def _check_compatibility(contract):
errors = ''
errors = ""
if not contract.is_erc20():
errors = f'{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc'
errors = f"{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc"
return errors
transfer = contract.get_function_from_signature('transfer(address,uint256)')
transfer = contract.get_function_from_signature("transfer(address,uint256)")
if transfer.visibility != 'public':
errors = f'slither-prop requires {transfer.canonical_name} to be public. Please change the visibility'
if transfer.visibility != "public":
errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility"
transfer_from = contract.get_function_from_signature('transferFrom(address,address,uint256)')
if transfer_from.visibility != 'public':
transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)")
if transfer_from.visibility != "public":
if errors:
errors += '\n'
errors += f'slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility'
errors += "\n"
errors += f"slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility"
approve = contract.get_function_from_signature('approve(address,uint256)')
if approve.visibility != 'public':
approve = contract.get_function_from_signature("approve(address,uint256)")
if approve.visibility != "public":
if errors:
errors += '\n'
errors += f'slither-prop requires {approve.canonical_name} to be public. Please change the visibility'
errors += "\n"
errors += f"slither-prop requires {approve.canonical_name} to be public. Please change the visibility"
return errors
def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]:
solidity_properties = ''
solidity_properties = ""
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
solidity_properties += '\n'.join([property_to_solidity(p) for p in ERC20_CONFIG])
solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG])
solidity_properties += '\n'.join([property_to_solidity(p) for p in properties])
solidity_properties += "\n".join([property_to_solidity(p) for p in properties])
unit_tests = [p for p in properties if p.is_unit_test]
return solidity_properties, unit_tests

@ -1,30 +1,38 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotBurnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()',
description='The total supply does not decrease.',
content='''
\t\treturn initialTotalSupply == this.totalSupply();''',
Property(
name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
description="The total supply does not decrease.",
content="""
\t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
caller=PropertyCaller.ANY,
),
]
# Require burn(address) returns()
ERC20_Burnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()',
description='Cannot burn more than available balance',
content='''
Property(
name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
description="Cannot burn more than available balance",
content="""
\t\tuint balance = balanceOf(msg.sender);
\t\tburn(balance + 1);
\t\treturn false;''',
\t\treturn false;""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL)
caller=PropertyCaller.ALL,
)
]

@ -1,65 +1,76 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_CONFIG = [
Property(name='init_total_supply()',
description='The total supply is correctly initialized.',
content='''
\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;''',
Property(
name="init_total_supply()",
description="The total supply is correctly initialized.",
content="""
\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_owner_balance()',
caller=PropertyCaller.ANY,
),
Property(
name="init_owner_balance()",
description="Owner's balance is correctly initialized.",
content='''
\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);''',
content="""
\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_user_balance()',
caller=PropertyCaller.ANY,
),
Property(
name="init_user_balance()",
description="User's balance is correctly initialized.",
content='''
\t\treturn initialBalance_user == this.balanceOf(crytic_user);''',
content="""
\t\treturn initialBalance_user == this.balanceOf(crytic_user);""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_attacker_balance()',
caller=PropertyCaller.ANY,
),
Property(
name="init_attacker_balance()",
description="Attacker's balance is correctly initialized.",
content='''
\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);''',
content="""
\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_caller_balance()',
caller=PropertyCaller.ANY,
),
Property(
name="init_caller_balance()",
description="All the users have a positive balance.",
content='''
\t\treturn this.balanceOf(msg.sender) >0 ;''',
content="""
\t\treturn this.balanceOf(msg.sender) >0 ;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ALL),
caller=PropertyCaller.ALL,
),
# Note: there is a potential overflow on the addition, but we dont consider it
Property(name='init_total_supply_is_balances()',
Property(
name="init_total_supply_is_balances()",
description="The total supply is the user and owner balance.",
content='''
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();''',
content="""
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
caller=PropertyCaller.ANY,
),
]

@ -1,13 +1,20 @@
from slither.tools.properties.properties.properties import PropertyType, PropertyReturn, Property, PropertyCaller
from slither.tools.properties.properties.properties import (
PropertyType,
PropertyReturn,
Property,
PropertyCaller,
)
ERC20_NotMintable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotMintable()',
description='The total supply does not increase.',
content='''
\t\treturn initialTotalSupply >= totalSupply();''',
Property(
name="crytic_supply_constant_ERC20PropertiesNotMintable()",
description="The total supply does not increase.",
content="""
\t\treturn initialTotalSupply >= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
caller=PropertyCaller.ANY,
),
]

@ -1,14 +1,20 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotMintableNotBurnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()',
description='The total supply does not change.',
content='''
\t\treturn initialTotalSupply == this.totalSupply();''',
Property(
name="crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()",
description="The total supply does not change.",
content="""
\t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
caller=PropertyCaller.ANY,
),
]

@ -1,96 +1,108 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_Transferable = [
Property(name='crytic_zero_always_empty_ERC20Properties()',
description='The address 0x0 should not receive tokens.',
content='''
\t\treturn this.balanceOf(address(0x0)) == 0;''',
Property(
name="crytic_zero_always_empty_ERC20Properties()",
description="The address 0x0 should not receive tokens.",
content="""
\t\treturn this.balanceOf(address(0x0)) == 0;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
Property(name='crytic_approve_overwrites()',
description='Allowance can be changed.',
content='''
caller=PropertyCaller.ANY,
),
Property(
name="crytic_approve_overwrites()",
description="Allowance can be changed.",
content="""
\t\tbool approve_return;
\t\tapprove_return = approve(crytic_user, 10);
\t\trequire(approve_return);
\t\tapprove_return = approve(crytic_user, 20);
\t\trequire(approve_return);
\t\treturn this.allowance(msg.sender, crytic_user) == 20;''',
\t\treturn this.allowance(msg.sender, crytic_user) == 20;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_less_than_total_ERC20Properties()',
description='Balance of one user must be less or equal to the total supply.',
content='''
\t\treturn this.balanceOf(msg.sender) <= totalSupply();''',
caller=PropertyCaller.ALL,
),
Property(
name="crytic_less_than_total_ERC20Properties()",
description="Balance of one user must be less or equal to the total supply.",
content="""
\t\treturn this.balanceOf(msg.sender) <= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_totalSupply_consistant_ERC20Properties()',
description='Balance of the crytic users must be less or equal to the total supply.',
content='''
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();''',
caller=PropertyCaller.ALL,
),
Property(
name="crytic_totalSupply_consistant_ERC20Properties()",
description="Balance of the crytic users must be less or equal to the total supply.",
content="""
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
Property(name='crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()',
description='No one should be able to send tokens to the address 0x0 (transfer).',
content='''
caller=PropertyCaller.ANY,
),
Property(
name="crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()",
description="No one should be able to send tokens to the address 0x0 (transfer).",
content="""
\t\tif (this.balanceOf(msg.sender) == 0){
\t\t\trevert();
\t\t}
\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));''',
\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()',
description='No one should be able to send tokens to the address 0x0 (transferFrom).',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()",
description="No one should be able to send tokens to the address 0x0 (transferFrom).",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == 0){
\t\t\trevert();
\t\t}
\t\tapprove(msg.sender, balance);
\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));''',
\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_self_transferFrom_ERC20PropertiesTransferable()',
description='Self transferFrom works.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_self_transferFrom_ERC20PropertiesTransferable()",
description="Self transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance);
\t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;''',
\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()',
description='transferFrom works.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()",
description="transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance);
\t\taddress other = crytic_user;
@ -98,29 +110,30 @@ ERC20_Transferable = [
\t\t\tother = crytic_owner;
\t\t}
\t\tbool transfer_return = transferFrom(msg.sender, other, balance);
\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;''',
\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_self_transfer_ERC20PropertiesTransferable()',
description='Self transfer works.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_self_transfer_ERC20PropertiesTransferable()",
description="Self transfer works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tbool transfer_return = transfer(msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;''',
\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_transfer_to_other_ERC20PropertiesTransferable()',
description='transfer works.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_transfer_to_other_ERC20PropertiesTransferable()",
description="transfer works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\taddress other = crytic_user;
\t\tif (other == msg.sender) {
@ -130,74 +143,76 @@ ERC20_Transferable = [
\t\t\tbool transfer_other = transfer(other, 1);
\t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other;
\t\t}
\t\treturn true;''',
\t\treturn true;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_revert_transfer_to_user_ERC20PropertiesTransferable()',
description='Cannot transfer more than the balance.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_revert_transfer_to_user_ERC20PropertiesTransferable()",
description="Cannot transfer more than the balance.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == (2 ** 256 - 1))
\t\t\treturn true;
\t\tbool transfer_other = transfer(crytic_user, balance+1);
\t\treturn transfer_other;''',
\t\treturn transfer_other;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
caller=PropertyCaller.ALL,
),
]
ERC20_Pausable = [
Property(name='crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()',
description='Cannot transfer.',
content='''
\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));''',
Property(
name="crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()",
description="Cannot transfer.",
content="""
\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()',
description='Cannot execute transferFrom.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()",
description="Cannot execute transferFrom.",
content="""
\t\tapprove(msg.sender, this.balanceOf(msg.sender));
\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));''',
\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_constantBalance()',
description='Cannot change the balance.',
content='''
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;''',
caller=PropertyCaller.ALL,
),
Property(
name="crytic_constantBalance()",
description="Cannot change the balance.",
content="""
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_constantAllowance()',
description='Cannot change the allowance.',
content='''
caller=PropertyCaller.ALL,
),
Property(
name="crytic_constantAllowance()",
description="Cannot change the allowance.",
content="""
\t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) &&
\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);''',
\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
caller=PropertyCaller.ALL,
),
]

@ -11,26 +11,32 @@ from slither.tools.properties.properties.properties import Property
logger = logging.getLogger("Slither")
def generate_truffle_test(contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses) -> str:
test_contract = f'Test{contract.name}{type_property}'
filename_init = f'Initialization{test_contract}.js'
filename = f'{test_contract}.js'
def generate_truffle_test(
contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses
) -> str:
test_contract = f"Test{contract.name}{type_property}"
filename_init = f"Initialization{test_contract}.js"
filename = f"{test_contract}.js"
output_dir = Path(contract.slither.crytic_compile.target)
generate_migration(test_contract, output_dir, addresses.owner)
generate_unit_test(test_contract,
generate_unit_test(
test_contract,
filename_init,
ERC20_CONFIG,
output_dir,
addresses,
f'Check the constructor of {test_contract}')
f"Check the constructor of {test_contract}",
)
generate_unit_test(test_contract, filename, unit_tests, output_dir, addresses,)
generate_unit_test(
test_contract, filename, unit_tests, output_dir, addresses,
)
log_info = '\n'
log_info += 'To run the unit tests:\n'
log_info = "\n"
log_info += "To run the unit tests:\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n"
return log_info

@ -36,4 +36,4 @@ class Property(NamedTuple):
def property_to_solidity(p: Property):
return f'\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n'
return f"\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n"

@ -9,59 +9,64 @@ from slither.tools.properties.utils import write_file
logger = logging.getLogger("Slither")
def generate_solidity_properties(contract: Contract, type_property: str, solidity_properties: str,
output_dir: Path) -> Path:
def generate_solidity_properties(
contract: Contract, type_property: str, solidity_properties: str, output_dir: Path
) -> Path:
solidity_import = f'import "./interfaces.sol";\n'
solidity_import += f'import "../{contract.source_mapping["filename_short"]}";'
test_contract_name = f'Properties{contract.name}{type_property}'
test_contract_name = f"Properties{contract.name}{type_property}"
solidity_content = f'{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}'
solidity_content += f'{{\n\n{solidity_properties}\n}}\n'
solidity_content = (
f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}"
)
solidity_content += f"{{\n\n{solidity_properties}\n}}\n"
filename = f'{test_contract_name}.sol'
filename = f"{test_contract_name}.sol"
write_file(output_dir, filename, solidity_content)
return Path(filename)
def generate_test_contract(contract: Contract,
def generate_test_contract(
contract: Contract,
type_property: str,
output_dir: Path,
property_file: Path,
initialization_recommendation: str) -> Tuple[str, str]:
test_contract_name = f'Test{contract.name}{type_property}'
properties_name = f'Properties{contract.name}{type_property}'
initialization_recommendation: str,
) -> Tuple[str, str]:
test_contract_name = f"Test{contract.name}{type_property}"
properties_name = f"Properties{contract.name}{type_property}"
content = ''
content = ""
content += f'import "./{property_file}";\n'
content += f"contract {test_contract_name} is {properties_name} {{\n"
content += '\tconstructor() public{\n'
content += '\t\t// Existing addresses:\n'
content += '\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n'
content += '\t\t// - crytic_user: Legitimate user\n'
content += '\t\t// - crytic_attacker: Attacker\n'
content += '\t\t// \n'
content += "\tconstructor() public{\n"
content += "\t\t// Existing addresses:\n"
content += "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n"
content += "\t\t// - crytic_user: Legitimate user\n"
content += "\t\t// - crytic_attacker: Attacker\n"
content += "\t\t// \n"
content += initialization_recommendation
content += '\t\t// \n'
content += '\t\t// \n'
content += '\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n'
content += '\t\tinitialTotalSupply = totalSupply();\n'
content += '\t\tinitialBalance_owner = balanceOf(crytic_owner);\n'
content += '\t\tinitialBalance_user = balanceOf(crytic_user);\n'
content += '\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n'
content += "\t\t// \n"
content += "\t\t// \n"
content += "\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n"
content += "\t\tinitialTotalSupply = totalSupply();\n"
content += "\t\tinitialBalance_owner = balanceOf(crytic_owner);\n"
content += "\t\tinitialBalance_user = balanceOf(crytic_user);\n"
content += "\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n"
content += '\t}\n}\n'
content += "\t}\n}\n"
filename = f'{test_contract_name}.sol'
filename = f"{test_contract_name}.sol"
write_file(output_dir, filename, content, allow_overwrite=False)
return filename, test_contract_name
def generate_solidity_interface(output_dir: Path, addresses: Addresses):
content = f'''
content = f"""
contract CryticInterface{{
address internal crytic_owner = address({addresses.owner});
address internal crytic_user = address({addresses.user});
@ -70,7 +75,7 @@ contract CryticInterface{{
uint internal initialBalance_owner;
uint internal initialBalance_user;
uint internal initialBalance_attacker;
}}'''
}}"""
# Static file, we discard if it exists as it should never change
write_file(output_dir, 'interfaces.sol', content, discard_if_exist=True)
write_file(output_dir, "interfaces.sol", content, discard_if_exist=True)

@ -6,11 +6,13 @@ from slither.utils.colors import green, yellow
logger = logging.getLogger("Slither")
def write_file(output_dir: Path,
def write_file(
output_dir: Path,
filename: str,
content: str,
allow_overwrite: bool = True,
discard_if_exist: bool = False):
discard_if_exist: bool = False,
):
"""
Write the content into output_dir/filename
:param output_dir:
@ -25,10 +27,10 @@ def write_file(output_dir: Path,
if discard_if_exist:
return
if not allow_overwrite:
logger.info(yellow(f'{file_to_write} already exist and will not be overwritten'))
logger.info(yellow(f"{file_to_write} already exist and will not be overwritten"))
return
logger.info(yellow(f'Overwrite {file_to_write}'))
logger.info(yellow(f"Overwrite {file_to_write}"))
else:
logger.info(green(f'Write {file_to_write}'))
with open(file_to_write, 'w') as f:
logger.info(green(f"Write {file_to_write}"))
with open(file_to_write, "w") as f:
f.write(content)

@ -18,52 +18,46 @@ logger = logging.getLogger("Slither-simil")
modes = ["info", "test", "train", "plot"]
def parse_args():
parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector')
parser = argparse.ArgumentParser(
description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector"
)
parser.add_argument('mode',
help="|".join(modes))
parser.add_argument("mode", help="|".join(modes))
parser.add_argument('model',
help='model.bin')
parser.add_argument("model", help="model.bin")
parser.add_argument('--filename',
action='store',
dest='filename',
help='contract.sol')
parser.add_argument("--filename", action="store", dest="filename", help="contract.sol")
parser.add_argument('--fname',
action='store',
dest='fname',
help='Target function')
parser.add_argument("--fname", action="store", dest="fname", help="Target function")
parser.add_argument('--ext',
action='store',
dest='ext',
help='Extension to filter contracts')
parser.add_argument("--ext", action="store", dest="ext", help="Extension to filter contracts")
parser.add_argument('--nsamples',
action='store',
parser.add_argument(
"--nsamples",
action="store",
type=int,
dest='nsamples',
help='Number of contract samples used for training')
dest="nsamples",
help="Number of contract samples used for training",
)
parser.add_argument('--ntop',
action='store',
parser.add_argument(
"--ntop",
action="store",
type=int,
dest='ntop',
dest="ntop",
default=10,
help='Number of more similar contracts to show for testing')
help="Number of more similar contracts to show for testing",
)
parser.add_argument('--input',
action='store',
dest='input',
help='File or directory used as input')
parser.add_argument(
"--input", action="store", dest="input", help="File or directory used as input"
)
parser.add_argument('--version',
help='displays the current version',
version="0.0",
action='version')
parser.add_argument(
"--version", help="displays the current version", version="0.0", action="version"
)
cryticparser.init(parser)
@ -74,6 +68,7 @@ def parse_args():
args = parser.parse_args()
return args
# endregion
###################################################################################
###################################################################################
@ -81,6 +76,7 @@ def parse_args():
###################################################################################
###################################################################################
def main():
args = parse_args()
@ -98,10 +94,11 @@ def main():
elif mode == "plot":
plot(args)
else:
logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes))
logger.error("Invalid mode!. It should be one of these: %s" % ", ".join(modes))
sys.exit(-1)
if __name__ == '__main__':
if __name__ == "__main__":
main()
# endregion

@ -7,16 +7,18 @@ except ImportError:
print("$ pip3 install numpy --user\n")
sys.exit(-1)
def load_cache(infile, nsamples=None):
cache = dict()
with np.load(infile, allow_pickle=True) as data:
array = data['arr_0'][0]
for i,(x,y) in enumerate(array):
array = data["arr_0"][0]
for i, (x, y) in enumerate(array):
cache[x] = y
if i == nsamples:
break
return cache
def save_cache(cache, outfile):
np.savez(outfile,[np.array(cache)])
np.savez(outfile, [np.array(cache)])

@ -2,16 +2,47 @@ import logging
import os
from slither import Slither
from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function
from slither.core.declarations import (
Structure,
Enum,
SolidityVariableComposed,
SolidityVariable,
Function,
)
from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \
Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \
SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \
HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable
from slither.slithir.operations import (
Assignment,
Index,
AccessMember,
Length,
Balance,
Binary,
Unary,
Condition,
NewArray,
NewStructure,
NewContract,
NewElementaryType,
SolidityCall,
Push,
Delete,
EventCall,
LibraryCall,
InternalDynamicCall,
HighLevelCall,
LowLevelCall,
TypeConversion,
Return,
Transfer,
Send,
Unpack,
InitArray,
InternalCall,
)
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, IndexVariable
from .cache import load_cache
simil_logger = logging.getLogger("Slither-simil")
@ -20,11 +51,12 @@ compiler_logger.setLevel(logging.CRITICAL)
slither_logger = logging.getLogger("Slither")
slither_logger.setLevel(logging.CRITICAL)
def parse_target(target):
if target is None:
return None, None
parts = target.split('.')
parts = target.split(".")
if len(parts) == 1:
return None, parts[0]
elif len(parts) == 2:
@ -32,6 +64,7 @@ def parse_target(target):
else:
simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'")
def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs):
r = dict()
if infile.endswith(".npz"):
@ -39,13 +72,14 @@ def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs):
else:
contracts = load_contracts(infile, ext=ext, nsamples=nsamples)
for contract in contracts:
for x,ir in encode_contract(contract, **kwargs).items():
for x, ir in encode_contract(contract, **kwargs).items():
if ir != []:
y = " ".join(ir)
r[x] = vmodel.get_sentence_vector(y)
return r
def load_contracts(dirname, ext=None, nsamples=None, **kwargs):
r = []
walk = list(os.walk(dirname))
@ -60,6 +94,7 @@ def load_contracts(dirname, ext=None, nsamples=None, **kwargs):
# TODO: shuffle
return r[:nsamples]
def ntype(_type):
if isinstance(_type, ElementaryType):
_type = str(_type)
@ -79,8 +114,8 @@ def ntype(_type):
else:
_type = str(_type)
_type = _type.replace(" memory","")
_type = _type.replace(" storage ref","")
_type = _type.replace(" memory", "")
_type = _type.replace(" storage ref", "")
if "struct" in _type:
return "struct"
@ -93,91 +128,93 @@ def ntype(_type):
elif "mapping" in _type:
return "mapping"
else:
return _type.replace(" ","_")
return _type.replace(" ", "_")
def encode_ir(ir):
# operations
if isinstance(ir, Assignment):
return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue))
return "({}):=({})".format(encode_ir(ir.lvalue), encode_ir(ir.rvalue))
if isinstance(ir, Index):
return 'index({})'.format(ntype(ir._type))
if isinstance(ir, Member):
return 'member' #.format(ntype(ir._type))
return "index({})".format(ntype(ir._type))
if isinstance(ir, AccessMember):
return "member" # .format(ntype(ir._type))
if isinstance(ir, Length):
return 'length'
return "length"
if isinstance(ir, Balance):
return 'balance'
return "balance"
if isinstance(ir, Binary):
return 'binary({})'.format(ir.type_str)
return "binary({})".format(str(ir.type))
if isinstance(ir, Unary):
return 'unary({})'.format(ir.type_str)
return "unary({})".format(str(ir.type))
if isinstance(ir, Condition):
return 'condition({})'.format(encode_ir(ir.value))
return "condition({})".format(encode_ir(ir.value))
if isinstance(ir, NewStructure):
return 'new_structure'
return "new_structure"
if isinstance(ir, NewContract):
return 'new_contract'
return "new_contract"
if isinstance(ir, NewArray):
return 'new_array({})'.format(ntype(ir._array_type))
return "new_array({})".format(ntype(ir._array_type))
if isinstance(ir, NewElementaryType):
return 'new_elementary({})'.format(ntype(ir._type))
return "new_elementary({})".format(ntype(ir._type))
if isinstance(ir, Push):
return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue))
return "push({},{})".format(encode_ir(ir.value), encode_ir(ir.lvalue))
if isinstance(ir, Delete):
return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable))
return "delete({},{})".format(encode_ir(ir.lvalue), encode_ir(ir.variable))
if isinstance(ir, SolidityCall):
return 'solidity_call({})'.format(ir.function.full_name)
return "solidity_call({})".format(ir.function.full_name)
if isinstance(ir, InternalCall):
return 'internal_call({})'.format(ntype(ir._type_call))
return "internal_call({})".format(ntype(ir._type_call))
if isinstance(ir, EventCall): # is this useful?
return 'event'
return "event"
if isinstance(ir, LibraryCall):
return 'library_call'
return "library_call"
if isinstance(ir, InternalDynamicCall):
return 'internal_dynamic_call'
return "internal_dynamic_call"
if isinstance(ir, HighLevelCall): # TODO: improve
return 'high_level_call'
return "high_level_call"
if isinstance(ir, LowLevelCall): # TODO: improve
return 'low_level_call'
return "low_level_call"
if isinstance(ir, TypeConversion):
return 'type_conversion({})'.format(ntype(ir.type))
return "type_conversion({})".format(ntype(ir.type))
if isinstance(ir, Return): # this can be improved using values
return 'return' #.format(ntype(ir.type))
return "return" # .format(ntype(ir.type))
if isinstance(ir, Transfer):
return 'transfer({})'.format(encode_ir(ir.call_value))
return "transfer({})".format(encode_ir(ir.call_value))
if isinstance(ir, Send):
return 'send({})'.format(encode_ir(ir.call_value))
return "send({})".format(encode_ir(ir.call_value))
if isinstance(ir, Unpack): # TODO: improve
return 'unpack'
return "unpack"
if isinstance(ir, InitArray): # TODO: improve
return 'init_array'
return "init_array"
if isinstance(ir, Function): # TODO: investigate this
return 'function_solc'
return "function_solc"
# variables
if isinstance(ir, Constant):
return 'constant({})'.format(ntype(ir._type))
return "constant({})".format(ntype(ir._type))
if isinstance(ir, SolidityVariableComposed):
return 'solidity_variable_composed({})'.format(ir.name)
return "solidity_variable_composed({})".format(ir.name)
if isinstance(ir, SolidityVariable):
return 'solidity_variable{}'.format(ir.name)
return "solidity_variable{}".format(ir.name)
if isinstance(ir, TemporaryVariable):
return 'temporary_variable'
if isinstance(ir, ReferenceVariable):
return 'reference({})'.format(ntype(ir._type))
return "temporary_variable"
if isinstance(ir, IndexVariable):
return "reference({})".format(ntype(ir._type))
if isinstance(ir, LocalVariable):
return 'local_solc_variable({})'.format(ir._location)
return "local_solc_variable({})".format(ir._location)
if isinstance(ir, StateVariable):
return 'state_solc_variable({})'.format(ntype(ir._type))
return "state_solc_variable({})".format(ntype(ir._type))
if isinstance(ir, LocalVariableInitFromTuple):
return 'local_variable_init_tuple'
return "local_variable_init_tuple"
if isinstance(ir, TupleVariable):
return 'tuple_variable'
return "tuple_variable"
# default
else:
simil_logger.error(type(ir),"is missing encoding!")
return ''
simil_logger.error(type(ir), "is missing encoding!")
return ""
def encode_contract(cfilename, **kwargs):
r = dict()
@ -186,7 +223,7 @@ def encode_contract(cfilename, **kwargs):
try:
slither = Slither(cfilename, **kwargs)
except:
simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc'])
simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs["solc"])
return r
# Iterate over all the contracts
@ -198,7 +235,7 @@ def encode_contract(cfilename, **kwargs):
if function.nodes == [] or function.is_constructor_variables:
continue
x = (cfilename,contract.name,function.name)
x = (cfilename, contract.name, function.name)
r[x] = []
@ -210,5 +247,3 @@ def encode_contract(cfilename, **kwargs):
for ir in node.irs:
r[x].append(encode_ir(ir))
return r

@ -9,6 +9,7 @@ from .encode import parse_target, encode_contract
logging.basicConfig()
logger = logging.getLogger("Slither-simil")
def info(args):
try:
@ -24,20 +25,20 @@ def info(args):
solc = args.solc
if filename is None and contract is None and fname is None:
logger.info("%s uses the following words:",args.model)
logger.info("%s uses the following words:", args.model)
for word in model.get_words():
logger.info(word)
sys.exit(0)
if filename is None or contract is None or fname is None:
logger.error('The encode mode requires filename, contract and fname parameters.')
logger.error("The encode mode requires filename, contract and fname parameters.")
sys.exit(-1)
irs = encode_contract(filename, **vars(args))
if len(irs) == 0:
sys.exit(-1)
x = (filename,contract,fname)
x = (filename, contract, fname)
y = " ".join(irs[x])
logger.info("Function {} in contract {} is encoded as:".format(fname, contract))
@ -47,8 +48,6 @@ def info(args):
logger.info(fvector)
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -17,10 +17,13 @@ except ImportError:
logger = logging.getLogger("Slither-simil")
def plot(args):
if decomposition is None or plt is None:
logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:")
logger.error(
"ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:"
)
logger.error("$ pip3 install sklearn matplotlib --user")
sys.exit(-1)
@ -29,50 +32,50 @@ def plot(args):
model = args.model
model = load_model(model)
filename = args.filename
#contract = args.contract
# contract = args.contract
contract, fname = parse_target(args.fname)
#solc = args.solc
# solc = args.solc
infile = args.input
#ext = args.filter
#nsamples = args.nsamples
# ext = args.filter
# nsamples = args.nsamples
if fname is None or infile is None:
logger.error('The plot mode requieres fname and input parameters.')
logger.error("The plot mode requieres fname and input parameters.")
sys.exit(-1)
logger.info('Loading data..')
logger.info("Loading data..")
cache = load_and_encode(infile, **vars(args))
data = list()
fs = list()
logger.info('Procesing data..')
for (f,c,n),y in cache.items():
logger.info("Procesing data..")
for (f, c, n), y in cache.items():
if (c == contract or contract is None) and n == fname:
fs.append(f)
data.append(y)
if len(data) == 0:
logger.error('No contract was found with function %s', fname)
logger.error("No contract was found with function %s", fname)
sys.exit(-1)
data = np.array(data)
pca = decomposition.PCA(n_components=2)
tdata = pca.fit_transform(data)
logger.info('Plotting data..')
plt.figure(figsize=(20,10))
assert(len(tdata) == len(fs))
for ([x,y],l) in zip(tdata, fs):
logger.info("Plotting data..")
plt.figure(figsize=(20, 10))
assert len(tdata) == len(fs)
for ([x, y], l) in zip(tdata, fs):
x = random.gauss(0, 0.01) + x
y = random.gauss(0, 0.01) + y
plt.scatter(x, y, c='blue')
plt.text(x-0.001,y+0.001, l)
plt.scatter(x, y, c="blue")
plt.text(x - 0.001, y + 0.001, l)
logger.info('Saving figure to plot.png..')
plt.savefig('plot.png', bbox_inches='tight')
logger.info("Saving figure to plot.png..")
plt.savefig("plot.png", bbox_inches="tight")
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -1,5 +1,6 @@
import numpy as np
def similarity(v1, v2):
n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2)

@ -12,6 +12,7 @@ from .similarity import similarity
logger = logging.getLogger("Slither-simil")
def test(args):
try:
@ -23,32 +24,32 @@ def test(args):
ntop = args.ntop
if filename is None or contract is None or fname is None or infile is None:
logger.error('The test mode requires filename, contract, fname and input parameters.')
logger.error("The test mode requires filename, contract, fname and input parameters.")
sys.exit(-1)
irs = encode_contract(filename, **vars(args))
if len(irs) == 0:
sys.exit(-1)
y = " ".join(irs[(filename,contract,fname)])
y = " ".join(irs[(filename, contract, fname)])
fvector = model.get_sentence_vector(y)
cache = load_and_encode(infile, model, **vars(args))
#save_cache("cache.npz", cache)
# save_cache("cache.npz", cache)
r = dict()
for x,y in cache.items():
for x, y in cache.items():
r[x] = similarity(fvector, y)
r = sorted(r.items(), key=operator.itemgetter(1), reverse=True)
logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop)
format_table = "{: <65} {: <20} {: <20} {: <10}"
logger.info(format_table.format(*["filename", "contract", "function", "score"]))
for x,score in r[:ntop]:
for x, score in r[:ntop]:
score = str(round(score, 3))
logger.info(format_table.format(*(list(x)+[score])))
logger.info(format_table.format(*(list(x) + [score])))
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -11,6 +11,7 @@ from .cache import save_cache
logger = logging.getLogger("Slither-simil")
def train(args):
try:
@ -20,35 +21,37 @@ def train(args):
nsamples = args.nsamples
if dirname is None:
logger.error('The train mode requires the input parameter.')
logger.error("The train mode requires the input parameter.")
sys.exit(-1)
contracts = load_contracts(dirname, **vars(args))
logger.info('Saving extracted data into %s', last_data_train_filename)
logger.info("Saving extracted data into %s", last_data_train_filename)
cache = []
with open(last_data_train_filename, 'w') as f:
with open(last_data_train_filename, "w") as f:
for filename in contracts:
#cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items():
# cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(
filename, **vars(args)
).items():
if ir != []:
x = " ".join(ir)
f.write(x+"\n")
f.write(x + "\n")
cache.append((os.path.split(filename)[-1], contract, function, x))
logger.info('Starting training')
model = train_unsupervised(input=last_data_train_filename, model='skipgram')
logger.info('Training complete')
logger.info('Saving model')
logger.info("Starting training")
model = train_unsupervised(input=last_data_train_filename, model="skipgram")
logger.info("Training complete")
logger.info("Saving model")
model.save_model(model_filename)
for i,(filename, contract, function, irs) in enumerate(cache):
for i, (filename, contract, function, irs) in enumerate(cache):
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs))
logger.info('Saving cache in cache.npz')
logger.info("Saving cache in cache.npz")
save_cache(cache, "cache.npz")
logger.info('Done!')
logger.info("Done!")
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -10,56 +10,70 @@ logging.basicConfig()
logger = logging.getLogger("Slither").setLevel(logging.INFO)
# Slither detectors for which slither-format currently works
available_detectors = ["unused-state",
available_detectors = [
"unused-state",
"solc-version",
"pragma",
"naming-convention",
"external-function",
"constable-states",
"constant-function-asm",
"constatnt-function-state"]
"constatnt-function-state",
]
detectors_to_run = []
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='slither_format',
usage='slither_format filename')
parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False)
parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False)
parser.add_argument('--version',
help='displays the current version',
version='0.1.0',
action='version')
parser.add_argument('--config-file',
help='Provide a config file (default: slither.config.json)',
action='store',
dest='config_file',
default='slither.config.json')
group_detector = parser.add_argument_group('Detectors')
group_detector.add_argument('--detect',
help='Comma-separated list of detectors, defaults to all, '
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_run',
default='all')
group_detector.add_argument('--exclude',
help='Comma-separated list of detectors to exclude,'
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_exclude',
default='all')
parser = argparse.ArgumentParser(description="slither_format", usage="slither_format filename")
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument(
"--verbose-test",
"-v",
help="verbose mode output for testing",
action="store_true",
default=False,
)
parser.add_argument(
"--verbose-json", "-j", help="verbose json output", action="store_true", default=False
)
parser.add_argument(
"--version", help="displays the current version", version="0.1.0", action="version"
)
parser.add_argument(
"--config-file",
help="Provide a config file (default: slither.config.json)",
action="store",
dest="config_file",
default="slither.config.json",
)
group_detector = parser.add_argument_group("Detectors")
group_detector.add_argument(
"--detect",
help="Comma-separated list of detectors, defaults to all, "
"available detectors: {}".format(", ".join(d for d in available_detectors)),
action="store",
dest="detectors_to_run",
default="all",
)
group_detector.add_argument(
"--exclude",
help="Comma-separated list of detectors to exclude,"
"available detectors: {}".format(", ".join(d for d in available_detectors)),
action="store",
dest="detectors_to_exclude",
default="all",
)
cryticparser.init(parser)
@ -80,11 +94,12 @@ def main():
read_config_file(args)
# Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args))
# Format the input files based on slither analysis
slither_format(slither, **vars(args))
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -11,27 +11,29 @@ from slither.detectors.attributes.const_functions_state import ConstantFunctions
from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format')
logger = logging.getLogger("Slither.Format")
all_detectors = {
'unused-state': UnusedStateVars,
'solc-version': IncorrectSolc,
'pragma': ConstantPragma,
'naming-convention': NamingConvention,
'external-function': ExternalFunction,
'constable-states' : ConstCandidateStateVars,
'constant-function-asm': ConstantFunctionsAsm,
'constant-functions-state': ConstantFunctionsState
"unused-state": UnusedStateVars,
"solc-version": IncorrectSolc,
"pragma": ConstantPragma,
"naming-convention": NamingConvention,
"external-function": ExternalFunction,
"constable-states": ConstCandidateStateVars,
"constant-function-asm": ConstantFunctionsAsm,
"constant-functions-state": ConstantFunctionsState,
}
def slither_format(slither, **kwargs):
''''
"""'
Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all
'''
"""
detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'),
kwargs.get('detectors_to_exclude', ''))
detectors_to_run = choose_detectors(
kwargs.get("detectors_to_run", "all"), kwargs.get("detectors_to_exclude", "")
)
for detector in detectors_to_run:
slither.register_detector(detector)
@ -42,32 +44,32 @@ def slither_format(slither, **kwargs):
detector_results = [x for x in detector_results if x] # remove empty results
detector_results = [item for sublist in detector_results for item in sublist] # flatten
export = Path('crytic-export', 'patches')
export = Path("crytic-export", "patches")
export.mkdir(parents=True, exist_ok=True)
counter_result = 0
logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.'))
logger.info(yellow("slither-format is in beta, carefully review each patch before merging it."))
for result in detector_results:
if not 'patches' in result:
if not "patches" in result:
continue
one_line_description = result["description"].split("\n")[0]
export_result = Path(export, f'{counter_result}')
export_result = Path(export, f"{counter_result}")
export_result.mkdir(parents=True, exist_ok=True)
counter_result += 1
counter = 0
logger.info(f'Issue: {one_line_description}')
logger.info(f'Generated: ({export_result})')
logger.info(f"Issue: {one_line_description}")
logger.info(f"Generated: ({export_result})")
for file, diff, in result['patches_diff'].items():
filename = f'fix_{counter}.patch'
for file, diff, in result["patches_diff"].items():
filename = f"fix_{counter}.patch"
path = Path(export_result, filename)
logger.info(f'\t- {filename}')
with open(path, 'w') as f:
logger.info(f"\t- {filename}")
with open(path, "w") as f:
f.write(diff)
counter += 1
@ -79,26 +81,28 @@ def slither_format(slither, **kwargs):
###################################################################################
###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude):
# If detectors are specified, run only these ones
cls_detectors_to_run = []
exclude = detectors_to_exclude.split(',')
if detectors_to_run == 'all':
exclude = detectors_to_exclude.split(",")
if detectors_to_run == "all":
for d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
exclude = detectors_to_exclude.split(',')
for d in detectors_to_run.split(','):
exclude = detectors_to_exclude.split(",")
for d in detectors_to_run.split(","):
if d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
raise Exception('Error: {} is not a detector'.format(d))
raise Exception("Error: {} is not a detector".format(d))
return cls_detectors_to_run
# endregion
###################################################################################
###################################################################################
@ -106,6 +110,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude):
###################################################################################
###################################################################################
def print_patches(number_of_slither_results, patches):
logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0
@ -115,39 +120,38 @@ def print_patches(number_of_slither_results, patches):
for file in patches:
logger.info("Patch file: " + file)
for patch in patches[file]:
logger.info("Detector: " + patch['detector'])
logger.info("Old string: " + patch['old_string'].replace("\n",""))
logger.info("New string: " + patch['new_string'].replace("\n",""))
logger.info("Location start: " + str(patch['start']))
logger.info("Location end: " + str(patch['end']))
logger.info("Detector: " + patch["detector"])
logger.info("Old string: " + patch["old_string"].replace("\n", ""))
logger.info("New string: " + patch["new_string"].replace("\n", ""))
logger.info("Location start: " + str(patch["start"]))
logger.info("Location end: " + str(patch["end"]))
def print_patches_json(number_of_slither_results, patches):
print('{',end='')
print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",')
print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',')
print("\"Patchlets\":" + '[')
print("{", end="")
print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",')
print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",")
print('"Patchlets":' + "[")
for index, file in enumerate(patches):
if index > 0:
print(',')
print('{',end='')
print("\"Patch file\":" + '"' + file + '",')
print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',')
print("\"Patches\":" + '[')
print(",")
print("{", end="")
print('"Patch file":' + '"' + file + '",')
print('"Number of patches":' + '"' + str(len(patches[file])) + '"', ",")
print('"Patches":' + "[")
for index, patch in enumerate(patches[file]):
if index > 0:
print(',')
print('{',end='')
print("\"Detector\":" + '"' + patch['detector'] + '",')
print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",')
print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",')
print("\"Location start\":" + '"' + str(patch['start']) + '",')
print("\"Location end\":" + '"' + str(patch['end']) + '"')
if 'overlaps' in patch:
print("\"Overlaps\":" + "Yes")
print('}',end='')
print(']',end='')
print('}',end='')
print(']',end='')
print('}')
print(",")
print("{", end="")
print('"Detector":' + '"' + patch["detector"] + '",')
print('"Old string":' + '"' + patch["old_string"].replace("\n", "") + '",')
print('"New string":' + '"' + patch["new_string"].replace("\n", "") + '",')
print('"Location start":' + '"' + str(patch["start"]) + '",')
print('"Location end":' + '"' + str(patch["end"]) + '"')
if "overlaps" in patch:
print('"Overlaps":' + "Yes")
print("}", end="")
print("]", end="")
print("}", end="")
print("]", end="")
print("}")

@ -12,7 +12,12 @@ from slither.utils.colors import red
from slither.utils.output import output_to_json
from .checks import all_checks
from .checks.abstract_checks import AbstractCheck
from .utils.command_line import output_detectors_json, output_wiki, output_detectors, output_to_markdown
from .utils.command_line import (
output_detectors_json,
output_wiki,
output_detectors,
output_to_markdown,
)
logging.basicConfig()
logger = logging.getLogger("Slither")
@ -21,49 +26,53 @@ logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description='Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.',
usage="slither-check-upgradeability contract.sol ContractName")
description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.",
usage="slither-check-upgradeability contract.sol ContractName",
)
parser.add_argument('contract.sol', help='Codebase to analyze')
parser.add_argument('ContractName', help='Contract name (logic contract)')
parser.add_argument("contract.sol", help="Codebase to analyze")
parser.add_argument("ContractName", help="Contract name (logic contract)")
parser.add_argument('--proxy-name', help='Proxy name')
parser.add_argument('--proxy-filename', help='Proxy filename (if different)')
parser.add_argument("--proxy-name", help="Proxy name")
parser.add_argument("--proxy-filename", help="Proxy filename (if different)")
parser.add_argument('--new-contract-name', help='New contract name (if changed)')
parser.add_argument('--new-contract-filename', help='New implementation filename (if different)')
parser.add_argument("--new-contract-name", help="New contract name (if changed)")
parser.add_argument(
"--new-contract-filename", help="New implementation filename (if different)"
)
parser.add_argument('--json',
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store',
default=False)
action="store",
default=False,
)
parser.add_argument('--list-detectors',
help='List available detectors',
parser.add_argument(
"--list-detectors",
help="List available detectors",
action=ListDetectors,
nargs=0,
default=False)
default=False,
)
parser.add_argument('--markdown-root',
help='URL for markdown generation',
action='store',
default="")
parser.add_argument(
"--markdown-root", help="URL for markdown generation", action="store", default=""
)
parser.add_argument('--wiki-detectors',
help=argparse.SUPPRESS,
action=OutputWiki,
default=False)
parser.add_argument(
"--wiki-detectors", help=argparse.SUPPRESS, action=OutputWiki, default=False
)
parser.add_argument('--list-detectors-json',
parser.add_argument(
"--list-detectors-json",
help=argparse.SUPPRESS,
action=ListDetectorsJson,
nargs=0,
default=False)
default=False,
)
parser.add_argument('--markdown',
help=argparse.SUPPRESS,
action=OutputMarkdown,
default=False)
parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False)
cryticparser.init(parser)
@ -80,6 +89,7 @@ def parse_args():
###################################################################################
###################################################################################
def _get_checks():
detectors = [getattr(all_checks, name) for name in dir(all_checks)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)]
@ -123,13 +133,18 @@ def _run_checks(detectors):
def _checks_on_contract(detectors, contract):
detectors = [d(logger, contract) for d in detectors if (not d.REQUIRE_PROXY and
not d.REQUIRE_CONTRACT_V2)]
detectors = [
d(logger, contract)
for d in detectors
if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2)
]
return _run_checks(detectors), len(detectors)
def _checks_on_contract_update(detectors, contract_v1, contract_v2):
detectors = [d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2]
detectors = [
d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2
]
return _run_checks(detectors), len(detectors)
@ -147,15 +162,11 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy):
def main():
json_results = {
'proxy-present': False,
'contract_v2-present': False,
'detectors': []
}
json_results = {"proxy-present": False, "contract_v2-present": False, "detectors": []}
args = parse_args()
v1_filename = vars(args)['contract.sol']
v1_filename = vars(args)["contract.sol"]
number_detectors_run = 0
detectors = _get_checks()
try:
@ -165,14 +176,14 @@ def main():
v1_name = args.ContractName
v1_contract = v1.get_contract_from_name(v1_name)
if v1_contract is None:
info = 'Contract {} not found in {}'.format(v1_name, v1.filename)
info = "Contract {} not found in {}".format(v1_name, v1.filename)
logger.error(red(info))
if args.json:
output_to_json(args.json, str(info), json_results)
return
detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract)
json_results['detectors'] += detectors_results
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
# Analyze Proxy
@ -185,15 +196,17 @@ def main():
proxy_contract = proxy.get_contract_from_name(args.proxy_name)
if proxy_contract is None:
info = 'Proxy {} not found in {}'.format(args.proxy_name, proxy.filename)
info = "Proxy {} not found in {}".format(args.proxy_name, proxy.filename)
logger.error(red(info))
if args.json:
output_to_json(args.json, str(info), json_results)
return
json_results['proxy-present'] = True
json_results["proxy-present"] = True
detectors_results, number_detectors = _checks_on_contract_and_proxy(detectors, v1_contract, proxy_contract)
json_results['detectors'] += detectors_results
detectors_results, number_detectors = _checks_on_contract_and_proxy(
detectors, v1_contract, proxy_contract
)
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
# Analyze new version
if args.new_contract_name:
@ -204,30 +217,36 @@ def main():
v2_contract = v2.get_contract_from_name(args.new_contract_name)
if v2_contract is None:
info = 'New logic contract {} not found in {}'.format(args.new_contract_name, v2.filename)
info = "New logic contract {} not found in {}".format(
args.new_contract_name, v2.filename
)
logger.error(red(info))
if args.json:
output_to_json(args.json, str(info), json_results)
return
json_results['contract_v2-present'] = True
json_results["contract_v2-present"] = True
if proxy_contract:
detectors_results, _ = _checks_on_contract_and_proxy(detectors,
v2_contract,
proxy_contract)
detectors_results, _ = _checks_on_contract_and_proxy(
detectors, v2_contract, proxy_contract
)
json_results['detectors'] += detectors_results
json_results["detectors"] += detectors_results
detectors_results, number_detectors = _checks_on_contract_update(detectors, v1_contract, v2_contract)
json_results['detectors'] += detectors_results
detectors_results, number_detectors = _checks_on_contract_update(
detectors, v1_contract, v2_contract
)
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
# If there is a V2, we run the contract-only check on the V2
detectors_results, _ = _checks_on_contract(detectors, v2_contract)
json_results['detectors'] += detectors_results
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
logger.info(f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run')
logger.info(
f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run'
)
if args.json:
output_to_json(args.json, None, json_results)
@ -237,4 +256,5 @@ def main():
output_to_json(args.json, str(e), json_results)
return
# endregion

@ -19,28 +19,28 @@ classification_colors = {
CheckClassification.INFORMATIONAL: green,
CheckClassification.LOW: yellow,
CheckClassification.MEDIUM: yellow,
CheckClassification.HIGH: red
CheckClassification.HIGH: red,
}
classification_txt = {
CheckClassification.INFORMATIONAL: 'Informational',
CheckClassification.LOW: 'Low',
CheckClassification.MEDIUM: 'Medium',
CheckClassification.HIGH: 'High',
CheckClassification.INFORMATIONAL: "Informational",
CheckClassification.LOW: "Low",
CheckClassification.MEDIUM: "Medium",
CheckClassification.HIGH: "High",
}
class AbstractCheck(metaclass=abc.ABCMeta):
ARGUMENT = ''
HELP = ''
ARGUMENT = ""
HELP = ""
IMPACT = None
WIKI = ''
WIKI = ""
WIKI_TITLE = ''
WIKI_DESCRIPTION = ''
WIKI_EXPLOIT_SCENARIO = ''
WIKI_RECOMMENDATION = ''
WIKI_TITLE = ""
WIKI_DESCRIPTION = ""
WIKI_EXPLOIT_SCENARIO = ""
WIKI_RECOMMENDATION = ""
REQUIRE_CONTRACT = False
REQUIRE_PROXY = False
@ -53,43 +53,69 @@ class AbstractCheck(metaclass=abc.ABCMeta):
self.contract_v2 = contract_v2
if not self.ARGUMENT:
raise IncorrectCheckInitialization('NAME is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"NAME is not initialized {}".format(self.__class__.__name__)
)
if not self.HELP:
raise IncorrectCheckInitialization('HELP is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"HELP is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI:
raise IncorrectCheckInitialization('WIKI is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_TITLE:
raise IncorrectCheckInitialization('WIKI_TITLE is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI_TITLE is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_DESCRIPTION:
raise IncorrectCheckInitialization('WIKI_DESCRIPTION is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI_DESCRIPTION is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [CheckClassification.INFORMATIONAL]:
raise IncorrectCheckInitialization('WIKI_EXPLOIT_SCENARIO is not initialized {}'.format(self.__class__.__name__))
if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [
CheckClassification.INFORMATIONAL
]:
raise IncorrectCheckInitialization(
"WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_RECOMMENDATION:
raise IncorrectCheckInitialization('WIKI_RECOMMENDATION is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2:
# This is not a fundatemenal issues
# But it requires to change __main__ to avoid running two times the detectors
txt = 'REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}'.format(self.__class__.__name__)
txt = "REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}".format(
self.__class__.__name__
)
raise IncorrectCheckInitialization(txt)
if self.IMPACT not in [CheckClassification.LOW,
if self.IMPACT not in [
CheckClassification.LOW,
CheckClassification.MEDIUM,
CheckClassification.HIGH,
CheckClassification.INFORMATIONAL]:
raise IncorrectCheckInitialization('IMPACT is not initialized {}'.format(self.__class__.__name__))
CheckClassification.INFORMATIONAL,
]:
raise IncorrectCheckInitialization(
"IMPACT is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_CONTRACT_V2 and contract_v2 is None:
raise IncorrectCheckInitialization('ContractV2 is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"ContractV2 is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_PROXY and proxy is None:
raise IncorrectCheckInitialization('Proxy is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"Proxy is not initialized {}".format(self.__class__.__name__)
)
@abc.abstractmethod
def _check(self):
@ -102,19 +128,17 @@ class AbstractCheck(metaclass=abc.ABCMeta):
all_results = [r.data for r in all_results]
if all_results:
if self.logger:
info = '\n'
info = "\n"
for idx, result in enumerate(all_results):
info += result['description']
info += 'Reference: {}'.format(self.WIKI)
info += result["description"]
info += "Reference: {}".format(self.WIKI)
self._log(info)
return all_results
def generate_result(self, info, additional_fields=None):
output = Output(info,
additional_fields,
markdown_root=self.contract.slither.markdown_root)
output = Output(info, additional_fields, markdown_root=self.contract.slither.markdown_root)
output.data['check'] = self.ARGUMENT
output.data["check"] = self.ARGUMENT
return output

@ -1,11 +1,23 @@
from .initialization import (InitializablePresent, InitializableInherited,
InitializableInitializer, MissingInitializerModifier, MissingCalls, MultipleCalls, InitializeTarget)
from .initialization import (
InitializablePresent,
InitializableInherited,
InitializableInitializer,
MissingInitializerModifier,
MissingCalls,
MultipleCalls,
InitializeTarget,
)
from .functions_ids import IDCollision, FunctionShadowing
from .variable_initialization import VariableWithInit
from .variables_order import (MissingVariable, DifferentVariableContractProxy,
DifferentVariableContractNewContract, ExtraVariablesProxy, ExtraVariablesNewContract)
from .variables_order import (
MissingVariable,
DifferentVariableContractProxy,
DifferentVariableContractNewContract,
ExtraVariablesProxy,
ExtraVariablesNewContract,
)
from .constant import WereConstant, BecameConstant

@ -2,17 +2,17 @@ from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, C
class WereConstant(AbstractCheck):
ARGUMENT = 'were-constant'
ARGUMENT = "were-constant"
IMPACT = CheckClassification.HIGH
HELP = 'Variables that should be constant'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant'
WIKI_TITLE = 'Variables that should be constant'
WIKI_DESCRIPTION = '''
HELP = "Variables that should be constant"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant"
WIKI_TITLE = "Variables that should be constant"
WIKI_DESCRIPTION = """
Detect state variables that should be `constant̀`.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -28,11 +28,11 @@ contract ContractV2{
```
Because `variable2` is not anymore a `constant`, the storage location of `variable3` will be different.
As a result, `ContractV2` will have a corrupted storage layout.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Do not remove `constant` from a state variables during an update.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
@ -66,12 +66,13 @@ Do not remove `constant` from a state variables during an update.
if state_v1.is_constant:
if not state_v2.is_constant:
# If v2 has additional non constant variables, we need to skip them
if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type)
and v2_additional_variables > 0):
if (
state_v1.name != state_v2.name or state_v1.type != state_v2.type
) and v2_additional_variables > 0:
v2_additional_variables -= 1
idx_v2 += 1
continue
info = [state_v1, ' was constant, but ', state_v2, 'is not.\n']
info = [state_v1, " was constant, but ", state_v2, "is not.\n"]
json = self.generate_result(info)
results.append(json)
@ -80,19 +81,20 @@ Do not remove `constant` from a state variables during an update.
return results
class BecameConstant(AbstractCheck):
ARGUMENT = 'became-constant'
ARGUMENT = "became-constant"
IMPACT = CheckClassification.HIGH
HELP = 'Variables that should not be constant'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant'
WIKI_TITLE = 'Variables that should not be constant'
HELP = "Variables that should not be constant"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant"
WIKI_TITLE = "Variables that should not be constant"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect state variables that should not be `constant̀`.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -108,11 +110,11 @@ contract ContractV2{
```
Because `variable2` is now a `constant`, the storage location of `variable3` will be different.
As a result, `ContractV2` will have a corrupted storage layout.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Do not make an existing state variable `constant`.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
@ -146,13 +148,14 @@ Do not make an existing state variable `constant`.
if state_v1.is_constant:
if not state_v2.is_constant:
# If v2 has additional non constant variables, we need to skip them
if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type)
and v2_additional_variables > 0):
if (
state_v1.name != state_v2.name or state_v1.type != state_v2.type
) and v2_additional_variables > 0:
v2_additional_variables -= 1
idx_v2 += 1
continue
elif state_v2.is_constant:
info = [state_v1, ' was not constant but ', state_v2, ' is.\n']
info = [state_v1, " was not constant but ", state_v2, " is.\n"]
json = self.generate_result(info)
results.append(json)

@ -5,11 +5,16 @@ from slither.utils.function import get_function_id
def get_signatures(c):
functions = c.functions
functions = [f.full_name for f in functions if f.visibility in ['public', 'external'] and
not f.is_constructor and not f.is_fallback]
functions = [
f.full_name
for f in functions
if f.visibility in ["public", "external"] and not f.is_constructor and not f.is_fallback
]
variables = c.state_variables
variables = [variable.name + '()' for variable in variables if variable.visibility in ['public']]
variables = [
variable.name + "()" for variable in variables if variable.visibility in ["public"]
]
return list(set(functions + variables))
@ -21,26 +26,26 @@ def _get_function_or_variable(contract, signature):
for variable in contract.state_variables:
# Todo: can lead to incorrect variable in case of shadowing
if variable.visibility in ['public']:
if variable.name + '()' == signature:
if variable.visibility in ["public"]:
if variable.name + "()" == signature:
return variable
raise SlitherError(f'Function id checks: {signature} not found in {contract.name}')
raise SlitherError(f"Function id checks: {signature} not found in {contract.name}")
class IDCollision(AbstractCheck):
ARGUMENT = 'function-id-collision'
ARGUMENT = "function-id-collision"
IMPACT = CheckClassification.HIGH
HELP = 'Functions ids collision'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions'
WIKI_TITLE = 'Functions ids collisions'
HELP = "Functions ids collision"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions"
WIKI_TITLE = "Functions ids collisions"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect function id collision between the contract and the proxy.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
function gsf() public {
@ -56,11 +61,11 @@ contract Proxy{
```
`Proxy.tgeo()` and `Contract.gsf()` have the same function id (0x67e43e43).
As a result, `Proxy.tgeo()` will shadow Contract.gsf()`.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Rename the function. Avoid public functions in the proxy.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -77,11 +82,18 @@ Rename the function. Avoid public functions in the proxy.
for (k, _) in signatures_ids_implem.items():
if k in signatures_ids_proxy:
if signatures_ids_implem[k] != signatures_ids_proxy[k]:
implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k])
implem_function = _get_function_or_variable(
self.contract, signatures_ids_implem[k]
)
proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k])
info = ['Function id collision found: ', implem_function,
' ', proxy_function, '\n']
info = [
"Function id collision found: ",
implem_function,
" ",
proxy_function,
"\n",
]
json = self.generate_result(info)
results.append(json)
@ -89,18 +101,18 @@ Rename the function. Avoid public functions in the proxy.
class FunctionShadowing(AbstractCheck):
ARGUMENT = 'function-shadowing'
ARGUMENT = "function-shadowing"
IMPACT = CheckClassification.HIGH
HELP = 'Functions shadowing'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing'
WIKI_TITLE = 'Functions shadowing'
HELP = "Functions shadowing"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing"
WIKI_TITLE = "Functions shadowing"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect function shadowing between the contract and the proxy.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
function get() public {
@ -115,11 +127,11 @@ contract Proxy{
}
```
`Proxy.get` will shadow any call to `get()`. As a result `get()` is never executed in the logic contract and cannot be updated.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Rename the function. Avoid public functions in the proxy.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -136,11 +148,18 @@ Rename the function. Avoid public functions in the proxy.
for (k, _) in signatures_ids_implem.items():
if k in signatures_ids_proxy:
if signatures_ids_implem[k] == signatures_ids_proxy[k]:
implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k])
implem_function = _get_function_or_variable(
self.contract, signatures_ids_implem[k]
)
proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k])
info = ['Function shadowing found: ', implem_function,
' ', proxy_function, '\n']
info = [
"Function shadowing found: ",
implem_function,
" ",
proxy_function,
"\n",
]
json = self.generate_result(info)
results.append(json)

@ -13,16 +13,20 @@ class MultipleInitTarget(Exception):
def _get_initialize_functions(contract):
return [f for f in contract.functions if f.name == 'initialize' and f.is_implemented]
return [f for f in contract.functions if f.name == "initialize" and f.is_implemented]
def _get_all_internal_calls(function):
all_ir = function.all_slithir_operations()
return [i.function for i in all_ir if isinstance(i, InternalCall) and i.function_name == "initialize"]
return [
i.function
for i in all_ir
if isinstance(i, InternalCall) and i.function_name == "initialize"
]
def _get_most_derived_init(contract):
init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == 'initialize']
init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == "initialize"]
if len(init_functions) > 1:
if len([f for f in init_functions if f.contract_declarer == contract]) == 1:
return next((f for f in init_functions if f.contract_declarer == contract))
@ -33,80 +37,82 @@ def _get_most_derived_init(contract):
class InitializablePresent(AbstractCheck):
ARGUMENT = 'init-missing'
ARGUMENT = "init-missing"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initializable is missing'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing'
WIKI_TITLE = 'Initializable is missing'
WIKI_DESCRIPTION = '''
HELP = "Initializable is missing"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing"
WIKI_TITLE = "Initializable is missing"
WIKI_DESCRIPTION = """
Detect if a contract `Initializable` is present.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Review manually the contract's initialization..
Consider using a `Initializable` contract to follow [standard practice](https://docs.openzeppelin.com/upgrades/2.7/writing-upgradeable).
'''
"""
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
if initializable is None:
info = ["Initializable contract not found, the contract does not follow a standard initalization schema.\n"]
info = [
"Initializable contract not found, the contract does not follow a standard initalization schema.\n"
]
json = self.generate_result(info)
return [json]
return []
class InitializableInherited(AbstractCheck):
ARGUMENT = 'init-inherited'
ARGUMENT = "init-inherited"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initializable is not inherited'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited'
WIKI_TITLE = 'Initializable is not inherited'
HELP = "Initializable is not inherited"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited"
WIKI_TITLE = "Initializable is not inherited"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect if `Initializable` is inherited.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Review manually the contract's initialization. Consider inheriting `Initializable`.
'''
"""
REQUIRE_CONTRACT = True
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent
if initializable is None:
return []
if initializable not in self.contract.inheritance:
info = [self.contract, ' does not inherit from ', initializable, '.\n']
info = [self.contract, " does not inherit from ", initializable, ".\n"]
json = self.generate_result(info)
return [json]
return []
class InitializableInitializer(AbstractCheck):
ARGUMENT = 'initializer-missing'
ARGUMENT = "initializer-missing"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'initializer() is missing'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing'
WIKI_TITLE = 'initializer() is missing'
HELP = "initializer() is missing"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing"
WIKI_TITLE = "initializer() is missing"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect the lack of `Initializable.initializer()` modifier.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Review manually the contract's initialization. Consider inheriting a `Initializable.initializer()` modifier.
'''
"""
REQUIRE_CONTRACT = True
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent
if initializable is None:
return []
@ -114,26 +120,26 @@ Review manually the contract's initialization. Consider inheriting a `Initializa
if initializable not in self.contract.inheritance:
return []
initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()')
initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()")
if initializer is None:
info = ['Initializable.initializer() does not exist.\n']
info = ["Initializable.initializer() does not exist.\n"]
json = self.generate_result(info)
return [json]
return []
class MissingInitializerModifier(AbstractCheck):
ARGUMENT = 'missing-init-modifier'
ARGUMENT = "missing-init-modifier"
IMPACT = CheckClassification.HIGH
HELP = 'initializer() is not called'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called'
WIKI_TITLE = 'initializer() is not called'
WIKI_DESCRIPTION = '''
HELP = "initializer() is not called"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called"
WIKI_TITLE = "initializer() is not called"
WIKI_DESCRIPTION = """
Detect if `Initializable.initializer()` is called.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
function initialize() public{
@ -143,23 +149,23 @@ contract Contract{
```
`initialize` should have the `initializer` modifier to prevent someone from initializing the contract multiple times.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Use `Initializable.initializer()`.
'''
"""
REQUIRE_CONTRACT = True
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent
if initializable is None:
return []
# See InitializableInherited
if initializable not in self.contract.inheritance:
return []
initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()')
initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()")
# InitializableInitializer
if initializer is None:
return []
@ -168,24 +174,24 @@ Use `Initializable.initializer()`.
all_init_functions = _get_initialize_functions(self.contract)
for f in all_init_functions:
if initializer not in f.modifiers:
info = [f, ' does not call the initializer modifier.\n']
info = [f, " does not call the initializer modifier.\n"]
json = self.generate_result(info)
results.append(json)
return results
class MissingCalls(AbstractCheck):
ARGUMENT = 'missing-calls'
ARGUMENT = "missing-calls"
IMPACT = CheckClassification.HIGH
HELP = 'Missing calls to init functions'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called'
WIKI_TITLE = 'Initialize functions are not called'
WIKI_DESCRIPTION = '''
HELP = "Missing calls to init functions"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called"
WIKI_TITLE = "Initialize functions are not called"
WIKI_DESCRIPTION = """
Detect missing calls to initialize functions.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Base{
function initialize() public{
@ -200,11 +206,11 @@ contract Derived is Base{
```
`Derived.initialize` does not call `Base.initialize` leading the contract to not be correctly initialized.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Ensure all the initialize functions are reached by the most derived initialize function.
'''
"""
REQUIRE_CONTRACT = True
@ -215,7 +221,7 @@ Ensure all the initialize functions are reached by the most derived initialize f
try:
most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget:
logger.error(red(f'Too many init targets in {self.contract}'))
logger.error(red(f"Too many init targets in {self.contract}"))
return []
if most_derived_init is None:
@ -225,24 +231,24 @@ Ensure all the initialize functions are reached by the most derived initialize f
all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init]
missing_calls = [f for f in all_init_functions if not f in all_init_functions_called]
for f in missing_calls:
info = ['Missing call to ', f, ' in ', most_derived_init, '.\n']
info = ["Missing call to ", f, " in ", most_derived_init, ".\n"]
json = self.generate_result(info)
results.append(json)
return results
class MultipleCalls(AbstractCheck):
ARGUMENT = 'multiple-calls'
ARGUMENT = "multiple-calls"
IMPACT = CheckClassification.HIGH
HELP = 'Init functions called multiple times'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times'
WIKI_TITLE = 'Initialize functions are called multiple times'
WIKI_DESCRIPTION = '''
HELP = "Init functions called multiple times"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times"
WIKI_TITLE = "Initialize functions are called multiple times"
WIKI_DESCRIPTION = """
Detect multiple calls to a initialize function.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Base{
function initialize(uint) public{
@ -264,11 +270,11 @@ contract DerivedDerived is Derived{
```
`Base.initialize(uint)` is called two times in `DerivedDerived.initiliaze` execution, leading to a potential corruption.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Call only one time every initialize function.
'''
"""
REQUIRE_CONTRACT = True
@ -280,38 +286,41 @@ Call only one time every initialize function.
most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget:
# Should be already reported by MissingCalls
#logger.error(red(f'Too many init targets in {self.contract}'))
# logger.error(red(f'Too many init targets in {self.contract}'))
return []
if most_derived_init is None:
return []
all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init]
double_calls = list(set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1]))
double_calls = list(
set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1])
)
for f in double_calls:
info = [f, ' is called multiple times in ', most_derived_init, '.\n']
info = [f, " is called multiple times in ", most_derived_init, ".\n"]
json = self.generate_result(info)
results.append(json)
return results
class InitializeTarget(AbstractCheck):
ARGUMENT = 'initialize-target'
ARGUMENT = "initialize-target"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initialize function that must be called'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function'
WIKI_TITLE = 'Initialize function'
HELP = "Initialize function that must be called"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function"
WIKI_TITLE = "Initialize function"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Show the function that must be called at deployment.
This finding does not have an immediate security impact and is informative.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Ensure that the function is called at deployment.
'''
"""
REQUIRE_CONTRACT = True
@ -322,12 +331,12 @@ Ensure that the function is called at deployment.
most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget:
# Should be already reported by MissingCalls
#logger.error(red(f'Too many init targets in {self.contract}'))
# logger.error(red(f'Too many init targets in {self.contract}'))
return []
if most_derived_init is None:
return []
info = [self.contract, f' needs to be initialized by ', most_derived_init, '.\n']
info = [self.contract, f" needs to be initialized by ", most_derived_init, ".\n"]
json = self.generate_result(info)
return [json]

@ -2,29 +2,29 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat
class VariableWithInit(AbstractCheck):
ARGUMENT = 'variables-initialized'
ARGUMENT = "variables-initialized"
IMPACT = CheckClassification.HIGH
HELP = 'State variables with an initial value'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized'
WIKI_TITLE = 'State variable initialized'
HELP = "State variables with an initial value"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized"
WIKI_TITLE = "State variable initialized"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect state variables that are initialized.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable = 10;
}
```
Using `Contract` will the delegatecall proxy pattern will lead `variable` to be 0 when called through the proxy.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Using initialize functions to write initial values in state variables.
'''
"""
REQUIRE_CONTRACT = True
@ -32,7 +32,7 @@ Using initialize functions to write initial values in state variables.
results = []
for s in self.contract.state_variables:
if s.initialized and not s.is_constant:
info = [s, ' is a state variable with an initial value.\n']
info = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info)
results.append(json)
return results

@ -2,16 +2,16 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat
class MissingVariable(AbstractCheck):
ARGUMENT = 'missing-variables'
ARGUMENT = "missing-variables"
IMPACT = CheckClassification.MEDIUM
HELP = 'Variable missing in the v2'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables'
WIKI_TITLE = 'Missing variables'
WIKI_DESCRIPTION = '''
HELP = "Variable missing in the v2"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables"
WIKI_TITLE = "Missing variables"
WIKI_DESCRIPTION = """
Detect variables that were present in the original contracts but are not in the updated one.
'''
WIKI_EXPLOIT_SCENARIO = '''
"""
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract V1{
uint variable1;
@ -25,11 +25,11 @@ contract V2{
The new version, `V2` does not contain `variable1`.
If a new variable is added in an update of `V2`, this variable will hold the latest value of `variable2` and
will be corrupted.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Do not change the order of the state variables in the updated contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
@ -44,7 +44,7 @@ Do not change the order of the state variables in the updated contract.
for idx in range(0, len(order1)):
variable1 = order1[idx]
if len(order2) <= idx:
info = ['Variable missing in ', contract2, ': ', variable1, '\n']
info = ["Variable missing in ", contract2, ": ", variable1, "\n"]
json = self.generate_result(info)
results.append(json)
@ -52,18 +52,18 @@ Do not change the order of the state variables in the updated contract.
class DifferentVariableContractProxy(AbstractCheck):
ARGUMENT = 'order-vars-proxy'
ARGUMENT = "order-vars-proxy"
IMPACT = CheckClassification.HIGH
HELP = 'Incorrect vars order with the proxy'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy'
WIKI_TITLE = 'Incorrect variables with the proxy'
HELP = "Incorrect vars order with the proxy"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy"
WIKI_TITLE = "Incorrect variables with the proxy"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect variables that are different between the contract and the proxy.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -74,11 +74,11 @@ contract Proxy{
}
```
`Contract` and `Proxy` do not have the same storage layout. As a result the storage of both contracts can be corrupted.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -104,9 +104,9 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info = ['Different variables between ', contract1, ' and ', contract2, '\n']
info += [f'\t ', variable1, '\n']
info += [f'\t ', variable2, '\n']
info = ["Different variables between ", contract1, " and ", contract2, "\n"]
info += [f"\t ", variable1, "\n"]
info += [f"\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
@ -114,17 +114,17 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
class DifferentVariableContractNewContract(DifferentVariableContractProxy):
ARGUMENT = 'order-vars-contracts'
ARGUMENT = "order-vars-contracts"
HELP = 'Incorrect vars order with the v2'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2'
WIKI_TITLE = 'Incorrect variables with the v2'
HELP = "Incorrect vars order with the v2"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2"
WIKI_TITLE = "Incorrect variables with the v2"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect variables that are different between the original contract and the updated one.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -135,11 +135,11 @@ contract ContractV2{
}
```
`Contract` and `ContractV2` do not have the same storage layout. As a result the storage of both contracts can be corrupted.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Respect the variable order of the original contract in the updated contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = False
@ -150,18 +150,20 @@ Respect the variable order of the original contract in the updated contract.
class ExtraVariablesProxy(AbstractCheck):
ARGUMENT = 'extra-vars-proxy'
ARGUMENT = "extra-vars-proxy"
IMPACT = CheckClassification.MEDIUM
HELP = 'Extra vars in the proxy'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy'
WIKI_TITLE = 'Extra variables in the proxy'
HELP = "Extra vars in the proxy"
WIKI = (
"https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy"
)
WIKI_TITLE = "Extra variables in the proxy"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect variables that are in the proxy and not in the contract.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -173,11 +175,11 @@ contract Proxy{
}
```
`Proxy` contains additional variables. A future update of `Contract` is likely to corrupt the proxy.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -203,7 +205,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
while idx < len(order2):
variable2 = order2[idx]
info = ['Extra variables in ', contract2, ': ', variable2, '\n']
info = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1
@ -212,21 +214,21 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
class ExtraVariablesNewContract(ExtraVariablesProxy):
ARGUMENT = 'extra-vars-v2'
ARGUMENT = "extra-vars-v2"
HELP = 'Extra vars in the v2'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2'
WIKI_TITLE = 'Extra variables in the v2'
HELP = "Extra vars in the v2"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2"
WIKI_TITLE = "Extra variables in the v2"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Show new variables in the updated contract.
This finding does not have an immediate security impact and is informative.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Ensure that all the new variables are expected.
'''
"""
IMPACT = CheckClassification.INFORMATIONAL

@ -4,7 +4,9 @@ from slither.utils.myprettytable import MyPrettyTable
def output_wiki(detector_classes, filter_wiki):
# Sort by impact, confidence, and name
detectors_list = sorted(detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT))
detectors_list = sorted(
detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)
)
for detector in detectors_list:
if filter_wiki not in detector.WIKI:
@ -16,16 +18,16 @@ def output_wiki(detector_classes, filter_wiki):
exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO
recommendation = detector.WIKI_RECOMMENDATION
print('\n## {}'.format(title))
print('### Configuration')
print('* Check: `{}`'.format(argument))
print('* Severity: `{}`'.format(impact))
print('\n### Description')
print("\n## {}".format(title))
print("### Configuration")
print("* Check: `{}`".format(argument))
print("* Severity: `{}`".format(impact))
print("\n### Description")
print(description)
if exploit_scenario:
print('\n### Exploit Scenario:')
print("\n### Exploit Scenario:")
print(exploit_scenario)
print('\n### Recommendation')
print("\n### Recommendation")
print(recommendation)
@ -38,27 +40,31 @@ def output_detectors(detector_classes):
require_proxy = detector.REQUIRE_PROXY
require_v2 = detector.REQUIRE_CONTRACT_V2
detectors_list.append((argument, help_info, impact, require_proxy, require_v2))
table = MyPrettyTable(["Num",
"Check",
"What it Detects",
"Impact",
"Proxy",
"Contract V2"])
table = MyPrettyTable(["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"])
# Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
for (argument, help_info, impact, proxy, v2) in detectors_list:
table.add_row([idx, argument, help_info, classification_txt[impact], 'X' if proxy else '', 'X' if v2 else ''])
table.add_row(
[
idx,
argument,
help_info,
classification_txt[impact],
"X" if proxy else "",
"X" if v2 else "",
]
)
idx = idx + 1
print(table)
def output_to_markdown(detector_classes, filter_wiki):
def extract_help(cls):
if cls.WIKI == '':
if cls.WIKI == "":
return cls.HELP
return '[{}]({})'.format(cls.HELP, cls.WIKI)
return "[{}]({})".format(cls.HELP, cls.WIKI)
detectors_list = []
for detector in detector_classes:
@ -73,12 +79,16 @@ def output_to_markdown(detector_classes, filter_wiki):
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
for (argument, help_info, impact, proxy, v2) in detectors_list:
print('{} | `{}` | {} | {} | {} | {}'.format(idx,
print(
"{} | `{}` | {} | {} | {} | {}".format(
idx,
argument,
help_info,
classification_txt[impact],
'X' if proxy else '',
'X' if v2 else ''))
"X" if proxy else "",
"X" if v2 else "",
)
)
idx = idx + 1
@ -92,26 +102,42 @@ def output_detectors_json(detector_classes):
wiki_description = detector.WIKI_DESCRIPTION
wiki_exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO
wiki_recommendation = detector.WIKI_RECOMMENDATION
detectors_list.append((argument,
detectors_list.append(
(
argument,
help_info,
impact,
wiki_url,
wiki_description,
wiki_exploit_scenario,
wiki_recommendation))
wiki_recommendation,
)
)
# Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
table = []
for (argument, help_info, impact, wiki_url, description, exploit, recommendation) in detectors_list:
table.append({'index': idx,
'check': argument,
'title': help_info,
'impact': classification_txt[impact],
'wiki_url': wiki_url,
'description': description,
'exploit_scenario': exploit,
'recommendation': recommendation})
for (
argument,
help_info,
impact,
wiki_url,
description,
exploit,
recommendation,
) in detectors_list:
table.append(
{
"index": idx,
"check": argument,
"title": help_info,
"impact": classification_txt[impact],
"wiki_url": wiki_url,
"description": description,
"exploit_scenario": exploit,
"recommendation": recommendation,
}
)
idx = idx + 1
return table

@ -3,11 +3,13 @@ import logging
from .expression import ExpressionVisitor
from slither.core.expressions import BinaryOperationType, Literal
class NotConstant(Exception):
pass
KEY = 'ConstantFolding'
KEY = "ConstantFolding"
def get_val(expression):
val = expression.context[KEY]
@ -15,11 +17,12 @@ def get_val(expression):
del expression.context[KEY]
return val
def set_val(expression, val):
expression.context[KEY] = val
class ConstantFolding(ExpressionVisitor):
class ConstantFolding(ExpressionVisitor):
def __init__(self, expression, type):
self._type = type
super(ConstantFolding, self).__init__(expression)
@ -51,7 +54,7 @@ class ConstantFolding(ExpressionVisitor):
elif expression.type == BinaryOperationType.ADDITION:
set_val(expression, left + right)
elif expression.type == BinaryOperationType.SUBTRACTION:
if(left-right) <0:
if (left - right) < 0:
# Could trigger underflow
raise NotConstant
set_val(expression, left - right)
@ -110,6 +113,3 @@ class ConstantFolding(ExpressionVisitor):
def _post_type_conversion(self, expression):
raise NotConstant

@ -1,11 +1,11 @@
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType
from slither.core.variables.variable import Variable
key = 'ExportValues'
key = "ExportValues"
def get(expression):
val = expression.context[key]
@ -13,11 +13,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class ExportValues(ExpressionVisitor):
class ExportValues(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))

@ -1,10 +1,12 @@
import logging
from typing import Any
from slither.core.expressions.assignment_operation import AssignmentOperation
from slither.core.expressions.binary_operation import BinaryOperation
from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.conditional_expression import ConditionalExpression
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.expressions.expression import Expression
from slither.core.expressions.identifier import Identifier
from slither.core.expressions.index_access import IndexAccess
from slither.core.expressions.literal import Literal
@ -19,24 +21,24 @@ from slither.exceptions import SlitherError
logger = logging.getLogger("ExpressionVisitor")
class ExpressionVisitor:
def __init__(self, expression):
class ExpressionVisitor:
def __init__(self, expression: Expression):
# Inherited class must declared their variables prior calling super().__init__
self._expression = expression
self._result = None
self._result: Any = None
self._visit_expression(self.expression)
def result(self):
return self._result
@property
def expression(self):
def expression(self) -> Expression:
return self._expression
# visit an expression
# call pre_visit, visit_expression_name, post_visit
def _visit_expression(self, expression):
def _visit_expression(self, expression: Expression):
self._pre_visit(expression)
if isinstance(expression, AssignmentOperation):
@ -88,7 +90,7 @@ class ExpressionVisitor:
pass
else:
raise SlitherError('Expression not handled: {}'.format(expression))
raise SlitherError("Expression not handled: {}".format(expression))
self._post_visit(expression)
@ -207,7 +209,7 @@ class ExpressionVisitor:
pass
else:
raise SlitherError('Expression not handled: {}'.format(expression))
raise SlitherError("Expression not handled: {}".format(expression))
# pre_expression_name
@ -308,7 +310,7 @@ class ExpressionVisitor:
pass
else:
raise SlitherError('Expression not handled: {}'.format(expression))
raise SlitherError("Expression not handled: {}".format(expression))
# post_expression_name
@ -356,5 +358,3 @@ class ExpressionVisitor:
def _post_unary_operation(self, expression):
pass

@ -1,17 +1,18 @@
from slither.visitors.expression.expression import ExpressionVisitor
def get(expression):
val = expression.context['ExpressionPrinter']
val = expression.context["ExpressionPrinter"]
# we delete the item to reduce memory use
del expression.context['ExpressionPrinter']
del expression.context["ExpressionPrinter"]
return val
def set_val(expression, val):
expression.context['ExpressionPrinter'] = val
expression.context["ExpressionPrinter"] = val
class ExpressionPrinter(ExpressionVisitor):
class ExpressionPrinter(ExpressionVisitor):
def result(self):
if not self._result:
self._result = get(self.expression)
@ -20,19 +21,19 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_assignement_operation(self, expression):
left = get(expression.expression_left)
right = get(expression.expression_right)
val = "{} {} {}".format(left, expression.type_str, right)
val = "{} {} {}".format(left, expression.type, right)
set_val(expression, val)
def _post_binary_operation(self, expression):
left = get(expression.expression_left)
right = get(expression.expression_right)
val = "{} {} {}".format(left, expression.type_str, right)
val = "{} {} {}".format(left, expression.type, right)
set_val(expression, val)
def _post_call_expression(self, expression):
called = get(expression.called)
arguments = [get(x) for x in expression.arguments if x]
val = "{}({})".format(called, ','.join(arguments))
val = "{}({})".format(called, ",".join(arguments))
set_val(expression, val)
def _post_conditional_expression(self, expression):
@ -66,7 +67,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_new_array(self, expression):
array = str(expression.array_type)
depth = expression.depth
val = "new {}{}".format(array, '[]'*depth)
val = "new {}{}".format(array, "[]" * depth)
set_val(expression, val)
def _post_new_contract(self, expression):
@ -81,7 +82,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_tuple_expression(self, expression):
expressions = [get(e) for e in expression.expressions if e]
val = "({})".format(','.join(expressions))
val = "({})".format(",".join(expressions))
set_val(expression, val)
def _post_type_conversion(self, expression):

@ -1,11 +1,10 @@
from typing import List
from slither.core.expressions.expression import Expression
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType
key = "FindCall"
from slither.core.variables.variable import Variable
key = 'FindCall'
def get(expression):
val = expression.context[key]
@ -13,12 +12,13 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class FindCalls(ExpressionVisitor):
def result(self):
class FindCalls(ExpressionVisitor):
def result(self) -> List[Expression]:
if self._result is None:
self._result = list(set(get(self.expression)))
return self._result

@ -4,7 +4,8 @@ from slither.core.expressions.index_access import IndexAccess
from slither.visitors.expression.right_value import RightValue
key = 'FindPush'
key = "FindPush"
def get(expression):
val = expression.context[key]
@ -12,11 +13,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class FindPush(ExpressionVisitor):
class FindPush(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -66,7 +68,7 @@ class FindPush(ExpressionVisitor):
def _post_member_access(self, expression):
val = []
if expression.member_name == 'push':
if expression.member_name == "push":
right = RightValue(expression.expression)
val = right.result()
set_val(expression, val)

@ -1,13 +1,12 @@
from slither.visitors.expression.expression import ExpressionVisitor
class HasConditional(ExpressionVisitor):
class HasConditional(ExpressionVisitor):
def result(self):
# == True, to convert None to false
return self._result is True
def _post_conditional_expression(self, expression):
# if self._result is True:
# raise('Slither does not support nested ternary operator')
# if self._result is True:
# raise('Slither does not support nested ternary operator')
self._result = True

@ -6,7 +6,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp
from slither.core.variables.variable import Variable
key = 'LeftValue'
key = "LeftValue"
def get(expression):
val = expression.context[key]
@ -14,11 +15,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class LeftValue(ExpressionVisitor):
class LeftValue(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -64,8 +66,8 @@ class LeftValue(ExpressionVisitor):
def _post_identifier(self, expression):
if isinstance(expression.value, Variable):
set_val(expression, [expression.value])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
else:
set_val(expression, [])

@ -1,4 +1,3 @@
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType
@ -6,7 +5,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp
from slither.core.variables.variable import Variable
from slither.core.declarations.solidity_variables import SolidityVariable
key = 'ReadVar'
key = "ReadVar"
def get(expression):
val = expression.context[key]
@ -14,17 +14,17 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class ReadVar(ExpressionVisitor):
class ReadVar(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
return self._result
# overide assignement
# dont explore if its direct assignement (we explore if its +=, -=, ...)
def _visit_assignement_operation(self, expression):

@ -10,7 +10,8 @@ from slither.core.expressions.expression import Expression
from slither.core.variables.variable import Variable
key = 'RightValue'
key = "RightValue"
def get(expression):
val = expression.context[key]
@ -18,11 +19,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class RightValue(ExpressionVisitor):
class RightValue(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -68,8 +70,8 @@ class RightValue(ExpressionVisitor):
def _post_identifier(self, expression):
if isinstance(expression.value, Variable):
set_val(expression, [expression.value])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
else:
set_val(expression, [])

@ -1,4 +1,3 @@
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperation
@ -10,7 +9,8 @@ from slither.core.expressions.member_access import MemberAccess
from slither.core.expressions.index_access import IndexAccess
key = 'WriteVar'
key = "WriteVar"
def get(expression):
val = expression.context[key]
@ -18,11 +18,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class WriteVar(ExpressionVisitor):
class WriteVar(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -71,10 +72,11 @@ class WriteVar(ExpressionVisitor):
set_val(expression, [expression])
else:
set_val(expression, [])
# if isinstance(expression.value, Variable):
# set_val(expression, [expression.value])
# else:
# set_val(expression, [])
# if isinstance(expression.value, Variable):
# set_val(expression, [expression.value])
# else:
# set_val(expression, [])
def _post_index_access(self, expression):
left = get(expression.expression_left)

Loading…
Cancel
Save