Merge slither/tools slither/visitors/expression from dev-0.7

pull/514/head
Josselin 4 years ago
parent 010d84125a
commit 4319bb3605
  1. 14
      slither/tools/demo/__main__.py
  2. 39
      slither/tools/erc_conformance/__main__.py
  3. 14
      slither/tools/erc_conformance/erc/erc20.py
  4. 99
      slither/tools/erc_conformance/erc/ercs.py
  5. 48
      slither/tools/kspec_coverage/__main__.py
  6. 95
      slither/tools/kspec_coverage/analysis.py
  7. 3
      slither/tools/kspec_coverage/kspec_coverage.py
  8. 21
      slither/tools/possible_paths/__main__.py
  9. 39
      slither/tools/possible_paths/possible_paths.py
  10. 98
      slither/tools/properties/__main__.py
  11. 8
      slither/tools/properties/addresses/address.py
  12. 6
      slither/tools/properties/platforms/echidna.py
  13. 68
      slither/tools/properties/platforms/truffle.py
  14. 130
      slither/tools/properties/properties/erc20.py
  15. 48
      slither/tools/properties/properties/ercs/erc20/properties/burn.py
  16. 135
      slither/tools/properties/properties/ercs/erc20/properties/initialization.py
  17. 27
      slither/tools/properties/properties/ercs/erc20/properties/mint.py
  18. 30
      slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py
  19. 327
      slither/tools/properties/properties/ercs/erc20/properties/transfer.py
  20. 36
      slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py
  21. 2
      slither/tools/properties/properties/properties.py
  22. 69
      slither/tools/properties/solidity/generate_properties.py
  23. 20
      slither/tools/properties/utils.py
  24. 103
      slither/tools/similarity/__main__.py
  25. 8
      slither/tools/similarity/cache.py
  26. 167
      slither/tools/similarity/encode.py
  27. 19
      slither/tools/similarity/info.py
  28. 45
      slither/tools/similarity/plot.py
  29. 1
      slither/tools/similarity/similarity.py
  30. 25
      slither/tools/similarity/test.py
  31. 39
      slither/tools/similarity/train.py
  32. 115
      slither/tools/slither_format/__main__.py
  33. 120
      slither/tools/slither_format/slither_format.py
  34. 158
      slither/tools/upgradeability/__main__.py
  35. 94
      slither/tools/upgradeability/checks/abstract_checks.py
  36. 22
      slither/tools/upgradeability/checks/all_checks.py
  37. 55
      slither/tools/upgradeability/checks/constant.py
  38. 83
      slither/tools/upgradeability/checks/functions_ids.py
  39. 173
      slither/tools/upgradeability/checks/initialization.py
  40. 22
      slither/tools/upgradeability/checks/variable_initialization.py
  41. 108
      slither/tools/upgradeability/checks/variables_order.py
  42. 104
      slither/tools/upgradeability/utils/command_line.py
  43. 28
      slither/visitors/expression/constants_folding.py
  44. 7
      slither/visitors/expression/export_values.py
  45. 20
      slither/visitors/expression/expression.py
  46. 21
      slither/visitors/expression/expression_printer.py
  47. 12
      slither/visitors/expression/find_calls.py
  48. 8
      slither/visitors/expression/find_push.py
  49. 7
      slither/visitors/expression/has_conditional.py
  50. 10
      slither/visitors/expression/left_value.py
  51. 8
      slither/visitors/expression/read_var.py
  52. 10
      slither/visitors/expression/right_value.py
  53. 36
      slither/visitors/expression/write_var.py

@ -9,16 +9,17 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-demo")
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Demo',
usage='slither-demo filename')
parser = argparse.ArgumentParser(description="Demo", usage="slither-demo filename")
parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
@ -32,7 +33,8 @@ def main():
# Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args))
logger.info('Analysis done!')
logger.info("Analysis done!")
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -17,28 +17,29 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
ADDITIONAL_CHECKS = {
"ERC20": check_erc20
}
ADDITIONAL_CHECKS = {"ERC20": check_erc20}
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Check the ERC 20 conformance',
usage='slither-erc project contractName')
parser = argparse.ArgumentParser(
description="Check the ERC 20 conformance", usage="slither-erc project contractName"
)
parser.add_argument('project',
help='The codebase to be tested.')
parser.add_argument("project", help="The codebase to be tested.")
parser.add_argument('contract_name',
help='The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.')
parser.add_argument(
"contract_name",
help="The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.",
)
parser.add_argument(
"--erc",
@ -47,22 +48,26 @@ def parse_args():
default="erc20",
)
parser.add_argument('--json',
help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store',
default=False)
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action="store",
default=False,
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
return parser.parse_args()
def _log_error(err, args):
if args.json:
output_to_json(args.json, str(err), {"upgradeability-check": []})
logger.error(err)
def main():
args = parse_args()
@ -76,7 +81,7 @@ def main():
contract = slither.get_contract_from_name(args.contract_name)
if not contract:
err = f'Contract not found: {args.contract_name}'
err = f"Contract not found: {args.contract_name}"
_log_error(err, args)
return
# First elem is the function, second is the event
@ -87,7 +92,7 @@ def main():
ADDITIONAL_CHECKS[args.erc.upper()](contract, ret)
else:
err = f'Incorrect ERC selected {args.erc}'
err = f"Incorrect ERC selected {args.erc}"
_log_error(err, args)
return
@ -95,5 +100,5 @@ def main():
output_to_json(args.json, None, {"upgradeability-check": ret})
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -6,21 +6,25 @@ logger = logging.getLogger("Slither-conformance")
def approval_race_condition(contract, ret):
increaseAllowance = contract.get_function_from_signature('increaseAllowance(address,uint256)')
increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)")
if not increaseAllowance:
increaseAllowance = contract.get_function_from_signature('safeIncreaseAllowance(address,uint256)')
increaseAllowance = contract.get_function_from_signature(
"safeIncreaseAllowance(address,uint256)"
)
if increaseAllowance:
txt = f'\t[✓] {contract.name} has {increaseAllowance.full_name}'
txt = f"\t[✓] {contract.name} has {increaseAllowance.full_name}"
logger.info(txt)
else:
txt = f'\t[ ] {contract.name} is not protected for the ERC20 approval race condition'
txt = f"\t[ ] {contract.name} is not protected for the ERC20 approval race condition"
logger.info(txt)
lack_of_erc20_race_condition_protection = output.Output(txt)
lack_of_erc20_race_condition_protection.add(contract)
ret["lack_of_erc20_race_condition_protection"].append(lack_of_erc20_race_condition_protection.data)
ret["lack_of_erc20_race_condition_protection"].append(
lack_of_erc20_race_condition_protection.data
)
def check_erc20(contract, ret, explored=None):

@ -22,13 +22,15 @@ def _check_signature(erc_function, contract, ret):
# The check on state variable is needed until we have a better API to handle state variable getters
state_variable_as_function = contract.get_state_variable_from_name(name)
if not state_variable_as_function or not state_variable_as_function.visibility in ['public', 'external']:
if not state_variable_as_function or not state_variable_as_function.visibility in [
"public",
"external",
]:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt)
missing_func = output.Output(txt, additional_fields={
"function": sig,
"required": required
})
missing_func = output.Output(
txt, additional_fields={"function": sig, "required": required}
)
missing_func.add(contract)
ret["missing_function"].append(missing_func.data)
return
@ -38,10 +40,9 @@ def _check_signature(erc_function, contract, ret):
if types != parameters:
txt = f'[ ] {sig} is missing {"" if required else "(optional)"}'
logger.info(txt)
missing_func = output.Output(txt, additional_fields={
"function": sig,
"required": required
})
missing_func = output.Output(
txt, additional_fields={"function": sig, "required": required}
)
missing_func.add(contract)
ret["missing_function"].append(missing_func.data)
return
@ -53,45 +54,51 @@ def _check_signature(erc_function, contract, ret):
function_return_type = function.return_type
function_view = function.view
txt = f'[✓] {sig} is present'
txt = f"[✓] {sig} is present"
logger.info(txt)
if function_return_type:
function_return_type = ','.join([str(x) for x in function_return_type])
function_return_type = ",".join([str(x) for x in function_return_type])
if function_return_type == return_type:
txt = f'\t[✓] {sig} -> () (correct return value)'
txt = f"\t[✓] {sig} -> () (correct return value)"
logger.info(txt)
else:
txt = f'\t[ ] {sig} -> () should return {return_type}'
txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={
"expected_return_type": return_type,
"actual_return_type": function_return_type
})
incorrect_return = output.Output(
txt,
additional_fields={
"expected_return_type": return_type,
"actual_return_type": function_return_type,
},
)
incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data)
elif not return_type:
txt = f'\t[✓] {sig} -> () (correct return type)'
txt = f"\t[✓] {sig} -> () (correct return type)"
logger.info(txt)
else:
txt = f'\t[ ] {sig} -> () should return {return_type}'
txt = f"\t[ ] {sig} -> () should return {return_type}"
logger.info(txt)
incorrect_return = output.Output(txt, additional_fields={
"expected_return_type": return_type,
"actual_return_type": function_return_type
})
incorrect_return = output.Output(
txt,
additional_fields={
"expected_return_type": return_type,
"actual_return_type": function_return_type,
},
)
incorrect_return.add(function)
ret["incorrect_return_type"].append(incorrect_return.data)
if view:
if function_view:
txt = f'\t[✓] {sig} is view'
txt = f"\t[✓] {sig} is view"
logger.info(txt)
else:
txt = f'\t[ ] {sig} should be view'
txt = f"\t[ ] {sig} should be view"
logger.info(txt)
should_be_view = output.Output(txt)
@ -103,12 +110,12 @@ def _check_signature(erc_function, contract, ret):
event_sig = f'{event.name}({",".join(event.parameters)})'
if not function:
txt = f'\t[ ] Must emit be view {event_sig}'
txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={
"missing_event": event_sig
})
missing_event_emmited = output.Output(
txt, additional_fields={"missing_event": event_sig}
)
missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -121,15 +128,15 @@ def _check_signature(erc_function, contract, ret):
event_found = True
break
if event_found:
txt = f'\t[✓] {event_sig} is emitted'
txt = f"\t[✓] {event_sig} is emitted"
logger.info(txt)
else:
txt = f'\t[ ] Must emit be view {event_sig}'
txt = f"\t[ ] Must emit be view {event_sig}"
logger.info(txt)
missing_event_emmited = output.Output(txt, additional_fields={
"missing_event": event_sig
})
missing_event_emmited = output.Output(
txt, additional_fields={"missing_event": event_sig}
)
missing_event_emmited.add(function)
ret["missing_event_emmited"].append(missing_event_emmited.data)
@ -143,31 +150,27 @@ def _check_events(erc_event, contract, ret):
event = contract.get_event_from_signature(sig)
if not event:
txt = f'[ ] {sig} is missing'
txt = f"[ ] {sig} is missing"
logger.info(txt)
missing_event = output.Output(txt, additional_fields={
"event": sig
})
missing_event = output.Output(txt, additional_fields={"event": sig})
missing_event.add(contract)
ret["missing_event"].append(missing_event.data)
return
txt = f'[✓] {sig} is present'
txt = f"[✓] {sig} is present"
logger.info(txt)
for i, index in enumerate(indexes):
if index:
if event.elems[i].indexed:
txt = f'\t[✓] parameter {i} is indexed'
txt = f"\t[✓] parameter {i} is indexed"
logger.info(txt)
else:
txt = f'\t[ ] parameter {i} should be indexed'
txt = f"\t[ ] parameter {i} should be indexed"
logger.info(txt)
missing_event_index = output.Output(txt, additional_fields={
"missing_index": i
})
missing_event_index = output.Output(txt, additional_fields={"missing_index": i})
missing_event_index.add_event(event)
ret["missing_event_index"].append(missing_event_index.data)
@ -179,16 +182,16 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None):
explored.add(contract)
logger.info(f'# Check {contract.name}\n')
logger.info(f"# Check {contract.name}\n")
logger.info(f'## Check functions')
logger.info(f"## Check functions")
for erc_function in erc_functions:
_check_signature(erc_function, contract, ret)
logger.info(f'\n## Check events')
logger.info(f"\n## Check events")
for erc_event in erc_events:
_check_events(erc_event, contract, ret)
logger.info('\n')
logger.info("\n")
for derived_contract in contract.derived_contracts:
generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored)

@ -11,35 +11,44 @@ logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='slither-kspec-coverage',
usage='slither-kspec-coverage contract.sol kspec.md')
parser.add_argument('contract', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('kspec', help='The filename of the Klab spec markdown for the analyzed contract(s)')
parser.add_argument('--version', help='displays the current version', version='0.1.0',action='version')
parser.add_argument('--json',
help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store',
default=False
parser = argparse.ArgumentParser(
description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md"
)
parser.add_argument(
"contract", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument(
"kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)"
)
cryticparser.init(parser)
if len(sys.argv) < 2:
parser.print_help(sys.stderr)
parser.add_argument(
"--version", help="displays the current version", version="0.1.0", action="version"
)
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action="store",
default=False,
)
cryticparser.init(parser)
if len(sys.argv) < 2:
parser.print_help(sys.stderr)
sys.exit(1)
return parser.parse_args()
@ -53,6 +62,7 @@ def main():
args = parse_args()
kspec_coverage(args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -7,25 +7,22 @@ from slither.utils.colors import yellow, green, red
from slither.utils import output
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('Slither.kspec')
logger = logging.getLogger("Slither.kspec")
def _refactor_type(type):
return {
'uint': 'uint256',
'int': 'int256'
}.get(type, type)
return {"uint": "uint256", "int": "int256"}.get(type, type)
def _get_all_covered_kspec_functions(target):
# Create a set of our discovered functions which are covered
covered_functions = set()
BEHAVIOUR_PATTERN = re.compile('behaviour\s+(\S+)\s+of\s+(\S+)')
INTERFACE_PATTERN = re.compile('interface\s+([^\r\n]+)')
BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)")
INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)")
# Read the file contents
with open(target, 'r', encoding='utf8') as target_file:
with open(target, "r", encoding="utf8") as target_file:
lines = target_file.readlines()
# Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex,
@ -38,10 +35,12 @@ def _get_all_covered_kspec_functions(target):
match = INTERFACE_PATTERN.match(lines[i + 1])
if match:
function_full_name = match.groups()[0]
start, end = function_full_name.index('(') + 1, function_full_name.index(')')
function_arguments = function_full_name[start:end].split(',')
function_arguments = [_refactor_type(arg.strip().split(' ')[0]) for arg in function_arguments]
function_full_name = function_full_name[:start] + ','.join(function_arguments) + ')'
start, end = function_full_name.index("(") + 1, function_full_name.index(")")
function_arguments = function_full_name[start:end].split(",")
function_arguments = [
_refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments
]
function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")"
covered_functions.add((contract_name, function_full_name))
i += 1
i += 1
@ -50,14 +49,25 @@ def _get_all_covered_kspec_functions(target):
def _get_slither_functions(slither):
# Use contract == contract_declarer to avoid dupplicate
all_functions_declared = [f for f in slither.functions if (f.contract == f.contract_declarer
and f.is_implemented
and not f.is_constructor
and not f.is_constructor_variables)]
all_functions_declared = [
f
for f in slither.functions
if (
f.contract == f.contract_declarer
and f.is_implemented
and not f.is_constructor
and not f.is_constructor_variables
)
]
# Use list(set()) because same state variable instances can be shared accross contracts
# TODO: integrate state variables
all_functions_declared += list(set([s for s in slither.state_variables if s.visibility in ['public', 'external']]))
slither_functions = {(function.contract.name, function.full_name): function for function in all_functions_declared}
all_functions_declared += list(
set([s for s in slither.state_variables if s.visibility in ["public", "external"]])
)
slither_functions = {
(function.contract.name, function.full_name): function
for function in all_functions_declared
}
return slither_functions
@ -110,35 +120,42 @@ def _run_coverage_analysis(args, slither, kspec_functions):
else:
kspec_missing.append(slither_func)
logger.info('## Check for functions coverage')
logger.info("## Check for functions coverage")
json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json)
json_kspec_missing_functions = _generate_output([f for f in kspec_missing if isinstance(f, Function)],
"[ ] (Missing function)",
red,
args.json)
json_kspec_missing_variables = _generate_output([f for f in kspec_missing if isinstance(f, Variable)],
"[ ] (Missing variable)",
yellow,
args.json)
json_kspec_unresolved = _generate_output_unresolved(kspec_functions_unresolved,
"[ ] (Unresolved)",
yellow,
args.json)
json_kspec_missing_functions = _generate_output(
[f for f in kspec_missing if isinstance(f, Function)],
"[ ] (Missing function)",
red,
args.json,
)
json_kspec_missing_variables = _generate_output(
[f for f in kspec_missing if isinstance(f, Variable)],
"[ ] (Missing variable)",
yellow,
args.json,
)
json_kspec_unresolved = _generate_output_unresolved(
kspec_functions_unresolved, "[ ] (Unresolved)", yellow, args.json
)
# Handle unresolved kspecs
if args.json:
output.output_to_json(args.json, None, {
"functions_present": json_kspec_present,
"functions_missing": json_kspec_missing_functions,
"variables_missing": json_kspec_missing_variables,
"functions_unresolved": json_kspec_unresolved
})
output.output_to_json(
args.json,
None,
{
"functions_present": json_kspec_present,
"functions_missing": json_kspec_missing_functions,
"variables_missing": json_kspec_missing_variables,
"functions_unresolved": json_kspec_unresolved,
},
)
def run_analysis(args, slither, kspec):
# Get all of our kspec'd functions (tuple(contract_name, function_name)).
if ',' in kspec:
kspecs = kspec.split(',')
if "," in kspec:
kspecs = kspec.split(",")
kspec_functions = set()
for kspec in kspecs:
kspec_functions |= _get_all_covered_kspec_functions(kspec)

@ -1,6 +1,7 @@
from slither.tools.kspec_coverage.analysis import run_analysis
from slither import Slither
def kspec_coverage(args):
contract = args.contract
@ -10,5 +11,3 @@ def kspec_coverage(args):
# Run the analysis on the Klab specs
run_analysis(args, slither, kspec)

@ -9,18 +9,21 @@ from crytic_compile import cryticparser
logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO)
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='PossiblePaths',
usage='possible_paths.py filename [contract.function targets]')
parser = argparse.ArgumentParser(
description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]"
)
parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument('targets', nargs='+')
parser.add_argument("targets", nargs="+")
cryticparser.init(parser)
@ -62,12 +65,16 @@ def main():
print("\n")
# Format all function paths.
reaching_paths_str = [' -> '.join([f"{f.canonical_name}" for f in reaching_path]) for reaching_path in reaching_paths]
reaching_paths_str = [
" -> ".join([f"{f.canonical_name}" for f in reaching_path])
for reaching_path in reaching_paths
]
# Print a sorted list of all function paths which can reach the targets.
print(f"The following paths reach the specified targets:")
for reaching_path in sorted(reaching_paths_str):
print(f"{reaching_path}\n")
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -1,4 +1,5 @@
class ResolveFunctionException(Exception): pass
class ResolveFunctionException(Exception):
pass
def resolve_function(slither, contract_name, function_name):
@ -16,11 +17,15 @@ def resolve_function(slither, contract_name, function_name):
raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}")
# Obtain the target function
target_function = next((function for function in contract.functions if function.name == function_name), None)
target_function = next(
(function for function in contract.functions if function.name == function_name), None
)
# Verify we have resolved the function specified.
if target_function is None:
raise ResolveFunctionException(f"Could not resolve target function: {contract_name}.{function_name}")
raise ResolveFunctionException(
f"Could not resolve target function: {contract_name}.{function_name}"
)
# Add the resolved function to the new list.
return target_function
@ -44,17 +49,23 @@ def resolve_functions(slither, functions):
for item in functions:
if isinstance(item, str):
# If the item is a single string, we assume it is of form 'ContractName.FunctionName'.
parts = item.split('.')
parts = item.split(".")
if len(parts) < 2:
raise ResolveFunctionException("Provided string descriptor must be of form 'ContractName.FunctionName'")
raise ResolveFunctionException(
"Provided string descriptor must be of form 'ContractName.FunctionName'"
)
resolved.append(resolve_function(slither, parts[0], parts[1]))
elif isinstance(item, tuple):
# If the item is a tuple, it should be a 2-tuple providing contract and function names.
if len(item) != 2:
raise ResolveFunctionException("Provided tuple descriptor must provide a contract and function name.")
raise ResolveFunctionException(
"Provided tuple descriptor must provide a contract and function name."
)
resolved.append(resolve_function(slither, item[0], item[1]))
else:
raise ResolveFunctionException(f"Unexpected function descriptor type to resolve in list: {type(item)}")
raise ResolveFunctionException(
f"Unexpected function descriptor type to resolve in list: {type(item)}"
)
# Return the resolved list.
return resolved
@ -66,9 +77,12 @@ def all_function_definitions(function):
:param function: The function to obtain all definitions at and beneath.
:return: Returns a list composed of the provided function definition and any base definitions.
"""
return [function] + [f for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name]
return [function] + [
f
for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name
]
def __find_target_paths(slither, target_function, current_path=[]):
@ -102,7 +116,7 @@ def __find_target_paths(slither, target_function, current_path=[]):
results = results.union(path_results)
# If this path is external accessible from this point, we add the current path to the list.
if target_function.visibility in ['public', 'external'] and len(current_path) > 1:
if target_function.visibility in ["public", "external"] and len(current_path) > 1:
results.add(tuple(current_path))
return results
@ -122,6 +136,3 @@ def find_target_paths(slither, target_functions):
results = results.union(__find_target_paths(slither, target_function))
return results

@ -16,20 +16,21 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither")
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
def _all_scenarios():
txt = '\n'
txt += '#################### ERC20 ####################\n'
txt = "\n"
txt += "#################### ERC20 ####################\n"
for k, value in ERC20_PROPERTIES.items():
txt += f'{k} - {value.description}\n'
txt += f"{k} - {value.description}\n"
return txt
def _all_properties():
table = MyPrettyTable(["Num", "Description", "Scenario"])
idx = 0
@ -39,6 +40,7 @@ def _all_properties():
idx = idx + 1
return table
class ListScenarios(argparse.Action):
def __call__(self, parser, *args, **kwargs):
logger.info(_all_scenarios())
@ -56,43 +58,51 @@ def parse_args():
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Demo',
usage='slither-demo filename',
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('--contract',
help='The targeted contract.')
parser.add_argument('--scenario',
help=f'Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable',
default='Transferable')
parser.add_argument('--list-scenarios',
help='List available scenarios',
action=ListScenarios,
nargs=0,
default=False)
parser.add_argument('--list-properties',
help='List available properties',
action=ListProperties,
nargs=0,
default=False)
parser.add_argument('--address-owner',
help=f'Owner address. Default {OWNER_ADDRESS}',
default=None)
parser.add_argument('--address-user',
help=f'Owner address. Default {USER_ADDRESS}',
default=None)
parser.add_argument('--address-attacker',
help=f'Attacker address. Default {ATTACKER_ADDRESS}',
default=None)
parser = argparse.ArgumentParser(
description="Demo",
usage="slither-demo filename",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument("--contract", help="The targeted contract.")
parser.add_argument(
"--scenario",
help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable",
default="Transferable",
)
parser.add_argument(
"--list-scenarios",
help="List available scenarios",
action=ListScenarios,
nargs=0,
default=False,
)
parser.add_argument(
"--list-properties",
help="List available properties",
action=ListProperties,
nargs=0,
default=False,
)
parser.add_argument(
"--address-owner", help=f"Owner address. Default {OWNER_ADDRESS}", default=None
)
parser.add_argument(
"--address-user", help=f"Owner address. Default {USER_ADDRESS}", default=None
)
parser.add_argument(
"--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
@ -116,9 +126,9 @@ def main():
contract = slither.contracts[0]
else:
if args.contract is None:
logger.error(f'Specify the target: --contract ContractName')
logger.error(f"Specify the target: --contract ContractName")
else:
logger.error(f'{args.contract} not found')
logger.error(f"{args.contract} not found")
return
addresses = Addresses(args.address_owner, args.address_user, args.address_attacker)
@ -126,5 +136,5 @@ def main():
generate_erc20(contract, args.scenario, addresses)
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -8,8 +8,12 @@ ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef"
class Addresses:
def __init__(self, owner: Optional[str] = None, user: Optional[str] = None, attacker: Optional[str] = None):
def __init__(
self,
owner: Optional[str] = None,
user: Optional[str] = None,
attacker: Optional[str] = None,
):
self.owner = owner if owner else OWNER_ADDRESS
self.user = user if user else USER_ADDRESS
self.attacker = attacker if attacker else ATTACKER_ADDRESS

@ -11,11 +11,11 @@ def generate_echidna_config(output_dir: Path, addresses: Addresses) -> str:
:param addresses:
:return:
"""
content = 'prefix: crytic_\n'
content = "prefix: crytic_\n"
content += f'deployer: "{addresses.owner}"\n'
content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n'
content += f'psender: "{addresses.user}"\n'
content += 'coverage: true\n'
filename = 'echidna_config.yaml'
content += "coverage: true\n"
filename = "echidna_config.yaml"
write_file(output_dir, filename, content)
return filename

@ -7,21 +7,21 @@ from slither.tools.properties.addresses.address import Addresses
from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller
from slither.tools.properties.utils import write_file
PATTERN_TRUFFLE_MIGRATION = re.compile('^[0-9]*_')
PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_")
logger = logging.getLogger("Slither")
def _extract_caller(p: PropertyCaller):
if p == PropertyCaller.OWNER:
return ['owner']
return ["owner"]
if p == PropertyCaller.SENDER:
return ['user']
return ["user"]
if p == PropertyCaller.ATTACKER:
return ['attacker']
return ["attacker"]
if p == PropertyCaller.ALL:
return ['owner', 'user', 'attacker']
return ["owner", "user", "attacker"]
assert p == PropertyCaller.ANY
return ['user']
return ["user"]
def _helpers():
@ -31,7 +31,7 @@ def _helpers():
- catchRevertThrow: check if the call revert/throw
:return:
"""
return '''
return """
async function catchRevertThrowReturnFalse(promise) {
try {
const ret = await promise;
@ -61,12 +61,17 @@ async function catchRevertThrow(promise) {
}
assert(false, "Expected revert/throw/or return false");
};
'''
"""
def generate_unit_test(test_contract: str, filename: str,
unit_tests: List[Property], output_dir: Path,
addresses: Addresses, assert_message: str = ''):
def generate_unit_test(
test_contract: str,
filename: str,
unit_tests: List[Property],
output_dir: Path,
addresses: Addresses,
assert_message: str = "",
):
"""
Generate unit tests files
:param test_contract:
@ -88,37 +93,37 @@ def generate_unit_test(test_contract: str, filename: str,
content += f'\tlet attacker = "{addresses.attacker}";\n'
for unit_test in unit_tests:
content += f'\tit("{unit_test.description}", async () => {{\n'
content += f'\t\tlet instance = await {test_contract}.deployed();\n'
content += f"\t\tlet instance = await {test_contract}.deployed();\n"
callers = _extract_caller(unit_test.caller)
if unit_test.return_type == PropertyReturn.SUCCESS:
for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n'
content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message:
content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n'
else:
content += f'\t\tassert.equal(test_{caller}, true);\n'
content += f"\t\tassert.equal(test_{caller}, true);\n"
elif unit_test.return_type == PropertyReturn.FAIL:
for caller in callers:
content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n'
content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n"
if assert_message:
content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n'
else:
content += f'\t\tassert.equal(test_{caller}, false);\n'
content += f"\t\tassert.equal(test_{caller}, false);\n"
elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW:
for caller in callers:
content += f'\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n'
content += f"\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
elif unit_test.return_type == PropertyReturn.THROW:
callers = _extract_caller(unit_test.caller)
for caller in callers:
content += f'\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n'
content += '\t});\n'
content += f"\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n"
content += "\t});\n"
content += '});\n'
content += "});\n"
output_dir = Path(output_dir, 'test')
output_dir = Path(output_dir, "test")
output_dir.mkdir(exist_ok=True)
output_dir = Path(output_dir, 'crytic')
output_dir = Path(output_dir, "crytic")
output_dir.mkdir(exist_ok=True)
write_file(output_dir, filename, content)
@ -133,28 +138,31 @@ def generate_migration(test_contract: str, output_dir: Path, owner_address: str)
:param owner_address:
:return:
"""
content = f'''{test_contract} = artifacts.require("{test_contract}");
content = f"""{test_contract} = artifacts.require("{test_contract}");
module.exports = function(deployer) {{
deployer.deploy({test_contract}, {{from: "{owner_address}"}});
}};
'''
"""
output_dir = Path(output_dir, 'migrations')
output_dir = Path(output_dir, "migrations")
output_dir.mkdir(exist_ok=True)
migration_files = [js_file for js_file in output_dir.iterdir() if js_file.suffix == '.js'
and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)]
migration_files = [
js_file
for js_file in output_dir.iterdir()
if js_file.suffix == ".js" and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)
]
idx = len(migration_files)
filename = f'{idx + 1}_{test_contract}.js'
potential_previous_filename = f'{idx}_{test_contract}.js'
filename = f"{idx + 1}_{test_contract}.js"
potential_previous_filename = f"{idx}_{test_contract}.js"
for m in migration_files:
if m.name == potential_previous_filename:
write_file(output_dir, potential_previous_filename, content)
return
if test_contract in m.name:
logger.error(f'Potential conflicts with {m.name}')
logger.error(f"Potential conflicts with {m.name}")
write_file(output_dir, filename, content)

@ -12,27 +12,37 @@ from slither.tools.properties.platforms.echidna import generate_echidna_config
from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable
from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG
from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable
from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ERC20_NotMintableNotBurnable
from slither.tools.properties.properties.ercs.erc20.properties.transfer import ERC20_Transferable, ERC20_Pausable
from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import (
ERC20_NotMintableNotBurnable,
)
from slither.tools.properties.properties.ercs.erc20.properties.transfer import (
ERC20_Transferable,
ERC20_Pausable,
)
from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test
from slither.tools.properties.properties.properties import property_to_solidity, Property
from slither.tools.properties.solidity.generate_properties import generate_solidity_properties, generate_test_contract, \
generate_solidity_interface
from slither.tools.properties.solidity.generate_properties import (
generate_solidity_properties,
generate_test_contract,
generate_solidity_interface,
)
from slither.utils.colors import red, green
logger = logging.getLogger("Slither")
PropertyDescription = namedtuple('PropertyDescription', ['properties', 'description'])
PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"])
ERC20_PROPERTIES = {
"Transferable": PropertyDescription(ERC20_Transferable, 'Test the correct tokens transfer'),
"Pausable": PropertyDescription(ERC20_Pausable, 'Test the pausable functionality'),
"NotMintable": PropertyDescription(ERC20_NotMintable, 'Test that no one can mint tokens'),
"NotMintableNotBurnable": PropertyDescription(ERC20_NotMintableNotBurnable,
'Test that no one can mint or burn tokens'),
"NotBurnable": PropertyDescription(ERC20_NotBurnable, 'Test that no one can burn tokens'),
"Burnable": PropertyDescription(ERC20_NotBurnable,
'Test the burn of tokens. Require the "burn(address) returns()" function')
"Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"),
"Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"),
"NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"),
"NotMintableNotBurnable": PropertyDescription(
ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens"
),
"NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"),
"Burnable": PropertyDescription(
ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function'
),
}
@ -54,7 +64,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
:return:
"""
if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]:
logging.error(f'{contract.slither.crytic_compile.type} not yet supported by slither-prop')
logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop")
return
# Check if the contract is an ERC20 contract and if the functions have the correct visibility
@ -65,7 +75,9 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
properties = ERC20_PROPERTIES.get(type_property, None)
if properties is None:
logger.error(f'{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}')
logger.error(
f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}"
)
return
properties = properties.properties
@ -78,51 +90,53 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses)
# Generate the contract containing the properties
generate_solidity_interface(output_dir, addresses)
property_file = generate_solidity_properties(contract, type_property, solidity_properties, output_dir)
property_file = generate_solidity_properties(
contract, type_property, solidity_properties, output_dir
)
# Generate the Test contract
initialization_recommendation = _initialization_recommendation(type_property)
contract_filename, contract_name = generate_test_contract(contract,
type_property,
output_dir,
property_file,
initialization_recommendation)
contract_filename, contract_name = generate_test_contract(
contract, type_property, output_dir, property_file, initialization_recommendation
)
# Generate Echidna config file
echidna_config_filename = generate_echidna_config(Path(contract.slither.crytic_compile.target).parent, addresses)
echidna_config_filename = generate_echidna_config(
Path(contract.slither.crytic_compile.target).parent, addresses
)
unit_test_info = ''
unit_test_info = ""
# If truffle, generate unit tests
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses)
logger.info('################################################')
logger.info(green(f'Update the constructor in {Path(output_dir, contract_filename)}'))
logger.info("################################################")
logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}"))
if unit_test_info:
logger.info(green(unit_test_info))
logger.info(green('To run Echidna:'))
txt = f'\t echidna-test {contract.slither.crytic_compile.target} '
txt += f'--contract {contract_name} --config {echidna_config_filename}'
logger.info(green("To run Echidna:"))
txt = f"\t echidna-test {contract.slither.crytic_compile.target} "
txt += f"--contract {contract_name} --config {echidna_config_filename}"
logger.info(green(txt))
def _initialization_recommendation(type_property: str) -> str:
content = ''
content += '\t\t// Add below a minimal configuration:\n'
content += '\t\t// - crytic_owner must have some tokens \n'
content += '\t\t// - crytic_user must have some tokens \n'
content += '\t\t// - crytic_attacker must have some tokens \n'
if type_property in ['Pausable']:
content += '\t\t// - The contract must be paused \n'
if type_property in ['NotMintable', 'NotMintableNotBurnable']:
content += '\t\t// - The contract must not be mintable \n'
if type_property in ['NotBurnable', 'NotMintableNotBurnable']:
content += '\t\t// - The contract must not be burnable \n'
content += '\n'
content += '\n'
content = ""
content += "\t\t// Add below a minimal configuration:\n"
content += "\t\t// - crytic_owner must have some tokens \n"
content += "\t\t// - crytic_user must have some tokens \n"
content += "\t\t// - crytic_attacker must have some tokens \n"
if type_property in ["Pausable"]:
content += "\t\t// - The contract must be paused \n"
if type_property in ["NotMintable", "NotMintableNotBurnable"]:
content += "\t\t// - The contract must not be mintable \n"
if type_property in ["NotBurnable", "NotMintableNotBurnable"]:
content += "\t\t// - The contract must not be burnable \n"
content += "\n"
content += "\n"
return content
@ -130,44 +144,44 @@ def _initialization_recommendation(type_property: str) -> str:
# TODO: move this to crytic-compile
def _platform_to_output_dir(platform: AbstractPlatform) -> Path:
if platform.TYPE == PlatformType.TRUFFLE:
return Path(platform.target, 'contracts', 'crytic')
return Path(platform.target, "contracts", "crytic")
if platform.TYPE == PlatformType.SOLC:
return Path(platform.target).parent
def _check_compatibility(contract):
errors = ''
errors = ""
if not contract.is_erc20():
errors = f'{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc'
errors = f"{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc"
return errors
transfer = contract.get_function_from_signature('transfer(address,uint256)')
transfer = contract.get_function_from_signature("transfer(address,uint256)")
if transfer.visibility != 'public':
errors = f'slither-prop requires {transfer.canonical_name} to be public. Please change the visibility'
if transfer.visibility != "public":
errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility"
transfer_from = contract.get_function_from_signature('transferFrom(address,address,uint256)')
if transfer_from.visibility != 'public':
transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)")
if transfer_from.visibility != "public":
if errors:
errors += '\n'
errors += f'slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility'
errors += "\n"
errors += f"slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility"
approve = contract.get_function_from_signature('approve(address,uint256)')
if approve.visibility != 'public':
approve = contract.get_function_from_signature("approve(address,uint256)")
if approve.visibility != "public":
if errors:
errors += '\n'
errors += f'slither-prop requires {approve.canonical_name} to be public. Please change the visibility'
errors += "\n"
errors += f"slither-prop requires {approve.canonical_name} to be public. Please change the visibility"
return errors
def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]:
solidity_properties = ''
solidity_properties = ""
if contract.slither.crytic_compile.type == PlatformType.TRUFFLE:
solidity_properties += '\n'.join([property_to_solidity(p) for p in ERC20_CONFIG])
solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG])
solidity_properties += '\n'.join([property_to_solidity(p) for p in properties])
solidity_properties += "\n".join([property_to_solidity(p) for p in properties])
unit_tests = [p for p in properties if p.is_unit_test]
return solidity_properties, unit_tests

@ -1,30 +1,38 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotBurnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()',
description='The total supply does not decrease.',
content='''
\t\treturn initialTotalSupply == this.totalSupply();''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
Property(
name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
description="The total supply does not decrease.",
content="""
\t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY,
),
]
# Require burn(address) returns()
ERC20_Burnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()',
description='Cannot burn more than available balance',
content='''
Property(
name="crytic_supply_constant_ERC20PropertiesNotBurnable()",
description="Cannot burn more than available balance",
content="""
\t\tuint balance = balanceOf(msg.sender);
\t\tburn(balance + 1);
\t\treturn false;''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL)
\t\treturn false;""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
)
]

@ -1,65 +1,76 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_CONFIG = [
Property(name='init_total_supply()',
description='The total supply is correctly initialized.',
content='''
\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_owner_balance()',
description="Owner's balance is correctly initialized.",
content='''
\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_user_balance()',
description="User's balance is correctly initialized.",
content='''
\t\treturn initialBalance_user == this.balanceOf(crytic_user);''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_attacker_balance()',
description="Attacker's balance is correctly initialized.",
content='''
\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
Property(name='init_caller_balance()',
description="All the users have a positive balance.",
content='''
\t\treturn this.balanceOf(msg.sender) >0 ;''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ALL),
Property(
name="init_total_supply()",
description="The total supply is correctly initialized.",
content="""
\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY,
),
Property(
name="init_owner_balance()",
description="Owner's balance is correctly initialized.",
content="""
\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY,
),
Property(
name="init_user_balance()",
description="User's balance is correctly initialized.",
content="""
\t\treturn initialBalance_user == this.balanceOf(crytic_user);""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY,
),
Property(
name="init_attacker_balance()",
description="Attacker's balance is correctly initialized.",
content="""
\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY,
),
Property(
name="init_caller_balance()",
description="All the users have a positive balance.",
content="""
\t\treturn this.balanceOf(msg.sender) >0 ;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ALL,
),
# Note: there is a potential overflow on the addition, but we dont consider it
Property(name='init_total_supply_is_balances()',
description="The total supply is the user and owner balance.",
content='''
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY),
]
Property(
name="init_total_supply_is_balances()",
description="The total supply is the user and owner balance.",
content="""
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=False,
caller=PropertyCaller.ANY,
),
]

@ -1,13 +1,20 @@
from slither.tools.properties.properties.properties import PropertyType, PropertyReturn, Property, PropertyCaller
from slither.tools.properties.properties.properties import (
PropertyType,
PropertyReturn,
Property,
PropertyCaller,
)
ERC20_NotMintable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotMintable()',
description='The total supply does not increase.',
content='''
\t\treturn initialTotalSupply >= totalSupply();''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
Property(
name="crytic_supply_constant_ERC20PropertiesNotMintable()",
description="The total supply does not increase.",
content="""
\t\treturn initialTotalSupply >= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY,
),
]

@ -1,14 +1,20 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_NotMintableNotBurnable = [
Property(name='crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()',
description='The total supply does not change.',
content='''
\t\treturn initialTotalSupply == this.totalSupply();''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
]
Property(
name="crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()",
description="The total supply does not change.",
content="""
\t\treturn initialTotalSupply == this.totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY,
),
]

@ -1,96 +1,108 @@
from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller
from slither.tools.properties.properties.properties import (
Property,
PropertyType,
PropertyReturn,
PropertyCaller,
)
ERC20_Transferable = [
Property(name='crytic_zero_always_empty_ERC20Properties()',
description='The address 0x0 should not receive tokens.',
content='''
\t\treturn this.balanceOf(address(0x0)) == 0;''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
Property(name='crytic_approve_overwrites()',
description='Allowance can be changed.',
content='''
Property(
name="crytic_zero_always_empty_ERC20Properties()",
description="The address 0x0 should not receive tokens.",
content="""
\t\treturn this.balanceOf(address(0x0)) == 0;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY,
),
Property(
name="crytic_approve_overwrites()",
description="Allowance can be changed.",
content="""
\t\tbool approve_return;
\t\tapprove_return = approve(crytic_user, 10);
\t\trequire(approve_return);
\t\tapprove_return = approve(crytic_user, 20);
\t\trequire(approve_return);
\t\treturn this.allowance(msg.sender, crytic_user) == 20;''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_less_than_total_ERC20Properties()',
description='Balance of one user must be less or equal to the total supply.',
content='''
\t\treturn this.balanceOf(msg.sender) <= totalSupply();''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_totalSupply_consistant_ERC20Properties()',
description='Balance of the crytic users must be less or equal to the total supply.',
content='''
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY),
Property(name='crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()',
description='No one should be able to send tokens to the address 0x0 (transfer).',
content='''
\t\treturn this.allowance(msg.sender, crytic_user) == 20;""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_less_than_total_ERC20Properties()",
description="Balance of one user must be less or equal to the total supply.",
content="""
\t\treturn this.balanceOf(msg.sender) <= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_totalSupply_consistant_ERC20Properties()",
description="Balance of the crytic users must be less or equal to the total supply.",
content="""
\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ANY,
),
Property(
name="crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()",
description="No one should be able to send tokens to the address 0x0 (transfer).",
content="""
\t\tif (this.balanceOf(msg.sender) == 0){
\t\t\trevert();
\t\t}
\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()',
description='No one should be able to send tokens to the address 0x0 (transferFrom).',
content='''
\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()",
description="No one should be able to send tokens to the address 0x0 (transferFrom).",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == 0){
\t\t\trevert();
\t\t}
\t\tapprove(msg.sender, balance);
\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));''',
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_self_transferFrom_ERC20PropertiesTransferable()',
description='Self transferFrom works.',
content='''
\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));""",
type=PropertyType.CODE_QUALITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_self_transferFrom_ERC20PropertiesTransferable()",
description="Self transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance);
\t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;''',
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()',
description='transferFrom works.',
content='''
\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()",
description="transferFrom works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tbool approve_return = approve(msg.sender, balance);
\t\taddress other = crytic_user;
@ -98,29 +110,30 @@ ERC20_Transferable = [
\t\t\tother = crytic_owner;
\t\t}
\t\tbool transfer_return = transferFrom(msg.sender, other, balance);
\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;''',
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_self_transfer_ERC20PropertiesTransferable()',
description='Self transfer works.',
content='''
\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_self_transfer_ERC20PropertiesTransferable()",
description="Self transfer works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tbool transfer_return = transfer(msg.sender, balance);
\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;''',
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_transfer_to_other_ERC20PropertiesTransferable()',
description='transfer works.',
content='''
\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_transfer_to_other_ERC20PropertiesTransferable()",
description="transfer works.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\taddress other = crytic_user;
\t\tif (other == msg.sender) {
@ -130,74 +143,76 @@ ERC20_Transferable = [
\t\t\tbool transfer_other = transfer(other, 1);
\t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other;
\t\t}
\t\treturn true;''',
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_revert_transfer_to_user_ERC20PropertiesTransferable()',
description='Cannot transfer more than the balance.',
content='''
\t\treturn true;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_revert_transfer_to_user_ERC20PropertiesTransferable()",
description="Cannot transfer more than the balance.",
content="""
\t\tuint balance = this.balanceOf(msg.sender);
\t\tif (balance == (2 ** 256 - 1))
\t\t\treturn true;
\t\tbool transfer_other = transfer(crytic_user, balance+1);
\t\treturn transfer_other;''',
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
\t\treturn transfer_other;""",
type=PropertyType.HIGH_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
]
ERC20_Pausable = [
Property(name='crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()',
description='Cannot transfer.',
content='''
\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()',
description='Cannot execute transferFrom.',
content='''
Property(
name="crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()",
description="Cannot transfer.",
content="""
\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()",
description="Cannot execute transferFrom.",
content="""
\t\tapprove(msg.sender, this.balanceOf(msg.sender));
\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_constantBalance()',
description='Cannot change the balance.',
content='''
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
Property(name='crytic_constantAllowance()',
description='Cannot change the allowance.',
content='''
\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.FAIL_OR_THROW,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_constantBalance()",
description="Cannot change the balance.",
content="""
\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
Property(
name="crytic_constantAllowance()",
description="Cannot change the allowance.",
content="""
\t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) &&
\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);''',
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL),
\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);""",
type=PropertyType.MEDIUM_SEVERITY,
return_type=PropertyReturn.SUCCESS,
is_unit_test=True,
is_property_test=True,
caller=PropertyCaller.ALL,
),
]

@ -11,26 +11,32 @@ from slither.tools.properties.properties.properties import Property
logger = logging.getLogger("Slither")
def generate_truffle_test(contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses) -> str:
test_contract = f'Test{contract.name}{type_property}'
filename_init = f'Initialization{test_contract}.js'
filename = f'{test_contract}.js'
def generate_truffle_test(
contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses
) -> str:
test_contract = f"Test{contract.name}{type_property}"
filename_init = f"Initialization{test_contract}.js"
filename = f"{test_contract}.js"
output_dir = Path(contract.slither.crytic_compile.target)
generate_migration(test_contract, output_dir, addresses.owner)
generate_unit_test(test_contract,
filename_init,
ERC20_CONFIG,
output_dir,
addresses,
f'Check the constructor of {test_contract}')
generate_unit_test(test_contract, filename, unit_tests, output_dir, addresses,)
log_info = '\n'
log_info += 'To run the unit tests:\n'
generate_unit_test(
test_contract,
filename_init,
ERC20_CONFIG,
output_dir,
addresses,
f"Check the constructor of {test_contract}",
)
generate_unit_test(
test_contract, filename, unit_tests, output_dir, addresses,
)
log_info = "\n"
log_info += "To run the unit tests:\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n"
log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n"
return log_info

@ -36,4 +36,4 @@ class Property(NamedTuple):
def property_to_solidity(p: Property):
return f'\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n'
return f"\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n"

@ -9,59 +9,64 @@ from slither.tools.properties.utils import write_file
logger = logging.getLogger("Slither")
def generate_solidity_properties(contract: Contract, type_property: str, solidity_properties: str,
output_dir: Path) -> Path:
def generate_solidity_properties(
contract: Contract, type_property: str, solidity_properties: str, output_dir: Path
) -> Path:
solidity_import = f'import "./interfaces.sol";\n'
solidity_import += f'import "../{contract.source_mapping["filename_short"]}";'
test_contract_name = f'Properties{contract.name}{type_property}'
test_contract_name = f"Properties{contract.name}{type_property}"
solidity_content = f'{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}'
solidity_content += f'{{\n\n{solidity_properties}\n}}\n'
solidity_content = (
f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}"
)
solidity_content += f"{{\n\n{solidity_properties}\n}}\n"
filename = f'{test_contract_name}.sol'
filename = f"{test_contract_name}.sol"
write_file(output_dir, filename, solidity_content)
return Path(filename)
def generate_test_contract(contract: Contract,
type_property: str,
output_dir: Path,
property_file: Path,
initialization_recommendation: str) -> Tuple[str, str]:
test_contract_name = f'Test{contract.name}{type_property}'
properties_name = f'Properties{contract.name}{type_property}'
def generate_test_contract(
contract: Contract,
type_property: str,
output_dir: Path,
property_file: Path,
initialization_recommendation: str,
) -> Tuple[str, str]:
test_contract_name = f"Test{contract.name}{type_property}"
properties_name = f"Properties{contract.name}{type_property}"
content = ''
content = ""
content += f'import "./{property_file}";\n'
content += f"contract {test_contract_name} is {properties_name} {{\n"
content += '\tconstructor() public{\n'
content += '\t\t// Existing addresses:\n'
content += '\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n'
content += '\t\t// - crytic_user: Legitimate user\n'
content += '\t\t// - crytic_attacker: Attacker\n'
content += '\t\t// \n'
content += "\tconstructor() public{\n"
content += "\t\t// Existing addresses:\n"
content += "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n"
content += "\t\t// - crytic_user: Legitimate user\n"
content += "\t\t// - crytic_attacker: Attacker\n"
content += "\t\t// \n"
content += initialization_recommendation
content += '\t\t// \n'
content += '\t\t// \n'
content += '\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n'
content += '\t\tinitialTotalSupply = totalSupply();\n'
content += '\t\tinitialBalance_owner = balanceOf(crytic_owner);\n'
content += '\t\tinitialBalance_user = balanceOf(crytic_user);\n'
content += '\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n'
content += "\t\t// \n"
content += "\t\t// \n"
content += "\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n"
content += "\t\tinitialTotalSupply = totalSupply();\n"
content += "\t\tinitialBalance_owner = balanceOf(crytic_owner);\n"
content += "\t\tinitialBalance_user = balanceOf(crytic_user);\n"
content += "\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n"
content += '\t}\n}\n'
content += "\t}\n}\n"
filename = f'{test_contract_name}.sol'
filename = f"{test_contract_name}.sol"
write_file(output_dir, filename, content, allow_overwrite=False)
return filename, test_contract_name
def generate_solidity_interface(output_dir: Path, addresses: Addresses):
content = f'''
content = f"""
contract CryticInterface{{
address internal crytic_owner = address({addresses.owner});
address internal crytic_user = address({addresses.user});
@ -70,7 +75,7 @@ contract CryticInterface{{
uint internal initialBalance_owner;
uint internal initialBalance_user;
uint internal initialBalance_attacker;
}}'''
}}"""
# Static file, we discard if it exists as it should never change
write_file(output_dir, 'interfaces.sol', content, discard_if_exist=True)
write_file(output_dir, "interfaces.sol", content, discard_if_exist=True)

@ -6,11 +6,13 @@ from slither.utils.colors import green, yellow
logger = logging.getLogger("Slither")
def write_file(output_dir: Path,
filename: str,
content: str,
allow_overwrite: bool = True,
discard_if_exist: bool = False):
def write_file(
output_dir: Path,
filename: str,
content: str,
allow_overwrite: bool = True,
discard_if_exist: bool = False,
):
"""
Write the content into output_dir/filename
:param output_dir:
@ -25,10 +27,10 @@ def write_file(output_dir: Path,
if discard_if_exist:
return
if not allow_overwrite:
logger.info(yellow(f'{file_to_write} already exist and will not be overwritten'))
logger.info(yellow(f"{file_to_write} already exist and will not be overwritten"))
return
logger.info(yellow(f'Overwrite {file_to_write}'))
logger.info(yellow(f"Overwrite {file_to_write}"))
else:
logger.info(green(f'Write {file_to_write}'))
with open(file_to_write, 'w') as f:
logger.info(green(f"Write {file_to_write}"))
with open(file_to_write, "w") as f:
f.write(content)

@ -8,62 +8,56 @@ import operator
from crytic_compile import cryticparser
from .info import info
from .test import test
from .train import train
from .plot import plot
from .info import info
from .test import test
from .train import train
from .plot import plot
logging.basicConfig()
logger = logging.getLogger("Slither-simil")
modes = ["info", "test", "train", "plot"]
def parse_args():
parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector')
parser.add_argument('mode',
help="|".join(modes))
parser.add_argument('model',
help='model.bin')
parser.add_argument('--filename',
action='store',
dest='filename',
help='contract.sol')
parser.add_argument('--fname',
action='store',
dest='fname',
help='Target function')
parser.add_argument('--ext',
action='store',
dest='ext',
help='Extension to filter contracts')
parser.add_argument('--nsamples',
action='store',
type=int,
dest='nsamples',
help='Number of contract samples used for training')
parser.add_argument('--ntop',
action='store',
type=int,
dest='ntop',
default=10,
help='Number of more similar contracts to show for testing')
parser.add_argument('--input',
action='store',
dest='input',
help='File or directory used as input')
parser.add_argument('--version',
help='displays the current version',
version="0.0",
action='version')
parser = argparse.ArgumentParser(
description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector"
)
parser.add_argument("mode", help="|".join(modes))
parser.add_argument("model", help="model.bin")
parser.add_argument("--filename", action="store", dest="filename", help="contract.sol")
parser.add_argument("--fname", action="store", dest="fname", help="Target function")
parser.add_argument("--ext", action="store", dest="ext", help="Extension to filter contracts")
parser.add_argument(
"--nsamples",
action="store",
type=int,
dest="nsamples",
help="Number of contract samples used for training",
)
parser.add_argument(
"--ntop",
action="store",
type=int,
dest="ntop",
default=10,
help="Number of more similar contracts to show for testing",
)
parser.add_argument(
"--input", action="store", dest="input", help="File or directory used as input"
)
parser.add_argument(
"--version", help="displays the current version", version="0.0", action="version"
)
cryticparser.init(parser)
@ -74,6 +68,7 @@ def parse_args():
args = parser.parse_args()
return args
# endregion
###################################################################################
###################################################################################
@ -81,27 +76,29 @@ def parse_args():
###################################################################################
###################################################################################
def main():
args = parse_args()
default_log = logging.INFO
logger.setLevel(default_log)
mode = args.mode
if mode == "info":
info(args)
elif mode == "train":
train(args)
train(args)
elif mode == "test":
test(args)
elif mode == "plot":
plot(args)
else:
logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes))
logger.error("Invalid mode!. It should be one of these: %s" % ", ".join(modes))
sys.exit(-1)
if __name__ == '__main__':
if __name__ == "__main__":
main()
# endregion

@ -7,16 +7,18 @@ except ImportError:
print("$ pip3 install numpy --user\n")
sys.exit(-1)
def load_cache(infile, nsamples=None):
cache = dict()
with np.load(infile, allow_pickle=True) as data:
array = data['arr_0'][0]
for i,(x,y) in enumerate(array):
array = data["arr_0"][0]
for i, (x, y) in enumerate(array):
cache[x] = y
if i == nsamples:
break
return cache
def save_cache(cache, outfile):
np.savez(outfile,[np.array(cache)])
np.savez(outfile, [np.array(cache)])

@ -2,16 +2,47 @@ import logging
import os
from slither import Slither
from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function
from slither.core.declarations import (
Structure,
Enum,
SolidityVariableComposed,
SolidityVariable,
Function,
)
from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \
Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \
SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \
HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable
from slither.slithir.operations import (
Assignment,
Index,
AccessMember,
Length,
Balance,
Binary,
Unary,
Condition,
NewArray,
NewStructure,
NewContract,
NewElementaryType,
SolidityCall,
Push,
Delete,
EventCall,
LibraryCall,
InternalDynamicCall,
HighLevelCall,
LowLevelCall,
TypeConversion,
Return,
Transfer,
Send,
Unpack,
InitArray,
InternalCall,
)
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, IndexVariable
from .cache import load_cache
simil_logger = logging.getLogger("Slither-simil")
@ -20,11 +51,12 @@ compiler_logger.setLevel(logging.CRITICAL)
slither_logger = logging.getLogger("Slither")
slither_logger.setLevel(logging.CRITICAL)
def parse_target(target):
if target is None:
return None, None
parts = target.split('.')
parts = target.split(".")
if len(parts) == 1:
return None, parts[0]
elif len(parts) == 2:
@ -32,25 +64,27 @@ def parse_target(target):
else:
simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'")
def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs):
r = dict()
if infile.endswith(".npz"):
r = load_cache(infile, nsamples=nsamples)
else:
else:
contracts = load_contracts(infile, ext=ext, nsamples=nsamples)
for contract in contracts:
for x,ir in encode_contract(contract, **kwargs).items():
for x, ir in encode_contract(contract, **kwargs).items():
if ir != []:
y = " ".join(ir)
r[x] = vmodel.get_sentence_vector(y)
return r
def load_contracts(dirname, ext=None, nsamples=None, **kwargs):
r = []
walk = list(os.walk(dirname))
for x, y, files in walk:
for f in files:
for f in files:
if ext is None or f.endswith(ext):
r.append(x + "/".join(y) + "/" + f)
@ -60,6 +94,7 @@ def load_contracts(dirname, ext=None, nsamples=None, **kwargs):
# TODO: shuffle
return r[:nsamples]
def ntype(_type):
if isinstance(_type, ElementaryType):
_type = str(_type)
@ -79,8 +114,8 @@ def ntype(_type):
else:
_type = str(_type)
_type = _type.replace(" memory","")
_type = _type.replace(" storage ref","")
_type = _type.replace(" memory", "")
_type = _type.replace(" storage ref", "")
if "struct" in _type:
return "struct"
@ -93,92 +128,94 @@ def ntype(_type):
elif "mapping" in _type:
return "mapping"
else:
return _type.replace(" ","_")
return _type.replace(" ", "_")
def encode_ir(ir):
# operations
if isinstance(ir, Assignment):
return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue))
return "({}):=({})".format(encode_ir(ir.lvalue), encode_ir(ir.rvalue))
if isinstance(ir, Index):
return 'index({})'.format(ntype(ir._type))
if isinstance(ir, Member):
return 'member' #.format(ntype(ir._type))
return "index({})".format(ntype(ir._type))
if isinstance(ir, AccessMember):
return "member" # .format(ntype(ir._type))
if isinstance(ir, Length):
return 'length'
return "length"
if isinstance(ir, Balance):
return 'balance'
return "balance"
if isinstance(ir, Binary):
return 'binary({})'.format(ir.type_str)
return "binary({})".format(str(ir.type))
if isinstance(ir, Unary):
return 'unary({})'.format(ir.type_str)
return "unary({})".format(str(ir.type))
if isinstance(ir, Condition):
return 'condition({})'.format(encode_ir(ir.value))
return "condition({})".format(encode_ir(ir.value))
if isinstance(ir, NewStructure):
return 'new_structure'
return "new_structure"
if isinstance(ir, NewContract):
return 'new_contract'
return "new_contract"
if isinstance(ir, NewArray):
return 'new_array({})'.format(ntype(ir._array_type))
return "new_array({})".format(ntype(ir._array_type))
if isinstance(ir, NewElementaryType):
return 'new_elementary({})'.format(ntype(ir._type))
return "new_elementary({})".format(ntype(ir._type))
if isinstance(ir, Push):
return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue))
return "push({},{})".format(encode_ir(ir.value), encode_ir(ir.lvalue))
if isinstance(ir, Delete):
return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable))
return "delete({},{})".format(encode_ir(ir.lvalue), encode_ir(ir.variable))
if isinstance(ir, SolidityCall):
return 'solidity_call({})'.format(ir.function.full_name)
return "solidity_call({})".format(ir.function.full_name)
if isinstance(ir, InternalCall):
return 'internal_call({})'.format(ntype(ir._type_call))
if isinstance(ir, EventCall): # is this useful?
return 'event'
return "internal_call({})".format(ntype(ir._type_call))
if isinstance(ir, EventCall): # is this useful?
return "event"
if isinstance(ir, LibraryCall):
return 'library_call'
return "library_call"
if isinstance(ir, InternalDynamicCall):
return 'internal_dynamic_call'
if isinstance(ir, HighLevelCall): # TODO: improve
return 'high_level_call'
if isinstance(ir, LowLevelCall): # TODO: improve
return 'low_level_call'
return "internal_dynamic_call"
if isinstance(ir, HighLevelCall): # TODO: improve
return "high_level_call"
if isinstance(ir, LowLevelCall): # TODO: improve
return "low_level_call"
if isinstance(ir, TypeConversion):
return 'type_conversion({})'.format(ntype(ir.type))
if isinstance(ir, Return): # this can be improved using values
return 'return' #.format(ntype(ir.type))
return "type_conversion({})".format(ntype(ir.type))
if isinstance(ir, Return): # this can be improved using values
return "return" # .format(ntype(ir.type))
if isinstance(ir, Transfer):
return 'transfer({})'.format(encode_ir(ir.call_value))
return "transfer({})".format(encode_ir(ir.call_value))
if isinstance(ir, Send):
return 'send({})'.format(encode_ir(ir.call_value))
if isinstance(ir, Unpack): # TODO: improve
return 'unpack'
if isinstance(ir, InitArray): # TODO: improve
return 'init_array'
if isinstance(ir, Function): # TODO: investigate this
return 'function_solc'
return "send({})".format(encode_ir(ir.call_value))
if isinstance(ir, Unpack): # TODO: improve
return "unpack"
if isinstance(ir, InitArray): # TODO: improve
return "init_array"
if isinstance(ir, Function): # TODO: investigate this
return "function_solc"
# variables
if isinstance(ir, Constant):
return 'constant({})'.format(ntype(ir._type))
return "constant({})".format(ntype(ir._type))
if isinstance(ir, SolidityVariableComposed):
return 'solidity_variable_composed({})'.format(ir.name)
return "solidity_variable_composed({})".format(ir.name)
if isinstance(ir, SolidityVariable):
return 'solidity_variable{}'.format(ir.name)
return "solidity_variable{}".format(ir.name)
if isinstance(ir, TemporaryVariable):
return 'temporary_variable'
if isinstance(ir, ReferenceVariable):
return 'reference({})'.format(ntype(ir._type))
return "temporary_variable"
if isinstance(ir, IndexVariable):
return "reference({})".format(ntype(ir._type))
if isinstance(ir, LocalVariable):
return 'local_solc_variable({})'.format(ir._location)
return "local_solc_variable({})".format(ir._location)
if isinstance(ir, StateVariable):
return 'state_solc_variable({})'.format(ntype(ir._type))
return "state_solc_variable({})".format(ntype(ir._type))
if isinstance(ir, LocalVariableInitFromTuple):
return 'local_variable_init_tuple'
return "local_variable_init_tuple"
if isinstance(ir, TupleVariable):
return 'tuple_variable'
return "tuple_variable"
# default
else:
simil_logger.error(type(ir),"is missing encoding!")
return ''
simil_logger.error(type(ir), "is missing encoding!")
return ""
def encode_contract(cfilename, **kwargs):
r = dict()
@ -186,7 +223,7 @@ def encode_contract(cfilename, **kwargs):
try:
slither = Slither(cfilename, **kwargs)
except:
simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc'])
simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs["solc"])
return r
# Iterate over all the contracts
@ -198,7 +235,7 @@ def encode_contract(cfilename, **kwargs):
if function.nodes == [] or function.is_constructor_variables:
continue
x = (cfilename,contract.name,function.name)
x = (cfilename, contract.name, function.name)
r[x] = []
@ -210,5 +247,3 @@ def encode_contract(cfilename, **kwargs):
for ir in node.irs:
r[x].append(encode_ir(ir))
return r

@ -3,18 +3,19 @@ import sys
import os.path
import traceback
from .model import load_model
from .model import load_model
from .encode import parse_target, encode_contract
logging.basicConfig()
logger = logging.getLogger("Slither-simil")
def info(args):
try:
model = args.model
if os.path.isfile(model):
if os.path.isfile(model):
model = load_model(model)
else:
model = None
@ -22,22 +23,22 @@ def info(args):
filename = args.filename
contract, fname = parse_target(args.fname)
solc = args.solc
if filename is None and contract is None and fname is None:
logger.info("%s uses the following words:",args.model)
logger.info("%s uses the following words:", args.model)
for word in model.get_words():
logger.info(word)
sys.exit(0)
if filename is None or contract is None or fname is None:
logger.error('The encode mode requires filename, contract and fname parameters.')
logger.error("The encode mode requires filename, contract and fname parameters.")
sys.exit(-1)
irs = encode_contract(filename, **vars(args))
if len(irs) == 0:
sys.exit(-1)
x = (filename,contract,fname)
x = (filename, contract, fname)
y = " ".join(irs[x])
logger.info("Function {} in contract {} is encoded as:".format(fname, contract))
@ -47,8 +48,6 @@ def info(args):
logger.info(fvector)
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -5,7 +5,7 @@ import operator
import numpy as np
import random
from .model import load_model
from .model import load_model
from .encode import load_and_encode, parse_target
try:
@ -17,10 +17,13 @@ except ImportError:
logger = logging.getLogger("Slither-simil")
def plot(args):
if decomposition is None or plt is None:
logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:")
logger.error(
"ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:"
)
logger.error("$ pip3 install sklearn matplotlib --user")
sys.exit(-1)
@ -29,50 +32,50 @@ def plot(args):
model = args.model
model = load_model(model)
filename = args.filename
#contract = args.contract
# contract = args.contract
contract, fname = parse_target(args.fname)
#solc = args.solc
# solc = args.solc
infile = args.input
#ext = args.filter
#nsamples = args.nsamples
# ext = args.filter
# nsamples = args.nsamples
if fname is None or infile is None:
logger.error('The plot mode requieres fname and input parameters.')
logger.error("The plot mode requieres fname and input parameters.")
sys.exit(-1)
logger.info('Loading data..')
logger.info("Loading data..")
cache = load_and_encode(infile, **vars(args))
data = list()
fs = list()
logger.info('Procesing data..')
for (f,c,n),y in cache.items():
logger.info("Procesing data..")
for (f, c, n), y in cache.items():
if (c == contract or contract is None) and n == fname:
fs.append(f)
data.append(y)
if len(data) == 0:
logger.error('No contract was found with function %s', fname)
logger.error("No contract was found with function %s", fname)
sys.exit(-1)
data = np.array(data)
pca = decomposition.PCA(n_components=2)
tdata = pca.fit_transform(data)
logger.info('Plotting data..')
plt.figure(figsize=(20,10))
assert(len(tdata) == len(fs))
for ([x,y],l) in zip(tdata, fs):
logger.info("Plotting data..")
plt.figure(figsize=(20, 10))
assert len(tdata) == len(fs)
for ([x, y], l) in zip(tdata, fs):
x = random.gauss(0, 0.01) + x
y = random.gauss(0, 0.01) + y
plt.scatter(x, y, c='blue')
plt.text(x-0.001,y+0.001, l)
plt.scatter(x, y, c="blue")
plt.text(x - 0.001, y + 0.001, l)
logger.info("Saving figure to plot.png..")
plt.savefig("plot.png", bbox_inches="tight")
logger.info('Saving figure to plot.png..')
plt.savefig('plot.png', bbox_inches='tight')
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -1,5 +1,6 @@
import numpy as np
def similarity(v1, v2):
n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2)

@ -5,50 +5,51 @@ import traceback
import operator
import numpy as np
from .model import load_model
from .encode import encode_contract, load_and_encode, parse_target
from .cache import save_cache
from .model import load_model
from .encode import encode_contract, load_and_encode, parse_target
from .cache import save_cache
from .similarity import similarity
logger = logging.getLogger("Slither-simil")
def test(args):
try:
model = args.model
model = load_model(model)
filename = args.filename
contract, fname = parse_target(args.fname)
contract, fname = parse_target(args.fname)
infile = args.input
ntop = args.ntop
if filename is None or contract is None or fname is None or infile is None:
logger.error('The test mode requires filename, contract, fname and input parameters.')
logger.error("The test mode requires filename, contract, fname and input parameters.")
sys.exit(-1)
irs = encode_contract(filename, **vars(args))
if len(irs) == 0:
sys.exit(-1)
y = " ".join(irs[(filename,contract,fname)])
y = " ".join(irs[(filename, contract, fname)])
fvector = model.get_sentence_vector(y)
cache = load_and_encode(infile, model, **vars(args))
#save_cache("cache.npz", cache)
# save_cache("cache.npz", cache)
r = dict()
for x,y in cache.items():
for x, y in cache.items():
r[x] = similarity(fvector, y)
r = sorted(r.items(), key=operator.itemgetter(1), reverse=True)
logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop)
format_table = "{: <65} {: <20} {: <20} {: <10}"
logger.info(format_table.format(*["filename", "contract", "function", "score"]))
for x,score in r[:ntop]:
for x, score in r[:ntop]:
score = str(round(score, 3))
logger.info(format_table.format(*(list(x)+[score])))
logger.info(format_table.format(*(list(x) + [score])))
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -5,12 +5,13 @@ import traceback
import operator
import os
from .model import train_unsupervised
from .encode import encode_contract, load_contracts
from .cache import save_cache
from .model import train_unsupervised
from .encode import encode_contract, load_contracts
from .cache import save_cache
logger = logging.getLogger("Slither-simil")
def train(args):
try:
@ -20,35 +21,37 @@ def train(args):
nsamples = args.nsamples
if dirname is None:
logger.error('The train mode requires the input parameter.')
logger.error("The train mode requires the input parameter.")
sys.exit(-1)
contracts = load_contracts(dirname, **vars(args))
logger.info('Saving extracted data into %s', last_data_train_filename)
logger.info("Saving extracted data into %s", last_data_train_filename)
cache = []
with open(last_data_train_filename, 'w') as f:
with open(last_data_train_filename, "w") as f:
for filename in contracts:
#cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items():
# cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(
filename, **vars(args)
).items():
if ir != []:
x = " ".join(ir)
f.write(x+"\n")
f.write(x + "\n")
cache.append((os.path.split(filename)[-1], contract, function, x))
logger.info('Starting training')
model = train_unsupervised(input=last_data_train_filename, model='skipgram')
logger.info('Training complete')
logger.info('Saving model')
logger.info("Starting training")
model = train_unsupervised(input=last_data_train_filename, model="skipgram")
logger.info("Training complete")
logger.info("Saving model")
model.save_model(model_filename)
for i,(filename, contract, function, irs) in enumerate(cache):
for i, (filename, contract, function, irs) in enumerate(cache):
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs))
logger.info('Saving cache in cache.npz')
logger.info("Saving cache in cache.npz")
save_cache(cache, "cache.npz")
logger.info('Done!')
logger.info("Done!")
except Exception:
logger.error('Error in %s' % args.filename)
logger.error("Error in %s" % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -10,63 +10,77 @@ logging.basicConfig()
logger = logging.getLogger("Slither").setLevel(logging.INFO)
# Slither detectors for which slither-format currently works
available_detectors = ["unused-state",
"solc-version",
"pragma",
"naming-convention",
"external-function",
"constable-states",
"constant-function-asm",
"constatnt-function-state"]
available_detectors = [
"unused-state",
"solc-version",
"pragma",
"naming-convention",
"external-function",
"constable-states",
"constant-function-asm",
"constatnt-function-state",
]
detectors_to_run = []
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='slither_format',
usage='slither_format filename')
parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False)
parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False)
parser.add_argument('--version',
help='displays the current version',
version='0.1.0',
action='version')
parser.add_argument('--config-file',
help='Provide a config file (default: slither.config.json)',
action='store',
dest='config_file',
default='slither.config.json')
group_detector = parser.add_argument_group('Detectors')
group_detector.add_argument('--detect',
help='Comma-separated list of detectors, defaults to all, '
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_run',
default='all')
group_detector.add_argument('--exclude',
help='Comma-separated list of detectors to exclude,'
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_exclude',
default='all')
cryticparser.init(parser)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
parser = argparse.ArgumentParser(description="slither_format", usage="slither_format filename")
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument(
"--verbose-test",
"-v",
help="verbose mode output for testing",
action="store_true",
default=False,
)
parser.add_argument(
"--verbose-json", "-j", help="verbose json output", action="store_true", default=False
)
parser.add_argument(
"--version", help="displays the current version", version="0.1.0", action="version"
)
parser.add_argument(
"--config-file",
help="Provide a config file (default: slither.config.json)",
action="store",
dest="config_file",
default="slither.config.json",
)
group_detector = parser.add_argument_group("Detectors")
group_detector.add_argument(
"--detect",
help="Comma-separated list of detectors, defaults to all, "
"available detectors: {}".format(", ".join(d for d in available_detectors)),
action="store",
dest="detectors_to_run",
default="all",
)
group_detector.add_argument(
"--exclude",
help="Comma-separated list of detectors to exclude,"
"available detectors: {}".format(", ".join(d for d in available_detectors)),
action="store",
dest="detectors_to_exclude",
default="all",
)
cryticparser.init(parser)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
return parser.parse_args()
@ -80,11 +94,12 @@ def main():
read_config_file(args)
# Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args))
# Format the input files based on slither analysis
slither_format(slither, **vars(args))
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -11,27 +11,29 @@ from slither.detectors.attributes.const_functions_state import ConstantFunctions
from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format')
logger = logging.getLogger("Slither.Format")
all_detectors = {
'unused-state': UnusedStateVars,
'solc-version': IncorrectSolc,
'pragma': ConstantPragma,
'naming-convention': NamingConvention,
'external-function': ExternalFunction,
'constable-states' : ConstCandidateStateVars,
'constant-function-asm': ConstantFunctionsAsm,
'constant-functions-state': ConstantFunctionsState
"unused-state": UnusedStateVars,
"solc-version": IncorrectSolc,
"pragma": ConstantPragma,
"naming-convention": NamingConvention,
"external-function": ExternalFunction,
"constable-states": ConstCandidateStateVars,
"constant-function-asm": ConstantFunctionsAsm,
"constant-functions-state": ConstantFunctionsState,
}
def slither_format(slither, **kwargs):
''''
"""'
Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all
'''
"""
detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'),
kwargs.get('detectors_to_exclude', ''))
detectors_to_run = choose_detectors(
kwargs.get("detectors_to_run", "all"), kwargs.get("detectors_to_exclude", "")
)
for detector in detectors_to_run:
slither.register_detector(detector)
@ -42,32 +44,32 @@ def slither_format(slither, **kwargs):
detector_results = [x for x in detector_results if x] # remove empty results
detector_results = [item for sublist in detector_results for item in sublist] # flatten
export = Path('crytic-export', 'patches')
export = Path("crytic-export", "patches")
export.mkdir(parents=True, exist_ok=True)
counter_result = 0
logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.'))
logger.info(yellow("slither-format is in beta, carefully review each patch before merging it."))
for result in detector_results:
if not 'patches' in result:
if not "patches" in result:
continue
one_line_description = result["description"].split("\n")[0]
export_result = Path(export, f'{counter_result}')
export_result = Path(export, f"{counter_result}")
export_result.mkdir(parents=True, exist_ok=True)
counter_result += 1
counter = 0
logger.info(f'Issue: {one_line_description}')
logger.info(f'Generated: ({export_result})')
logger.info(f"Issue: {one_line_description}")
logger.info(f"Generated: ({export_result})")
for file, diff, in result['patches_diff'].items():
filename = f'fix_{counter}.patch'
for file, diff, in result["patches_diff"].items():
filename = f"fix_{counter}.patch"
path = Path(export_result, filename)
logger.info(f'\t- {filename}')
with open(path, 'w') as f:
logger.info(f"\t- {filename}")
with open(path, "w") as f:
f.write(diff)
counter += 1
@ -79,26 +81,28 @@ def slither_format(slither, **kwargs):
###################################################################################
###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude):
# If detectors are specified, run only these ones
cls_detectors_to_run = []
exclude = detectors_to_exclude.split(',')
if detectors_to_run == 'all':
exclude = detectors_to_exclude.split(",")
if detectors_to_run == "all":
for d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
exclude = detectors_to_exclude.split(',')
for d in detectors_to_run.split(','):
exclude = detectors_to_exclude.split(",")
for d in detectors_to_run.split(","):
if d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
raise Exception('Error: {} is not a detector'.format(d))
raise Exception("Error: {} is not a detector".format(d))
return cls_detectors_to_run
# endregion
###################################################################################
###################################################################################
@ -106,6 +110,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude):
###################################################################################
###################################################################################
def print_patches(number_of_slither_results, patches):
logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0
@ -115,39 +120,38 @@ def print_patches(number_of_slither_results, patches):
for file in patches:
logger.info("Patch file: " + file)
for patch in patches[file]:
logger.info("Detector: " + patch['detector'])
logger.info("Old string: " + patch['old_string'].replace("\n",""))
logger.info("New string: " + patch['new_string'].replace("\n",""))
logger.info("Location start: " + str(patch['start']))
logger.info("Location end: " + str(patch['end']))
logger.info("Detector: " + patch["detector"])
logger.info("Old string: " + patch["old_string"].replace("\n", ""))
logger.info("New string: " + patch["new_string"].replace("\n", ""))
logger.info("Location start: " + str(patch["start"]))
logger.info("Location end: " + str(patch["end"]))
def print_patches_json(number_of_slither_results, patches):
print('{',end='')
print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",')
print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',')
print("\"Patchlets\":" + '[')
print("{", end="")
print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",')
print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",")
print('"Patchlets":' + "[")
for index, file in enumerate(patches):
if index > 0:
print(',')
print('{',end='')
print("\"Patch file\":" + '"' + file + '",')
print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',')
print("\"Patches\":" + '[')
print(",")
print("{", end="")
print('"Patch file":' + '"' + file + '",')
print('"Number of patches":' + '"' + str(len(patches[file])) + '"', ",")
print('"Patches":' + "[")
for index, patch in enumerate(patches[file]):
if index > 0:
print(',')
print('{',end='')
print("\"Detector\":" + '"' + patch['detector'] + '",')
print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",')
print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",')
print("\"Location start\":" + '"' + str(patch['start']) + '",')
print("\"Location end\":" + '"' + str(patch['end']) + '"')
if 'overlaps' in patch:
print("\"Overlaps\":" + "Yes")
print('}',end='')
print(']',end='')
print('}',end='')
print(']',end='')
print('}')
print(",")
print("{", end="")
print('"Detector":' + '"' + patch["detector"] + '",')
print('"Old string":' + '"' + patch["old_string"].replace("\n", "") + '",')
print('"New string":' + '"' + patch["new_string"].replace("\n", "") + '",')
print('"Location start":' + '"' + str(patch["start"]) + '",')
print('"Location end":' + '"' + str(patch["end"]) + '"')
if "overlaps" in patch:
print('"Overlaps":' + "Yes")
print("}", end="")
print("]", end="")
print("}", end="")
print("]", end="")
print("}")

@ -12,7 +12,12 @@ from slither.utils.colors import red
from slither.utils.output import output_to_json
from .checks import all_checks
from .checks.abstract_checks import AbstractCheck
from .utils.command_line import output_detectors_json, output_wiki, output_detectors, output_to_markdown
from .utils.command_line import (
output_detectors_json,
output_wiki,
output_detectors,
output_to_markdown,
)
logging.basicConfig()
logger = logging.getLogger("Slither")
@ -21,49 +26,53 @@ logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description='Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.',
usage="slither-check-upgradeability contract.sol ContractName")
parser.add_argument('contract.sol', help='Codebase to analyze')
parser.add_argument('ContractName', help='Contract name (logic contract)')
parser.add_argument('--proxy-name', help='Proxy name')
parser.add_argument('--proxy-filename', help='Proxy filename (if different)')
parser.add_argument('--new-contract-name', help='New contract name (if changed)')
parser.add_argument('--new-contract-filename', help='New implementation filename (if different)')
parser.add_argument('--json',
help='Export the results as a JSON file ("--json -" to export to stdout)',
action='store',
default=False)
parser.add_argument('--list-detectors',
help='List available detectors',
action=ListDetectors,
nargs=0,
default=False)
parser.add_argument('--markdown-root',
help='URL for markdown generation',
action='store',
default="")
parser.add_argument('--wiki-detectors',
help=argparse.SUPPRESS,
action=OutputWiki,
default=False)
parser.add_argument('--list-detectors-json',
help=argparse.SUPPRESS,
action=ListDetectorsJson,
nargs=0,
default=False)
parser.add_argument('--markdown',
help=argparse.SUPPRESS,
action=OutputMarkdown,
default=False)
description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.",
usage="slither-check-upgradeability contract.sol ContractName",
)
parser.add_argument("contract.sol", help="Codebase to analyze")
parser.add_argument("ContractName", help="Contract name (logic contract)")
parser.add_argument("--proxy-name", help="Proxy name")
parser.add_argument("--proxy-filename", help="Proxy filename (if different)")
parser.add_argument("--new-contract-name", help="New contract name (if changed)")
parser.add_argument(
"--new-contract-filename", help="New implementation filename (if different)"
)
parser.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action="store",
default=False,
)
parser.add_argument(
"--list-detectors",
help="List available detectors",
action=ListDetectors,
nargs=0,
default=False,
)
parser.add_argument(
"--markdown-root", help="URL for markdown generation", action="store", default=""
)
parser.add_argument(
"--wiki-detectors", help=argparse.SUPPRESS, action=OutputWiki, default=False
)
parser.add_argument(
"--list-detectors-json",
help=argparse.SUPPRESS,
action=ListDetectorsJson,
nargs=0,
default=False,
)
parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False)
cryticparser.init(parser)
@ -80,6 +89,7 @@ def parse_args():
###################################################################################
###################################################################################
def _get_checks():
detectors = [getattr(all_checks, name) for name in dir(all_checks)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)]
@ -123,13 +133,18 @@ def _run_checks(detectors):
def _checks_on_contract(detectors, contract):
detectors = [d(logger, contract) for d in detectors if (not d.REQUIRE_PROXY and
not d.REQUIRE_CONTRACT_V2)]
detectors = [
d(logger, contract)
for d in detectors
if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2)
]
return _run_checks(detectors), len(detectors)
def _checks_on_contract_update(detectors, contract_v1, contract_v2):
detectors = [d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2]
detectors = [
d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2
]
return _run_checks(detectors), len(detectors)
@ -147,15 +162,11 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy):
def main():
json_results = {
'proxy-present': False,
'contract_v2-present': False,
'detectors': []
}
json_results = {"proxy-present": False, "contract_v2-present": False, "detectors": []}
args = parse_args()
v1_filename = vars(args)['contract.sol']
v1_filename = vars(args)["contract.sol"]
number_detectors_run = 0
detectors = _get_checks()
try:
@ -165,14 +176,14 @@ def main():
v1_name = args.ContractName
v1_contract = v1.get_contract_from_name(v1_name)
if v1_contract is None:
info = 'Contract {} not found in {}'.format(v1_name, v1.filename)
info = "Contract {} not found in {}".format(v1_name, v1.filename)
logger.error(red(info))
if args.json:
output_to_json(args.json, str(info), json_results)
return
detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract)
json_results['detectors'] += detectors_results
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
# Analyze Proxy
@ -185,15 +196,17 @@ def main():
proxy_contract = proxy.get_contract_from_name(args.proxy_name)
if proxy_contract is None:
info = 'Proxy {} not found in {}'.format(args.proxy_name, proxy.filename)
info = "Proxy {} not found in {}".format(args.proxy_name, proxy.filename)
logger.error(red(info))
if args.json:
output_to_json(args.json, str(info), json_results)
return
json_results['proxy-present'] = True
json_results["proxy-present"] = True
detectors_results, number_detectors = _checks_on_contract_and_proxy(detectors, v1_contract, proxy_contract)
json_results['detectors'] += detectors_results
detectors_results, number_detectors = _checks_on_contract_and_proxy(
detectors, v1_contract, proxy_contract
)
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
# Analyze new version
if args.new_contract_name:
@ -204,30 +217,36 @@ def main():
v2_contract = v2.get_contract_from_name(args.new_contract_name)
if v2_contract is None:
info = 'New logic contract {} not found in {}'.format(args.new_contract_name, v2.filename)
info = "New logic contract {} not found in {}".format(
args.new_contract_name, v2.filename
)
logger.error(red(info))
if args.json:
output_to_json(args.json, str(info), json_results)
return
json_results['contract_v2-present'] = True
json_results["contract_v2-present"] = True
if proxy_contract:
detectors_results, _ = _checks_on_contract_and_proxy(detectors,
v2_contract,
proxy_contract)
detectors_results, _ = _checks_on_contract_and_proxy(
detectors, v2_contract, proxy_contract
)
json_results['detectors'] += detectors_results
json_results["detectors"] += detectors_results
detectors_results, number_detectors = _checks_on_contract_update(detectors, v1_contract, v2_contract)
json_results['detectors'] += detectors_results
detectors_results, number_detectors = _checks_on_contract_update(
detectors, v1_contract, v2_contract
)
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
# If there is a V2, we run the contract-only check on the V2
detectors_results, _ = _checks_on_contract(detectors, v2_contract)
json_results['detectors'] += detectors_results
json_results["detectors"] += detectors_results
number_detectors_run += number_detectors
logger.info(f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run')
logger.info(
f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run'
)
if args.json:
output_to_json(args.json, None, json_results)
@ -237,4 +256,5 @@ def main():
output_to_json(args.json, str(e), json_results)
return
# endregion

@ -19,28 +19,28 @@ classification_colors = {
CheckClassification.INFORMATIONAL: green,
CheckClassification.LOW: yellow,
CheckClassification.MEDIUM: yellow,
CheckClassification.HIGH: red
CheckClassification.HIGH: red,
}
classification_txt = {
CheckClassification.INFORMATIONAL: 'Informational',
CheckClassification.LOW: 'Low',
CheckClassification.MEDIUM: 'Medium',
CheckClassification.HIGH: 'High',
CheckClassification.INFORMATIONAL: "Informational",
CheckClassification.LOW: "Low",
CheckClassification.MEDIUM: "Medium",
CheckClassification.HIGH: "High",
}
class AbstractCheck(metaclass=abc.ABCMeta):
ARGUMENT = ''
HELP = ''
ARGUMENT = ""
HELP = ""
IMPACT = None
WIKI = ''
WIKI = ""
WIKI_TITLE = ''
WIKI_DESCRIPTION = ''
WIKI_EXPLOIT_SCENARIO = ''
WIKI_RECOMMENDATION = ''
WIKI_TITLE = ""
WIKI_DESCRIPTION = ""
WIKI_EXPLOIT_SCENARIO = ""
WIKI_RECOMMENDATION = ""
REQUIRE_CONTRACT = False
REQUIRE_PROXY = False
@ -53,43 +53,69 @@ class AbstractCheck(metaclass=abc.ABCMeta):
self.contract_v2 = contract_v2
if not self.ARGUMENT:
raise IncorrectCheckInitialization('NAME is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"NAME is not initialized {}".format(self.__class__.__name__)
)
if not self.HELP:
raise IncorrectCheckInitialization('HELP is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"HELP is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI:
raise IncorrectCheckInitialization('WIKI is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_TITLE:
raise IncorrectCheckInitialization('WIKI_TITLE is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI_TITLE is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_DESCRIPTION:
raise IncorrectCheckInitialization('WIKI_DESCRIPTION is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI_DESCRIPTION is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [CheckClassification.INFORMATIONAL]:
raise IncorrectCheckInitialization('WIKI_EXPLOIT_SCENARIO is not initialized {}'.format(self.__class__.__name__))
if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [
CheckClassification.INFORMATIONAL
]:
raise IncorrectCheckInitialization(
"WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__)
)
if not self.WIKI_RECOMMENDATION:
raise IncorrectCheckInitialization('WIKI_RECOMMENDATION is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2:
# This is not a fundatemenal issues
# But it requires to change __main__ to avoid running two times the detectors
txt = 'REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}'.format(self.__class__.__name__)
txt = "REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}".format(
self.__class__.__name__
)
raise IncorrectCheckInitialization(txt)
if self.IMPACT not in [CheckClassification.LOW,
CheckClassification.MEDIUM,
CheckClassification.HIGH,
CheckClassification.INFORMATIONAL]:
raise IncorrectCheckInitialization('IMPACT is not initialized {}'.format(self.__class__.__name__))
if self.IMPACT not in [
CheckClassification.LOW,
CheckClassification.MEDIUM,
CheckClassification.HIGH,
CheckClassification.INFORMATIONAL,
]:
raise IncorrectCheckInitialization(
"IMPACT is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_CONTRACT_V2 and contract_v2 is None:
raise IncorrectCheckInitialization('ContractV2 is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"ContractV2 is not initialized {}".format(self.__class__.__name__)
)
if self.REQUIRE_PROXY and proxy is None:
raise IncorrectCheckInitialization('Proxy is not initialized {}'.format(self.__class__.__name__))
raise IncorrectCheckInitialization(
"Proxy is not initialized {}".format(self.__class__.__name__)
)
@abc.abstractmethod
def _check(self):
@ -102,19 +128,17 @@ class AbstractCheck(metaclass=abc.ABCMeta):
all_results = [r.data for r in all_results]
if all_results:
if self.logger:
info = '\n'
info = "\n"
for idx, result in enumerate(all_results):
info += result['description']
info += 'Reference: {}'.format(self.WIKI)
info += result["description"]
info += "Reference: {}".format(self.WIKI)
self._log(info)
return all_results
def generate_result(self, info, additional_fields=None):
output = Output(info,
additional_fields,
markdown_root=self.contract.slither.markdown_root)
output = Output(info, additional_fields, markdown_root=self.contract.slither.markdown_root)
output.data['check'] = self.ARGUMENT
output.data["check"] = self.ARGUMENT
return output

@ -1,11 +1,23 @@
from .initialization import (InitializablePresent, InitializableInherited,
InitializableInitializer, MissingInitializerModifier, MissingCalls, MultipleCalls, InitializeTarget)
from .initialization import (
InitializablePresent,
InitializableInherited,
InitializableInitializer,
MissingInitializerModifier,
MissingCalls,
MultipleCalls,
InitializeTarget,
)
from .functions_ids import IDCollision, FunctionShadowing
from .variable_initialization import VariableWithInit
from .variables_order import (MissingVariable, DifferentVariableContractProxy,
DifferentVariableContractNewContract, ExtraVariablesProxy, ExtraVariablesNewContract)
from .variables_order import (
MissingVariable,
DifferentVariableContractProxy,
DifferentVariableContractNewContract,
ExtraVariablesProxy,
ExtraVariablesNewContract,
)
from .constant import WereConstant, BecameConstant
from .constant import WereConstant, BecameConstant

@ -2,17 +2,17 @@ from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, C
class WereConstant(AbstractCheck):
ARGUMENT = 'were-constant'
ARGUMENT = "were-constant"
IMPACT = CheckClassification.HIGH
HELP = 'Variables that should be constant'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant'
WIKI_TITLE = 'Variables that should be constant'
WIKI_DESCRIPTION = '''
HELP = "Variables that should be constant"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant"
WIKI_TITLE = "Variables that should be constant"
WIKI_DESCRIPTION = """
Detect state variables that should be `constant̀`.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -28,11 +28,11 @@ contract ContractV2{
```
Because `variable2` is not anymore a `constant`, the storage location of `variable3` will be different.
As a result, `ContractV2` will have a corrupted storage layout.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Do not remove `constant` from a state variables during an update.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
@ -66,12 +66,13 @@ Do not remove `constant` from a state variables during an update.
if state_v1.is_constant:
if not state_v2.is_constant:
# If v2 has additional non constant variables, we need to skip them
if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type)
and v2_additional_variables > 0):
if (
state_v1.name != state_v2.name or state_v1.type != state_v2.type
) and v2_additional_variables > 0:
v2_additional_variables -= 1
idx_v2 += 1
continue
info = [state_v1, ' was constant, but ', state_v2, 'is not.\n']
info = [state_v1, " was constant, but ", state_v2, "is not.\n"]
json = self.generate_result(info)
results.append(json)
@ -80,19 +81,20 @@ Do not remove `constant` from a state variables during an update.
return results
class BecameConstant(AbstractCheck):
ARGUMENT = 'became-constant'
ARGUMENT = "became-constant"
IMPACT = CheckClassification.HIGH
HELP = 'Variables that should not be constant'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant'
WIKI_TITLE = 'Variables that should not be constant'
HELP = "Variables that should not be constant"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant"
WIKI_TITLE = "Variables that should not be constant"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect state variables that should not be `constant̀`.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -108,11 +110,11 @@ contract ContractV2{
```
Because `variable2` is now a `constant`, the storage location of `variable3` will be different.
As a result, `ContractV2` will have a corrupted storage layout.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Do not make an existing state variable `constant`.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
@ -146,13 +148,14 @@ Do not make an existing state variable `constant`.
if state_v1.is_constant:
if not state_v2.is_constant:
# If v2 has additional non constant variables, we need to skip them
if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type)
and v2_additional_variables > 0):
if (
state_v1.name != state_v2.name or state_v1.type != state_v2.type
) and v2_additional_variables > 0:
v2_additional_variables -= 1
idx_v2 += 1
continue
elif state_v2.is_constant:
info = [state_v1, ' was not constant but ', state_v2, ' is.\n']
info = [state_v1, " was not constant but ", state_v2, " is.\n"]
json = self.generate_result(info)
results.append(json)

@ -5,11 +5,16 @@ from slither.utils.function import get_function_id
def get_signatures(c):
functions = c.functions
functions = [f.full_name for f in functions if f.visibility in ['public', 'external'] and
not f.is_constructor and not f.is_fallback]
functions = [
f.full_name
for f in functions
if f.visibility in ["public", "external"] and not f.is_constructor and not f.is_fallback
]
variables = c.state_variables
variables = [variable.name + '()' for variable in variables if variable.visibility in ['public']]
variables = [
variable.name + "()" for variable in variables if variable.visibility in ["public"]
]
return list(set(functions + variables))
@ -21,26 +26,26 @@ def _get_function_or_variable(contract, signature):
for variable in contract.state_variables:
# Todo: can lead to incorrect variable in case of shadowing
if variable.visibility in ['public']:
if variable.name + '()' == signature:
if variable.visibility in ["public"]:
if variable.name + "()" == signature:
return variable
raise SlitherError(f'Function id checks: {signature} not found in {contract.name}')
raise SlitherError(f"Function id checks: {signature} not found in {contract.name}")
class IDCollision(AbstractCheck):
ARGUMENT = 'function-id-collision'
ARGUMENT = "function-id-collision"
IMPACT = CheckClassification.HIGH
HELP = 'Functions ids collision'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions'
WIKI_TITLE = 'Functions ids collisions'
HELP = "Functions ids collision"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions"
WIKI_TITLE = "Functions ids collisions"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect function id collision between the contract and the proxy.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
function gsf() public {
@ -56,11 +61,11 @@ contract Proxy{
```
`Proxy.tgeo()` and `Contract.gsf()` have the same function id (0x67e43e43).
As a result, `Proxy.tgeo()` will shadow Contract.gsf()`.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Rename the function. Avoid public functions in the proxy.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -77,11 +82,18 @@ Rename the function. Avoid public functions in the proxy.
for (k, _) in signatures_ids_implem.items():
if k in signatures_ids_proxy:
if signatures_ids_implem[k] != signatures_ids_proxy[k]:
implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k])
implem_function = _get_function_or_variable(
self.contract, signatures_ids_implem[k]
)
proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k])
info = ['Function id collision found: ', implem_function,
' ', proxy_function, '\n']
info = [
"Function id collision found: ",
implem_function,
" ",
proxy_function,
"\n",
]
json = self.generate_result(info)
results.append(json)
@ -89,18 +101,18 @@ Rename the function. Avoid public functions in the proxy.
class FunctionShadowing(AbstractCheck):
ARGUMENT = 'function-shadowing'
ARGUMENT = "function-shadowing"
IMPACT = CheckClassification.HIGH
HELP = 'Functions shadowing'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing'
WIKI_TITLE = 'Functions shadowing'
HELP = "Functions shadowing"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing"
WIKI_TITLE = "Functions shadowing"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect function shadowing between the contract and the proxy.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
function get() public {
@ -115,11 +127,11 @@ contract Proxy{
}
```
`Proxy.get` will shadow any call to `get()`. As a result `get()` is never executed in the logic contract and cannot be updated.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Rename the function. Avoid public functions in the proxy.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -136,11 +148,18 @@ Rename the function. Avoid public functions in the proxy.
for (k, _) in signatures_ids_implem.items():
if k in signatures_ids_proxy:
if signatures_ids_implem[k] == signatures_ids_proxy[k]:
implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k])
implem_function = _get_function_or_variable(
self.contract, signatures_ids_implem[k]
)
proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k])
info = ['Function shadowing found: ', implem_function,
' ', proxy_function, '\n']
info = [
"Function shadowing found: ",
implem_function,
" ",
proxy_function,
"\n",
]
json = self.generate_result(info)
results.append(json)

@ -13,16 +13,20 @@ class MultipleInitTarget(Exception):
def _get_initialize_functions(contract):
return [f for f in contract.functions if f.name == 'initialize' and f.is_implemented]
return [f for f in contract.functions if f.name == "initialize" and f.is_implemented]
def _get_all_internal_calls(function):
all_ir = function.all_slithir_operations()
return [i.function for i in all_ir if isinstance(i, InternalCall) and i.function_name == "initialize"]
return [
i.function
for i in all_ir
if isinstance(i, InternalCall) and i.function_name == "initialize"
]
def _get_most_derived_init(contract):
init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == 'initialize']
init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == "initialize"]
if len(init_functions) > 1:
if len([f for f in init_functions if f.contract_declarer == contract]) == 1:
return next((f for f in init_functions if f.contract_declarer == contract))
@ -33,80 +37,82 @@ def _get_most_derived_init(contract):
class InitializablePresent(AbstractCheck):
ARGUMENT = 'init-missing'
ARGUMENT = "init-missing"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initializable is missing'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing'
WIKI_TITLE = 'Initializable is missing'
WIKI_DESCRIPTION = '''
HELP = "Initializable is missing"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing"
WIKI_TITLE = "Initializable is missing"
WIKI_DESCRIPTION = """
Detect if a contract `Initializable` is present.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Review manually the contract's initialization..
Consider using a `Initializable` contract to follow [standard practice](https://docs.openzeppelin.com/upgrades/2.7/writing-upgradeable).
'''
"""
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
if initializable is None:
info = ["Initializable contract not found, the contract does not follow a standard initalization schema.\n"]
info = [
"Initializable contract not found, the contract does not follow a standard initalization schema.\n"
]
json = self.generate_result(info)
return [json]
return []
class InitializableInherited(AbstractCheck):
ARGUMENT = 'init-inherited'
ARGUMENT = "init-inherited"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initializable is not inherited'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited'
WIKI_TITLE = 'Initializable is not inherited'
HELP = "Initializable is not inherited"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited"
WIKI_TITLE = "Initializable is not inherited"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect if `Initializable` is inherited.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Review manually the contract's initialization. Consider inheriting `Initializable`.
'''
"""
REQUIRE_CONTRACT = True
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent
if initializable is None:
return []
if initializable not in self.contract.inheritance:
info = [self.contract, ' does not inherit from ', initializable, '.\n']
info = [self.contract, " does not inherit from ", initializable, ".\n"]
json = self.generate_result(info)
return [json]
return []
class InitializableInitializer(AbstractCheck):
ARGUMENT = 'initializer-missing'
ARGUMENT = "initializer-missing"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'initializer() is missing'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing'
WIKI_TITLE = 'initializer() is missing'
HELP = "initializer() is missing"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing"
WIKI_TITLE = "initializer() is missing"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect the lack of `Initializable.initializer()` modifier.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Review manually the contract's initialization. Consider inheriting a `Initializable.initializer()` modifier.
'''
"""
REQUIRE_CONTRACT = True
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent
if initializable is None:
return []
@ -114,26 +120,26 @@ Review manually the contract's initialization. Consider inheriting a `Initializa
if initializable not in self.contract.inheritance:
return []
initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()')
initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()")
if initializer is None:
info = ['Initializable.initializer() does not exist.\n']
info = ["Initializable.initializer() does not exist.\n"]
json = self.generate_result(info)
return [json]
return []
class MissingInitializerModifier(AbstractCheck):
ARGUMENT = 'missing-init-modifier'
ARGUMENT = "missing-init-modifier"
IMPACT = CheckClassification.HIGH
HELP = 'initializer() is not called'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called'
WIKI_TITLE = 'initializer() is not called'
WIKI_DESCRIPTION = '''
HELP = "initializer() is not called"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called"
WIKI_TITLE = "initializer() is not called"
WIKI_DESCRIPTION = """
Detect if `Initializable.initializer()` is called.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
function initialize() public{
@ -143,23 +149,23 @@ contract Contract{
```
`initialize` should have the `initializer` modifier to prevent someone from initializing the contract multiple times.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Use `Initializable.initializer()`.
'''
"""
REQUIRE_CONTRACT = True
def _check(self):
initializable = self.contract.slither.get_contract_from_name('Initializable')
initializable = self.contract.slither.get_contract_from_name("Initializable")
# See InitializablePresent
if initializable is None:
return []
# See InitializableInherited
if initializable not in self.contract.inheritance:
return []
initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()')
initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()")
# InitializableInitializer
if initializer is None:
return []
@ -168,24 +174,24 @@ Use `Initializable.initializer()`.
all_init_functions = _get_initialize_functions(self.contract)
for f in all_init_functions:
if initializer not in f.modifiers:
info = [f, ' does not call the initializer modifier.\n']
info = [f, " does not call the initializer modifier.\n"]
json = self.generate_result(info)
results.append(json)
return results
class MissingCalls(AbstractCheck):
ARGUMENT = 'missing-calls'
ARGUMENT = "missing-calls"
IMPACT = CheckClassification.HIGH
HELP = 'Missing calls to init functions'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called'
WIKI_TITLE = 'Initialize functions are not called'
WIKI_DESCRIPTION = '''
HELP = "Missing calls to init functions"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called"
WIKI_TITLE = "Initialize functions are not called"
WIKI_DESCRIPTION = """
Detect missing calls to initialize functions.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Base{
function initialize() public{
@ -200,11 +206,11 @@ contract Derived is Base{
```
`Derived.initialize` does not call `Base.initialize` leading the contract to not be correctly initialized.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Ensure all the initialize functions are reached by the most derived initialize function.
'''
"""
REQUIRE_CONTRACT = True
@ -215,7 +221,7 @@ Ensure all the initialize functions are reached by the most derived initialize f
try:
most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget:
logger.error(red(f'Too many init targets in {self.contract}'))
logger.error(red(f"Too many init targets in {self.contract}"))
return []
if most_derived_init is None:
@ -225,24 +231,24 @@ Ensure all the initialize functions are reached by the most derived initialize f
all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init]
missing_calls = [f for f in all_init_functions if not f in all_init_functions_called]
for f in missing_calls:
info = ['Missing call to ', f, ' in ', most_derived_init, '.\n']
info = ["Missing call to ", f, " in ", most_derived_init, ".\n"]
json = self.generate_result(info)
results.append(json)
return results
class MultipleCalls(AbstractCheck):
ARGUMENT = 'multiple-calls'
ARGUMENT = "multiple-calls"
IMPACT = CheckClassification.HIGH
HELP = 'Init functions called multiple times'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times'
WIKI_TITLE = 'Initialize functions are called multiple times'
WIKI_DESCRIPTION = '''
HELP = "Init functions called multiple times"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times"
WIKI_TITLE = "Initialize functions are called multiple times"
WIKI_DESCRIPTION = """
Detect multiple calls to a initialize function.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Base{
function initialize(uint) public{
@ -264,11 +270,11 @@ contract DerivedDerived is Derived{
```
`Base.initialize(uint)` is called two times in `DerivedDerived.initiliaze` execution, leading to a potential corruption.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Call only one time every initialize function.
'''
"""
REQUIRE_CONTRACT = True
@ -280,38 +286,41 @@ Call only one time every initialize function.
most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget:
# Should be already reported by MissingCalls
#logger.error(red(f'Too many init targets in {self.contract}'))
# logger.error(red(f'Too many init targets in {self.contract}'))
return []
if most_derived_init is None:
return []
all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init]
double_calls = list(set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1]))
double_calls = list(
set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1])
)
for f in double_calls:
info = [f, ' is called multiple times in ', most_derived_init, '.\n']
info = [f, " is called multiple times in ", most_derived_init, ".\n"]
json = self.generate_result(info)
results.append(json)
return results
class InitializeTarget(AbstractCheck):
ARGUMENT = 'initialize-target'
ARGUMENT = "initialize-target"
IMPACT = CheckClassification.INFORMATIONAL
HELP = 'Initialize function that must be called'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function'
WIKI_TITLE = 'Initialize function'
HELP = "Initialize function that must be called"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function"
WIKI_TITLE = "Initialize function"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Show the function that must be called at deployment.
This finding does not have an immediate security impact and is informative.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Ensure that the function is called at deployment.
'''
"""
REQUIRE_CONTRACT = True
@ -322,12 +331,12 @@ Ensure that the function is called at deployment.
most_derived_init = _get_most_derived_init(self.contract)
except MultipleInitTarget:
# Should be already reported by MissingCalls
#logger.error(red(f'Too many init targets in {self.contract}'))
# logger.error(red(f'Too many init targets in {self.contract}'))
return []
if most_derived_init is None:
return []
info = [self.contract, f' needs to be initialized by ', most_derived_init, '.\n']
info = [self.contract, f" needs to be initialized by ", most_derived_init, ".\n"]
json = self.generate_result(info)
return [json]

@ -2,29 +2,29 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat
class VariableWithInit(AbstractCheck):
ARGUMENT = 'variables-initialized'
ARGUMENT = "variables-initialized"
IMPACT = CheckClassification.HIGH
HELP = 'State variables with an initial value'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized'
WIKI_TITLE = 'State variable initialized'
HELP = "State variables with an initial value"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized"
WIKI_TITLE = "State variable initialized"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect state variables that are initialized.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable = 10;
}
```
Using `Contract` will the delegatecall proxy pattern will lead `variable` to be 0 when called through the proxy.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Using initialize functions to write initial values in state variables.
'''
"""
REQUIRE_CONTRACT = True
@ -32,7 +32,7 @@ Using initialize functions to write initial values in state variables.
results = []
for s in self.contract.state_variables:
if s.initialized and not s.is_constant:
info = [s, ' is a state variable with an initial value.\n']
info = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info)
results.append(json)
return results

@ -2,16 +2,16 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat
class MissingVariable(AbstractCheck):
ARGUMENT = 'missing-variables'
ARGUMENT = "missing-variables"
IMPACT = CheckClassification.MEDIUM
HELP = 'Variable missing in the v2'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables'
WIKI_TITLE = 'Missing variables'
WIKI_DESCRIPTION = '''
HELP = "Variable missing in the v2"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables"
WIKI_TITLE = "Missing variables"
WIKI_DESCRIPTION = """
Detect variables that were present in the original contracts but are not in the updated one.
'''
WIKI_EXPLOIT_SCENARIO = '''
"""
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract V1{
uint variable1;
@ -25,11 +25,11 @@ contract V2{
The new version, `V2` does not contain `variable1`.
If a new variable is added in an update of `V2`, this variable will hold the latest value of `variable2` and
will be corrupted.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Do not change the order of the state variables in the updated contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_CONTRACT_V2 = True
@ -44,7 +44,7 @@ Do not change the order of the state variables in the updated contract.
for idx in range(0, len(order1)):
variable1 = order1[idx]
if len(order2) <= idx:
info = ['Variable missing in ', contract2, ': ', variable1, '\n']
info = ["Variable missing in ", contract2, ": ", variable1, "\n"]
json = self.generate_result(info)
results.append(json)
@ -52,18 +52,18 @@ Do not change the order of the state variables in the updated contract.
class DifferentVariableContractProxy(AbstractCheck):
ARGUMENT = 'order-vars-proxy'
ARGUMENT = "order-vars-proxy"
IMPACT = CheckClassification.HIGH
HELP = 'Incorrect vars order with the proxy'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy'
WIKI_TITLE = 'Incorrect variables with the proxy'
HELP = "Incorrect vars order with the proxy"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy"
WIKI_TITLE = "Incorrect variables with the proxy"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect variables that are different between the contract and the proxy.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -74,11 +74,11 @@ contract Proxy{
}
```
`Contract` and `Proxy` do not have the same storage layout. As a result the storage of both contracts can be corrupted.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -104,9 +104,9 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info = ['Different variables between ', contract1, ' and ', contract2, '\n']
info += [f'\t ', variable1, '\n']
info += [f'\t ', variable2, '\n']
info = ["Different variables between ", contract1, " and ", contract2, "\n"]
info += [f"\t ", variable1, "\n"]
info += [f"\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
@ -114,17 +114,17 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
class DifferentVariableContractNewContract(DifferentVariableContractProxy):
ARGUMENT = 'order-vars-contracts'
ARGUMENT = "order-vars-contracts"
HELP = 'Incorrect vars order with the v2'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2'
WIKI_TITLE = 'Incorrect variables with the v2'
HELP = "Incorrect vars order with the v2"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2"
WIKI_TITLE = "Incorrect variables with the v2"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect variables that are different between the original contract and the updated one.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -135,11 +135,11 @@ contract ContractV2{
}
```
`Contract` and `ContractV2` do not have the same storage layout. As a result the storage of both contracts can be corrupted.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Respect the variable order of the original contract in the updated contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = False
@ -150,18 +150,20 @@ Respect the variable order of the original contract in the updated contract.
class ExtraVariablesProxy(AbstractCheck):
ARGUMENT = 'extra-vars-proxy'
ARGUMENT = "extra-vars-proxy"
IMPACT = CheckClassification.MEDIUM
HELP = 'Extra vars in the proxy'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy'
WIKI_TITLE = 'Extra variables in the proxy'
HELP = "Extra vars in the proxy"
WIKI = (
"https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy"
)
WIKI_TITLE = "Extra variables in the proxy"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Detect variables that are in the proxy and not in the contract.
'''
"""
WIKI_EXPLOIT_SCENARIO = '''
WIKI_EXPLOIT_SCENARIO = """
```solidity
contract Contract{
uint variable1;
@ -173,11 +175,11 @@ contract Proxy{
}
```
`Proxy` contains additional variables. A future update of `Contract` is likely to corrupt the proxy.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract.
'''
"""
REQUIRE_CONTRACT = True
REQUIRE_PROXY = True
@ -203,7 +205,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
while idx < len(order2):
variable2 = order2[idx]
info = ['Extra variables in ', contract2, ': ', variable2, '\n']
info = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1
@ -212,21 +214,21 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
class ExtraVariablesNewContract(ExtraVariablesProxy):
ARGUMENT = 'extra-vars-v2'
ARGUMENT = "extra-vars-v2"
HELP = 'Extra vars in the v2'
WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2'
WIKI_TITLE = 'Extra variables in the v2'
HELP = "Extra vars in the v2"
WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2"
WIKI_TITLE = "Extra variables in the v2"
WIKI_DESCRIPTION = '''
WIKI_DESCRIPTION = """
Show new variables in the updated contract.
This finding does not have an immediate security impact and is informative.
'''
"""
WIKI_RECOMMENDATION = '''
WIKI_RECOMMENDATION = """
Ensure that all the new variables are expected.
'''
"""
IMPACT = CheckClassification.INFORMATIONAL

@ -4,7 +4,9 @@ from slither.utils.myprettytable import MyPrettyTable
def output_wiki(detector_classes, filter_wiki):
# Sort by impact, confidence, and name
detectors_list = sorted(detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT))
detectors_list = sorted(
detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)
)
for detector in detectors_list:
if filter_wiki not in detector.WIKI:
@ -16,16 +18,16 @@ def output_wiki(detector_classes, filter_wiki):
exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO
recommendation = detector.WIKI_RECOMMENDATION
print('\n## {}'.format(title))
print('### Configuration')
print('* Check: `{}`'.format(argument))
print('* Severity: `{}`'.format(impact))
print('\n### Description')
print("\n## {}".format(title))
print("### Configuration")
print("* Check: `{}`".format(argument))
print("* Severity: `{}`".format(impact))
print("\n### Description")
print(description)
if exploit_scenario:
print('\n### Exploit Scenario:')
print("\n### Exploit Scenario:")
print(exploit_scenario)
print('\n### Recommendation')
print("\n### Recommendation")
print(recommendation)
@ -38,27 +40,31 @@ def output_detectors(detector_classes):
require_proxy = detector.REQUIRE_PROXY
require_v2 = detector.REQUIRE_CONTRACT_V2
detectors_list.append((argument, help_info, impact, require_proxy, require_v2))
table = MyPrettyTable(["Num",
"Check",
"What it Detects",
"Impact",
"Proxy",
"Contract V2"])
table = MyPrettyTable(["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"])
# Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
for (argument, help_info, impact, proxy, v2) in detectors_list:
table.add_row([idx, argument, help_info, classification_txt[impact], 'X' if proxy else '', 'X' if v2 else ''])
table.add_row(
[
idx,
argument,
help_info,
classification_txt[impact],
"X" if proxy else "",
"X" if v2 else "",
]
)
idx = idx + 1
print(table)
def output_to_markdown(detector_classes, filter_wiki):
def extract_help(cls):
if cls.WIKI == '':
if cls.WIKI == "":
return cls.HELP
return '[{}]({})'.format(cls.HELP, cls.WIKI)
return "[{}]({})".format(cls.HELP, cls.WIKI)
detectors_list = []
for detector in detector_classes:
@ -73,12 +79,16 @@ def output_to_markdown(detector_classes, filter_wiki):
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
for (argument, help_info, impact, proxy, v2) in detectors_list:
print('{} | `{}` | {} | {} | {} | {}'.format(idx,
argument,
help_info,
classification_txt[impact],
'X' if proxy else '',
'X' if v2 else ''))
print(
"{} | `{}` | {} | {} | {} | {}".format(
idx,
argument,
help_info,
classification_txt[impact],
"X" if proxy else "",
"X" if v2 else "",
)
)
idx = idx + 1
@ -92,26 +102,42 @@ def output_detectors_json(detector_classes):
wiki_description = detector.WIKI_DESCRIPTION
wiki_exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO
wiki_recommendation = detector.WIKI_RECOMMENDATION
detectors_list.append((argument,
help_info,
impact,
wiki_url,
wiki_description,
wiki_exploit_scenario,
wiki_recommendation))
detectors_list.append(
(
argument,
help_info,
impact,
wiki_url,
wiki_description,
wiki_exploit_scenario,
wiki_recommendation,
)
)
# Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
table = []
for (argument, help_info, impact, wiki_url, description, exploit, recommendation) in detectors_list:
table.append({'index': idx,
'check': argument,
'title': help_info,
'impact': classification_txt[impact],
'wiki_url': wiki_url,
'description': description,
'exploit_scenario': exploit,
'recommendation': recommendation})
for (
argument,
help_info,
impact,
wiki_url,
description,
exploit,
recommendation,
) in detectors_list:
table.append(
{
"index": idx,
"check": argument,
"title": help_info,
"impact": classification_txt[impact],
"wiki_url": wiki_url,
"description": description,
"exploit_scenario": exploit,
"recommendation": recommendation,
}
)
idx = idx + 1
return table

@ -3,11 +3,13 @@ import logging
from .expression import ExpressionVisitor
from slither.core.expressions import BinaryOperationType, Literal
class NotConstant(Exception):
pass
KEY = 'ConstantFolding'
KEY = "ConstantFolding"
def get_val(expression):
val = expression.context[KEY]
@ -15,11 +17,12 @@ def get_val(expression):
del expression.context[KEY]
return val
def set_val(expression, val):
expression.context[KEY] = val
class ConstantFolding(ExpressionVisitor):
class ConstantFolding(ExpressionVisitor):
def __init__(self, expression, type):
self._type = type
super(ConstantFolding, self).__init__(expression)
@ -40,24 +43,24 @@ class ConstantFolding(ExpressionVisitor):
def _post_binary_operation(self, expression):
left = get_val(expression.expression_left)
right = get_val(expression.expression_right)
if expression.type == BinaryOperationType.POWER:
if expression.type == BinaryOperationType.POWER:
set_val(expression, left ** right)
elif expression.type == BinaryOperationType.MULTIPLICATION:
elif expression.type == BinaryOperationType.MULTIPLICATION:
set_val(expression, left * right)
elif expression.type == BinaryOperationType.DIVISION:
elif expression.type == BinaryOperationType.DIVISION:
set_val(expression, left / right)
elif expression.type == BinaryOperationType.MODULO:
elif expression.type == BinaryOperationType.MODULO:
set_val(expression, left % right)
elif expression.type == BinaryOperationType.ADDITION:
elif expression.type == BinaryOperationType.ADDITION:
set_val(expression, left + right)
elif expression.type == BinaryOperationType.SUBTRACTION:
if(left-right) <0:
elif expression.type == BinaryOperationType.SUBTRACTION:
if (left - right) < 0:
# Could trigger underflow
raise NotConstant
set_val(expression, left - right)
elif expression.type == BinaryOperationType.LEFT_SHIFT:
elif expression.type == BinaryOperationType.LEFT_SHIFT:
set_val(expression, left << right)
elif expression.type == BinaryOperationType.RIGHT_SHIFT:
elif expression.type == BinaryOperationType.RIGHT_SHIFT:
set_val(expression, left >> right)
else:
raise NotConstant
@ -110,6 +113,3 @@ class ConstantFolding(ExpressionVisitor):
def _post_type_conversion(self, expression):
raise NotConstant

@ -1,11 +1,11 @@
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType
from slither.core.variables.variable import Variable
key = 'ExportValues'
key = "ExportValues"
def get(expression):
val = expression.context[key]
@ -13,11 +13,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class ExportValues(ExpressionVisitor):
class ExportValues(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))

@ -1,10 +1,12 @@
import logging
from typing import Any
from slither.core.expressions.assignment_operation import AssignmentOperation
from slither.core.expressions.binary_operation import BinaryOperation
from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.conditional_expression import ConditionalExpression
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.expressions.expression import Expression
from slither.core.expressions.identifier import Identifier
from slither.core.expressions.index_access import IndexAccess
from slither.core.expressions.literal import Literal
@ -19,24 +21,24 @@ from slither.exceptions import SlitherError
logger = logging.getLogger("ExpressionVisitor")
class ExpressionVisitor:
def __init__(self, expression):
class ExpressionVisitor:
def __init__(self, expression: Expression):
# Inherited class must declared their variables prior calling super().__init__
self._expression = expression
self._result = None
self._result: Any = None
self._visit_expression(self.expression)
def result(self):
return self._result
@property
def expression(self):
def expression(self) -> Expression:
return self._expression
# visit an expression
# call pre_visit, visit_expression_name, post_visit
def _visit_expression(self, expression):
def _visit_expression(self, expression: Expression):
self._pre_visit(expression)
if isinstance(expression, AssignmentOperation):
@ -88,7 +90,7 @@ class ExpressionVisitor:
pass
else:
raise SlitherError('Expression not handled: {}'.format(expression))
raise SlitherError("Expression not handled: {}".format(expression))
self._post_visit(expression)
@ -207,7 +209,7 @@ class ExpressionVisitor:
pass
else:
raise SlitherError('Expression not handled: {}'.format(expression))
raise SlitherError("Expression not handled: {}".format(expression))
# pre_expression_name
@ -308,7 +310,7 @@ class ExpressionVisitor:
pass
else:
raise SlitherError('Expression not handled: {}'.format(expression))
raise SlitherError("Expression not handled: {}".format(expression))
# post_expression_name
@ -356,5 +358,3 @@ class ExpressionVisitor:
def _post_unary_operation(self, expression):
pass

@ -1,17 +1,18 @@
from slither.visitors.expression.expression import ExpressionVisitor
def get(expression):
val = expression.context['ExpressionPrinter']
val = expression.context["ExpressionPrinter"]
# we delete the item to reduce memory use
del expression.context['ExpressionPrinter']
del expression.context["ExpressionPrinter"]
return val
def set_val(expression, val):
expression.context['ExpressionPrinter'] = val
expression.context["ExpressionPrinter"] = val
class ExpressionPrinter(ExpressionVisitor):
class ExpressionPrinter(ExpressionVisitor):
def result(self):
if not self._result:
self._result = get(self.expression)
@ -20,19 +21,19 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_assignement_operation(self, expression):
left = get(expression.expression_left)
right = get(expression.expression_right)
val = "{} {} {}".format(left, expression.type_str, right)
val = "{} {} {}".format(left, expression.type, right)
set_val(expression, val)
def _post_binary_operation(self, expression):
left = get(expression.expression_left)
right = get(expression.expression_right)
val = "{} {} {}".format(left, expression.type_str, right)
val = "{} {} {}".format(left, expression.type, right)
set_val(expression, val)
def _post_call_expression(self, expression):
called = get(expression.called)
arguments = [get(x) for x in expression.arguments if x]
val = "{}({})".format(called, ','.join(arguments))
val = "{}({})".format(called, ",".join(arguments))
set_val(expression, val)
def _post_conditional_expression(self, expression):
@ -66,7 +67,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_new_array(self, expression):
array = str(expression.array_type)
depth = expression.depth
val = "new {}{}".format(array, '[]'*depth)
val = "new {}{}".format(array, "[]" * depth)
set_val(expression, val)
def _post_new_contract(self, expression):
@ -81,7 +82,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_tuple_expression(self, expression):
expressions = [get(e) for e in expression.expressions if e]
val = "({})".format(','.join(expressions))
val = "({})".format(",".join(expressions))
set_val(expression, val)
def _post_type_conversion(self, expression):

@ -1,11 +1,10 @@
from typing import List
from slither.core.expressions.expression import Expression
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType
key = "FindCall"
from slither.core.variables.variable import Variable
key = 'FindCall'
def get(expression):
val = expression.context[key]
@ -13,12 +12,13 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class FindCalls(ExpressionVisitor):
def result(self):
class FindCalls(ExpressionVisitor):
def result(self) -> List[Expression]:
if self._result is None:
self._result = list(set(get(self.expression)))
return self._result

@ -4,7 +4,8 @@ from slither.core.expressions.index_access import IndexAccess
from slither.visitors.expression.right_value import RightValue
key = 'FindPush'
key = "FindPush"
def get(expression):
val = expression.context[key]
@ -12,11 +13,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class FindPush(ExpressionVisitor):
class FindPush(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -66,7 +68,7 @@ class FindPush(ExpressionVisitor):
def _post_member_access(self, expression):
val = []
if expression.member_name == 'push':
if expression.member_name == "push":
right = RightValue(expression.expression)
val = right.result()
set_val(expression, val)

@ -1,13 +1,12 @@
from slither.visitors.expression.expression import ExpressionVisitor
class HasConditional(ExpressionVisitor):
class HasConditional(ExpressionVisitor):
def result(self):
# == True, to convert None to false
return self._result is True
def _post_conditional_expression(self, expression):
# if self._result is True:
# raise('Slither does not support nested ternary operator')
# if self._result is True:
# raise('Slither does not support nested ternary operator')
self._result = True

@ -6,7 +6,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp
from slither.core.variables.variable import Variable
key = 'LeftValue'
key = "LeftValue"
def get(expression):
val = expression.context[key]
@ -14,11 +15,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class LeftValue(ExpressionVisitor):
class LeftValue(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -64,8 +66,8 @@ class LeftValue(ExpressionVisitor):
def _post_identifier(self, expression):
if isinstance(expression.value, Variable):
set_val(expression, [expression.value])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
else:
set_val(expression, [])

@ -1,4 +1,3 @@
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperationType
@ -6,7 +5,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp
from slither.core.variables.variable import Variable
from slither.core.declarations.solidity_variables import SolidityVariable
key = 'ReadVar'
key = "ReadVar"
def get(expression):
val = expression.context[key]
@ -14,17 +14,17 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class ReadVar(ExpressionVisitor):
class ReadVar(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
return self._result
# overide assignement
# dont explore if its direct assignement (we explore if its +=, -=, ...)
def _visit_assignement_operation(self, expression):

@ -10,7 +10,8 @@ from slither.core.expressions.expression import Expression
from slither.core.variables.variable import Variable
key = 'RightValue'
key = "RightValue"
def get(expression):
val = expression.context[key]
@ -18,11 +19,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class RightValue(ExpressionVisitor):
class RightValue(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -68,8 +70,8 @@ class RightValue(ExpressionVisitor):
def _post_identifier(self, expression):
if isinstance(expression.value, Variable):
set_val(expression, [expression.value])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
# elif isinstance(expression.value, SolidityInbuilt):
# set_val(expression, [expression])
else:
set_val(expression, [])

@ -1,4 +1,3 @@
from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.expressions.assignment_operation import AssignmentOperation
@ -10,7 +9,8 @@ from slither.core.expressions.member_access import MemberAccess
from slither.core.expressions.index_access import IndexAccess
key = 'WriteVar'
key = "WriteVar"
def get(expression):
val = expression.context[key]
@ -18,11 +18,12 @@ def get(expression):
del expression.context[key]
return val
def set_val(expression, val):
expression.context[key] = val
class WriteVar(ExpressionVisitor):
class WriteVar(ExpressionVisitor):
def result(self):
if self._result is None:
self._result = list(set(get(self.expression)))
@ -71,27 +72,28 @@ class WriteVar(ExpressionVisitor):
set_val(expression, [expression])
else:
set_val(expression, [])
# if isinstance(expression.value, Variable):
# set_val(expression, [expression.value])
# else:
# set_val(expression, [])
# if isinstance(expression.value, Variable):
# set_val(expression, [expression.value])
# else:
# set_val(expression, [])
def _post_index_access(self, expression):
left = get(expression.expression_left)
right = get(expression.expression_right)
val = left + right
if expression.is_lvalue:
# val += [expression]
# val += [expression]
val += [expression.expression_left]
# n = expression.expression_left
# parse all the a.b[..].c[..]
# while isinstance(n, (IndexAccess, MemberAccess)):
# if isinstance(n, IndexAccess):
# val += [n.expression_left]
# n = n.expression_left
# else:
# val += [n.expression]
# n = n.expression
# n = expression.expression_left
# parse all the a.b[..].c[..]
# while isinstance(n, (IndexAccess, MemberAccess)):
# if isinstance(n, IndexAccess):
# val += [n.expression_left]
# n = n.expression_left
# else:
# val += [n.expression]
# n = n.expression
set_val(expression, val)
def _post_literal(self, expression):

Loading…
Cancel
Save