Fix all pylint issues

pull/616/head
Josselin 4 years ago
parent 683ebb924d
commit 70e609ec28
  1. 87
      examples/scripts/call_graph.py
  2. 39
      examples/scripts/convert_to_evm_ins.py
  3. 4
      examples/scripts/convert_to_ir.py
  4. 4
      examples/scripts/data_dependency.py
  5. 2
      examples/scripts/export_dominator_tree_to_dot.py
  6. 2
      examples/scripts/export_to_dot.py
  7. 2
      examples/scripts/functions_called.py
  8. 2
      examples/scripts/functions_writing.py
  9. 12
      examples/scripts/possible_paths.py
  10. 2
      examples/scripts/slithIR.py
  11. 8
      examples/scripts/taint_mapping.py
  12. 2
      examples/scripts/variable_in_condition.py
  13. 9
      pyproject.toml
  14. 5
      scripts/json_diff.py
  15. 173
      slither/__main__.py
  16. 1
      slither/all_exceptions.py
  17. 47
      slither/analyses/data_dependency/data_dependency.py
  18. 37
      slither/analyses/evm/convert.py
  19. 1
      slither/analyses/evm/evm_cfg_builder.py
  20. 19
      slither/analyses/write/are_variables_written.py
  21. 96
      slither/core/cfg/node.py
  22. 2
      slither/core/children/child_node.py
  23. 2
      slither/core/context/context.py
  24. 6
      slither/core/declarations/__init__.py
  25. 177
      slither/core/declarations/contract.py
  26. 270
      slither/core/declarations/function.py
  27. 5
      slither/core/declarations/pragma_directive.py
  28. 13
      slither/core/declarations/solidity_variables.py
  29. 2
      slither/core/dominators/node_dominator_tree.py
  30. 1
      slither/core/dominators/utils.py
  31. 12
      slither/core/expressions/assignment_operation.py
  32. 16
      slither/core/expressions/binary_operation.py
  33. 2
      slither/core/expressions/call_expression.py
  34. 4
      slither/core/expressions/expression_typed.py
  35. 4
      slither/core/expressions/literal.py
  36. 1
      slither/core/expressions/super_call_expression.py
  37. 13
      slither/core/expressions/unary_operation.py
  38. 60
      slither/core/slither_core.py
  39. 4
      slither/core/solidity_types/elementary_type.py
  40. 4
      slither/core/solidity_types/function_type.py
  41. 1
      slither/core/solidity_types/type_information.py
  42. 9
      slither/core/solidity_types/user_defined_type.py
  43. 12
      slither/core/source_mapping/source_mapping.py
  44. 2
      slither/core/variables/event_variable.py
  45. 2
      slither/core/variables/local_variable.py
  46. 17
      slither/core/variables/state_variable.py
  47. 2
      slither/core/variables/structure_variable.py
  48. 5
      slither/core/variables/variable.py
  49. 30
      slither/detectors/abstract_detector.py
  50. 1
      slither/detectors/all_detectors.py
  51. 8
      slither/detectors/attributes/const_functions_asm.py
  52. 13
      slither/detectors/attributes/const_functions_state.py
  53. 4
      slither/detectors/attributes/constant_pragma.py
  54. 25
      slither/detectors/attributes/incorrect_solc.py
  55. 19
      slither/detectors/attributes/locked_ether.py
  56. 29
      slither/detectors/erc/incorrect_erc20_interface.py
  57. 55
      slither/detectors/erc/incorrect_erc721_interface.py
  58. 6
      slither/detectors/erc/unindexed_event_parameters.py
  59. 4
      slither/detectors/examples/backdoor.py
  60. 132
      slither/detectors/functions/arbitrary_send.py
  61. 105
      slither/detectors/functions/complex_function.py
  62. 40
      slither/detectors/functions/external_function.py
  63. 4
      slither/detectors/functions/suicidal.py
  64. 36
      slither/detectors/naming_convention/naming_convention.py
  65. 23
      slither/detectors/operations/block_timestamp.py
  66. 4
      slither/detectors/operations/low_level_calls.py
  67. 8
      slither/detectors/operations/unchecked_low_level_return_values.py
  68. 8
      slither/detectors/operations/unchecked_send_return_value.py
  69. 20
      slither/detectors/operations/unused_return_values.py
  70. 4
      slither/detectors/operations/void_constructor.py
  71. 63
      slither/detectors/reentrancy/reentrancy.py
  72. 42
      slither/detectors/reentrancy/reentrancy_benign.py
  73. 48
      slither/detectors/reentrancy/reentrancy_eth.py
  74. 38
      slither/detectors/reentrancy/reentrancy_events.py
  75. 52
      slither/detectors/reentrancy/reentrancy_no_gas.py
  76. 46
      slither/detectors/reentrancy/reentrancy_read_before_write.py
  77. 31
      slither/detectors/shadowing/abstract.py
  78. 5
      slither/detectors/shadowing/builtin_symbols.py
  79. 16
      slither/detectors/shadowing/local.py
  80. 29
      slither/detectors/shadowing/state.py
  81. 18
      slither/detectors/slither/name_reused.py
  82. 44
      slither/detectors/source/rtlo.py
  83. 4
      slither/detectors/statements/assembly.py
  84. 20
      slither/detectors/statements/boolean_constant_equality.py
  85. 23
      slither/detectors/statements/boolean_constant_misuse.py
  86. 12
      slither/detectors/statements/calls_in_loop.py
  87. 38
      slither/detectors/statements/controlled_delegatecall.py
  88. 19
      slither/detectors/statements/deprecated_calls.py
  89. 15
      slither/detectors/statements/divide_before_multiply.py
  90. 31
      slither/detectors/statements/incorrect_strict_equality.py
  91. 4
      slither/detectors/statements/too_many_digits.py
  92. 4
      slither/detectors/statements/tx_origin.py
  93. 111
      slither/detectors/statements/type_based_tautology.py
  94. 35
      slither/detectors/variables/possible_const_state_variables.py
  95. 24
      slither/detectors/variables/uninitialized_local_variables.py
  96. 7
      slither/detectors/variables/uninitialized_state_variables.py
  97. 24
      slither/detectors/variables/uninitialized_storage_variables.py
  98. 75
      slither/detectors/variables/unused_state_variables.py
  99. 2
      slither/formatters/attributes/const_functions.py
  100. 31
      slither/formatters/attributes/constant_pragma.py
  101. Some files were not shown because too many files have changed in this diff Show More

@ -1,87 +0,0 @@
import os
import logging
import argparse
from slither import Slither
from slither.printers.all_printers import PrinterCallGraph
from slither.core.declarations.function import Function
logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO)
logging.getLogger("Printers").setLevel(logging.INFO)
class PrinterCallGraphStateChange(PrinterCallGraph):
def _process_function(
self,
contract,
function,
contract_functions,
contract_calls,
solidity_functions,
solidity_calls,
external_calls,
all_contracts,
):
if function.view or function.pure:
return
super()._process_function(
contract,
function,
contract_functions,
contract_calls,
solidity_functions,
solidity_calls,
external_calls,
all_contracts,
)
def _process_internal_call(
self, contract, function, internal_call, contract_calls, solidity_functions, solidity_calls
):
if isinstance(internal_call, Function):
if internal_call.view or internal_call.pure:
return
super()._process_internal_call(
contract, function, internal_call, contract_calls, solidity_functions, solidity_calls
)
def _process_external_call(
self, contract, function, external_call, contract_functions, external_calls, all_contracts
):
if isinstance(external_call[1], Function):
if external_call[1].view or external_call[1].pure:
return
super()._process_external_call(
contract, function, external_call, contract_functions, external_calls, all_contracts
)
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="Call graph printer. Similar to --print call-graph, but without printing the view/pure functions",
usage="call_graph.py filename",
)
parser.add_argument(
"filename", help="The filename of the contract or truffle directory to analyze."
)
parser.add_argument("--solc", help="solc path", default="solc")
return parser.parse_args()
def main():
args = parse_args()
slither = Slither(args.filename, is_truffle=os.path.isdir(args.filename), solc=args.solc)
slither.register_printer(PrinterCallGraphStateChange)
slither.run_printers()
if __name__ == "__main__":
main()

@ -1,39 +0,0 @@
import sys
from slither.slither import Slither
from slither.evm.convert import SourceToEVM
if len(sys.argv) != 2:
print("python3 function_called.py functions_called.sol")
exit(-1)
# Init slither
slither = Slither(sys.argv[1])
# Get the contract evm instructions
contract = slither.get_contract_from_name("Test")
contract_ins = SourceToEVM.get_evm_instructions(contract)
print("## Contract evm instructions: {} ##".format(contract.name))
for ins in contract_ins:
print(str(ins))
# Get the constructor evm instructions
constructor = contract.constructor
print("## Function evm instructions: {} ##".format(constructor.name))
constructor_ins = SourceToEVM.get_evm_instructions(constructor)
for ins in constructor_ins:
print(str(ins))
# Get the function evm instructions
function = contract.get_function_from_signature("foo()")
print("## Function evm instructions: {} ##".format(function.name))
function_ins = SourceToEVM.get_evm_instructions(function)
for ins in function_ins:
print(str(ins))
# Get the node evm instructions
nodes = function.nodes
for node in nodes:
node_ins = SourceToEVM.get_evm_instructions(node)
print("Node evm instructions: {}".format(str(node)))
for ins in node_ins:
print(str(ins))

@ -5,7 +5,7 @@ from slither.slithir.convert import convert_expression
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python function_called.py functions_called.sol") print("python function_called.py functions_called.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])
@ -21,7 +21,7 @@ nodes = test.nodes
for node in nodes: for node in nodes:
if node.expression: if node.expression:
print("Expression:\n\t{}".format(node.expression)) print("Expression:\n\t{}".format(node.expression))
irs = convert_expression(node.expression) irs = convert_expression(node.expression, node)
print("IR expressions:") print("IR expressions:")
for ir in irs: for ir in irs:
print("\t{}".format(ir)) print("\t{}".format(ir))

@ -1,15 +1,15 @@
import sys import sys
from slither import Slither from slither import Slither
from slither.analyses.data_dependency.data_dependency import ( from slither.analyses.data_dependency.data_dependency import (
is_dependent, is_dependent,
is_tainted, is_tainted,
pprint_dependency,
) )
from slither.core.declarations.solidity_variables import SolidityVariableComposed from slither.core.declarations.solidity_variables import SolidityVariableComposed
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python data_dependency.py file.sol") print("Usage: python data_dependency.py file.sol")
exit(-1) sys.exit(-1)
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -4,7 +4,7 @@ from slither.slither import Slither
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python export_dominator_tree_to_dot.py contract.sol") print("python export_dominator_tree_to_dot.py contract.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -4,7 +4,7 @@ from slither.slither import Slither
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python function_called.py contract.sol") print("python function_called.py contract.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -3,7 +3,7 @@ from slither.slither import Slither
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python functions_called.py functions_called.sol") print("python functions_called.py functions_called.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -3,7 +3,7 @@ from slither.slither import Slither
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python function_writing.py functions_writing.sol") print("python function_writing.py functions_writing.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -84,8 +84,8 @@ def all_function_definitions(function):
] ]
def __find_target_paths(target_function, current_path=[]): def __find_target_paths(target_function, current_path=None):
current_path = current_path if current_path else []
# Create our results list # Create our results list
results = set() results = set()
@ -184,17 +184,17 @@ slither = Slither(args.filename, is_truffle=args.is_truffle)
targets = resolve_functions(args.targets) targets = resolve_functions(args.targets)
# Print out all target functions. # Print out all target functions.
print(f"Target functions:") print("Target functions:")
for target in targets: for target in targets:
print(f"-{target.contract.name}.{target.full_name}") print(f"-{target.contract.name}.{target.full_name}")
print("\n") print("\n")
# Obtain all paths which reach the target functions. # Obtain all paths which reach the target functions.
reaching_paths = find_target_paths(targets) reaching_paths = find_target_paths(targets)
reaching_functions = set([y for x in reaching_paths for y in x if y not in targets]) reaching_functions = {y for x in reaching_paths for y in x if y not in targets}
# Print out all function names which can reach the targets. # Print out all function names which can reach the targets.
print(f"The following functions reach the specified targets:") print("The following functions reach the specified targets:")
for function_desc in sorted([f"{f.canonical_name}" for f in reaching_functions]): for function_desc in sorted([f"{f.canonical_name}" for f in reaching_functions]):
print(f"-{function_desc}") print(f"-{function_desc}")
print("\n") print("\n")
@ -205,6 +205,6 @@ reaching_paths_str = [
] ]
# Print a sorted list of all function paths which can reach the targets. # Print a sorted list of all function paths which can reach the targets.
print(f"The following paths reach the specified targets:") print("The following paths reach the specified targets:")
for reaching_path in sorted(reaching_paths_str): for reaching_path in sorted(reaching_paths_str):
print(f"{reaching_path}\n") print(f"{reaching_path}\n")

@ -3,7 +3,7 @@ from slither import Slither
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python slithIR.py contract.sol") print("python slithIR.py contract.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -58,7 +58,7 @@ def check_call(func, taints):
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python taint_mapping.py taint.sol") print("python taint_mapping.py taint.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])
@ -79,11 +79,11 @@ if __name__ == "__main__":
visit_node(function.entry_point, []) visit_node(function.entry_point, [])
print("All variables tainted : {}".format([str(v) for v in slither.context[KEY]])) print("All variables tainted : {}".format([str(v) for v in slither.context[KEY]]))
for function in contract.functions:
check_call(function, slither.context[KEY])
print( print(
"All state variables tainted : {}".format( "All state variables tainted : {}".format(
[str(v) for v in prev_taints if isinstance(v, StateVariable)] [str(v) for v in prev_taints if isinstance(v, StateVariable)]
) )
) )
for function in contract.functions:
check_call(function, slither.context[KEY])

@ -3,7 +3,7 @@ from slither.slither import Slither
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("python variable_in_condition.py variable_in_condition.sol") print("python variable_in_condition.py variable_in_condition.sol")
exit(-1) sys.exit(-1)
# Init slither # Init slither
slither = Slither(sys.argv[1]) slither = Slither(sys.argv[1])

@ -9,5 +9,12 @@ missing-function-docstring,
unnecessary-lambda, unnecessary-lambda,
bad-continuation, bad-continuation,
cyclic-import, cyclic-import,
line-too-long line-too-long,
invalid-name,
fixme,
too-many-return-statements,
too-many-ancestors,
logging-fstring-interpolation,
logging-not-lazy,
duplicate-code
""" """

@ -1,11 +1,12 @@
import sys import sys
import json import json
from deepdiff import DeepDiff # pip install deepdiff
from pprint import pprint from pprint import pprint
from deepdiff import DeepDiff # pip install deepdiff
if len(sys.argv) != 3: if len(sys.argv) != 3:
print("Usage: python json_diff.py 1.json 2.json") print("Usage: python json_diff.py 1.json 2.json")
exit(-1) sys.exit(-1)
with open(sys.argv[1], encoding="utf8") as f: with open(sys.argv[1], encoding="utf8") as f:
d1 = json.load(f) d1 = json.load(f)

@ -10,8 +10,10 @@ import sys
import traceback import traceback
from pkg_resources import iter_entry_points, require from pkg_resources import iter_entry_points, require
from crytic_compile import cryticparser from crytic_compile import cryticparser
from crytic_compile.platform.standard import generate_standard_export from crytic_compile.platform.standard import generate_standard_export
from crytic_compile import compile_all, is_supported
from slither.detectors import all_detectors from slither.detectors import all_detectors
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
@ -34,7 +36,6 @@ from slither.utils.command_line import (
JSON_OUTPUT_TYPES, JSON_OUTPUT_TYPES,
DEFAULT_JSON_OUTPUT_TYPES, DEFAULT_JSON_OUTPUT_TYPES,
) )
from crytic_compile import compile_all, is_supported
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
logging.basicConfig() logging.basicConfig()
@ -80,7 +81,12 @@ def process_all(target, args, detector_classes, printer_classes):
results_printers.extend(current_results_printers) results_printers.extend(current_results_printers)
slither_instances.append(slither) slither_instances.append(slither)
analyzed_contracts_count += current_analyzed_count analyzed_contracts_count += current_analyzed_count
return slither_instances, results_detectors, results_printers, analyzed_contracts_count return (
slither_instances,
results_detectors,
results_printers,
analyzed_contracts_count,
)
def _process(slither, detector_classes, printer_classes): def _process(slither, detector_classes, printer_classes):
@ -98,7 +104,9 @@ def _process(slither, detector_classes, printer_classes):
if not printer_classes: if not printer_classes:
detector_results = slither.run_detectors() detector_results = slither.run_detectors()
detector_results = [x for x in detector_results if x] # remove empty results detector_results = [x for x in detector_results if x] # remove empty results
detector_results = [item for sublist in detector_results for item in sublist] # flatten detector_results = [
item for sublist in detector_results for item in sublist
] # flatten
results_detectors.extend(detector_results) results_detectors.extend(detector_results)
else: else:
@ -113,8 +121,8 @@ def process_from_asts(filenames, args, detector_classes, printer_classes):
all_contracts = [] all_contracts = []
for filename in filenames: for filename in filenames:
with open(filename, encoding="utf8") as f: with open(filename, encoding="utf8") as file_open:
contract_loaded = json.load(f) contract_loaded = json.load(file_open)
all_contracts.append(contract_loaded["ast"]) all_contracts.append(contract_loaded["ast"])
return process_single(all_contracts, args, detector_classes, printer_classes) return process_single(all_contracts, args, detector_classes, printer_classes)
@ -128,7 +136,7 @@ def process_from_asts(filenames, args, detector_classes, printer_classes):
################################################################################### ###################################################################################
def exit(results): def my_exit(results):
if not results: if not results:
sys.exit(0) sys.exit(0)
sys.exit(len(results)) sys.exit(len(results))
@ -148,10 +156,14 @@ def get_detectors_and_printers():
""" """
detectors = [getattr(all_detectors, name) for name in dir(all_detectors)] detectors = [getattr(all_detectors, name) for name in dir(all_detectors)]
detectors = [d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector)] detectors = [
d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector)
]
printers = [getattr(all_printers, name) for name in dir(all_printers)] printers = [getattr(all_printers, name) for name in dir(all_printers)]
printers = [p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter)] printers = [
p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter)
]
# Handle plugins! # Handle plugins!
for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None): for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None):
@ -159,11 +171,16 @@ def get_detectors_and_printers():
plugin_detectors, plugin_printers = make_plugin() plugin_detectors, plugin_printers = make_plugin()
if not all(issubclass(d, AbstractDetector) for d in plugin_detectors): detector = None
raise Exception("Error when loading plugin %s, %r is not a detector" % (entry_point, d)) if not all(issubclass(detector, AbstractDetector) for detector in plugin_detectors):
raise Exception(
if not all(issubclass(p, AbstractPrinter) for p in plugin_printers): "Error when loading plugin %s, %r is not a detector" % (entry_point, detector)
raise Exception("Error when loading plugin %s, %r is not a printer" % (entry_point, p)) )
printer = None
if not all(issubclass(printer, AbstractPrinter) for printer in plugin_printers):
raise Exception(
"Error when loading plugin %s, %r is not a printer" % (entry_point, printer)
)
# We convert those to lists in case someone returns a tuple # We convert those to lists in case someone returns a tuple
detectors += list(plugin_detectors) detectors += list(plugin_detectors)
@ -171,7 +188,7 @@ def get_detectors_and_printers():
return detectors, printers return detectors, printers
# pylint: disable=too-many-branches
def choose_detectors(args, all_detector_classes): def choose_detectors(args, all_detector_classes):
# If detectors are specified, run only these ones # If detectors are specified, run only these ones
@ -182,35 +199,43 @@ def choose_detectors(args, all_detector_classes):
detectors_to_run = all_detector_classes detectors_to_run = all_detector_classes
if args.detectors_to_exclude: if args.detectors_to_exclude:
detectors_excluded = args.detectors_to_exclude.split(",") detectors_excluded = args.detectors_to_exclude.split(",")
for d in detectors: for detector in detectors:
if d in detectors_excluded: if detector in detectors_excluded:
detectors_to_run.remove(detectors[d]) detectors_to_run.remove(detectors[detector])
else: else:
for d in args.detectors_to_run.split(","): for detector in args.detectors_to_run.split(","):
if d in detectors: if detector in detectors:
detectors_to_run.append(detectors[d]) detectors_to_run.append(detectors[detector])
else: else:
raise Exception("Error: {} is not a detector".format(d)) raise Exception("Error: {} is not a detector".format(detector))
detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT) detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT)
return detectors_to_run return detectors_to_run
if args.exclude_optimization: if args.exclude_optimization:
detectors_to_run = [ detectors_to_run = [
d for d in detectors_to_run if d.IMPACT != DetectorClassification.OPTIMIZATION d
for d in detectors_to_run
if d.IMPACT != DetectorClassification.OPTIMIZATION
] ]
if args.exclude_informational: if args.exclude_informational:
detectors_to_run = [ detectors_to_run = [
d for d in detectors_to_run if d.IMPACT != DetectorClassification.INFORMATIONAL d
for d in detectors_to_run
if d.IMPACT != DetectorClassification.INFORMATIONAL
] ]
if args.exclude_low: if args.exclude_low:
detectors_to_run = [d for d in detectors_to_run if d.IMPACT != DetectorClassification.LOW] detectors_to_run = [
d for d in detectors_to_run if d.IMPACT != DetectorClassification.LOW
]
if args.exclude_medium: if args.exclude_medium:
detectors_to_run = [ detectors_to_run = [
d for d in detectors_to_run if d.IMPACT != DetectorClassification.MEDIUM d for d in detectors_to_run if d.IMPACT != DetectorClassification.MEDIUM
] ]
if args.exclude_high: if args.exclude_high:
detectors_to_run = [d for d in detectors_to_run if d.IMPACT != DetectorClassification.HIGH] detectors_to_run = [
d for d in detectors_to_run if d.IMPACT != DetectorClassification.HIGH
]
if args.detectors_to_exclude: if args.detectors_to_exclude:
detectors_to_run = [ detectors_to_run = [
d for d in detectors_to_run if d.ARGUMENT not in args.detectors_to_exclude d for d in detectors_to_run if d.ARGUMENT not in args.detectors_to_exclude
@ -232,11 +257,11 @@ def choose_printers(args, all_printer_classes):
return all_printer_classes return all_printer_classes
printers = {p.ARGUMENT: p for p in all_printer_classes} printers = {p.ARGUMENT: p for p in all_printer_classes}
for p in args.printers_to_run.split(","): for printer in args.printers_to_run.split(","):
if p in printers: if printer in printers:
printers_to_run.append(printers[p]) printers_to_run.append(printers[printer])
else: else:
raise Exception("Error: {} is not a printer".format(p)) raise Exception("Error: {} is not a printer".format(printer))
return printers_to_run return printers_to_run
@ -278,7 +303,9 @@ def parse_args(detector_classes, printer_classes):
group_detector.add_argument( group_detector.add_argument(
"--detect", "--detect",
help="Comma-separated list of detectors, defaults to all, " help="Comma-separated list of detectors, defaults to all, "
"available detectors: {}".format(", ".join(d.ARGUMENT for d in detector_classes)), "available detectors: {}".format(
", ".join(d.ARGUMENT for d in detector_classes)
),
action="store", action="store",
dest="detectors_to_run", dest="detectors_to_run",
default=defaults_flag_in_config["detectors_to_run"], default=defaults_flag_in_config["detectors_to_run"],
@ -287,7 +314,7 @@ def parse_args(detector_classes, printer_classes):
group_printer.add_argument( group_printer.add_argument(
"--print", "--print",
help="Comma-separated list fo contract information printers, " help="Comma-separated list fo contract information printers, "
"available printers: {}".format(", ".join(d.ARGUMENT for d in printer_classes)), "available printers: {}".format(", ".join(d.ARGUMENT for d in printer_classes)),
action="store", action="store",
dest="printers_to_run", dest="printers_to_run",
default=defaults_flag_in_config["printers_to_run"], default=defaults_flag_in_config["printers_to_run"],
@ -368,9 +395,9 @@ def parse_args(detector_classes, printer_classes):
group_misc.add_argument( group_misc.add_argument(
"--json-types", "--json-types",
help=f"Comma-separated list of result types to output to JSON, defaults to " help="Comma-separated list of result types to output to JSON, defaults to "
+ f'{",".join(output_type for output_type in DEFAULT_JSON_OUTPUT_TYPES)}. ' + f'{",".join(output_type for output_type in DEFAULT_JSON_OUTPUT_TYPES)}. '
+ f'Available types: {",".join(output_type for output_type in JSON_OUTPUT_TYPES)}', + f'Available types: {",".join(output_type for output_type in JSON_OUTPUT_TYPES)}',
action="store", action="store",
default=defaults_flag_in_config["json-types"], default=defaults_flag_in_config["json-types"],
) )
@ -390,7 +417,10 @@ def parse_args(detector_classes, printer_classes):
) )
group_misc.add_argument( group_misc.add_argument(
"--markdown-root", help="URL for markdown generation", action="store", default="" "--markdown-root",
help="URL for markdown generation",
action="store",
default="",
) )
group_misc.add_argument( group_misc.add_argument(
@ -425,7 +455,10 @@ def parse_args(detector_classes, printer_classes):
) )
group_misc.add_argument( group_misc.add_argument(
"--solc-ast", help="Provide the contract as a json AST", action="store_true", default=False "--solc-ast",
help="Provide the contract as a json AST",
action="store_true",
default=False,
) )
group_misc.add_argument( group_misc.add_argument(
@ -436,9 +469,13 @@ def parse_args(detector_classes, printer_classes):
) )
# debugger command # debugger command
parser.add_argument("--debug", help=argparse.SUPPRESS, action="store_true", default=False) parser.add_argument(
"--debug", help=argparse.SUPPRESS, action="store_true", default=False
)
parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False) parser.add_argument(
"--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False
)
group_misc.add_argument( group_misc.add_argument(
"--checklist", help=argparse.SUPPRESS, action="store_true", default=False "--checklist", help=argparse.SUPPRESS, action="store_true", default=False
@ -471,7 +508,9 @@ def parse_args(detector_classes, printer_classes):
) )
# if the json is splitted in different files # if the json is splitted in different files
parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False) parser.add_argument(
"--splitted", help=argparse.SUPPRESS, action="store_true", default=False
)
# Disable the throw/catch on partial analyses # Disable the throw/catch on partial analyses
parser.add_argument( parser.add_argument(
@ -491,41 +530,43 @@ def parse_args(detector_classes, printer_classes):
args.json_types = set(args.json_types.split(",")) args.json_types = set(args.json_types.split(","))
for json_type in args.json_types: for json_type in args.json_types:
if json_type not in JSON_OUTPUT_TYPES: if json_type not in JSON_OUTPUT_TYPES:
raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') raise Exception(
f'Error: "{json_type}" is not a valid JSON result output type.'
)
return args return args
class ListDetectors(argparse.Action): class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
output_detectors(detectors) output_detectors(detectors)
parser.exit() parser.exit()
class ListDetectorsJson(argparse.Action): class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
detector_types_json = output_detectors_json(detectors) detector_types_json = output_detectors_json(detectors)
print(json.dumps(detector_types_json)) print(json.dumps(detector_types_json))
parser.exit() parser.exit()
class ListPrinters(argparse.Action): class ListPrinters(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
_, printers = get_detectors_and_printers() _, printers = get_detectors_and_printers()
output_printers(printers) output_printers(printers)
parser.exit() parser.exit()
class OutputMarkdown(argparse.Action): class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, args, values, option_string=None): def __call__(self, parser, args, values, option_string=None):
detectors, printers = get_detectors_and_printers() detectors, printers = get_detectors_and_printers()
output_to_markdown(detectors, printers, values) output_to_markdown(detectors, printers, values)
parser.exit() parser.exit()
class OutputWiki(argparse.Action): class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, args, values, option_string=None): def __call__(self, parser, args, values, option_string=None):
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
output_wiki(detectors, values) output_wiki(detectors, values)
@ -569,6 +610,7 @@ def main():
main_impl(all_detector_classes=detectors, all_printer_classes=printers) main_impl(all_detector_classes=detectors, all_printer_classes=printers)
# pylint: disable=too-many-statements,too-many-branches,too-many-locals
def main_impl(all_detector_classes, all_printer_classes): def main_impl(all_detector_classes, all_printer_classes):
""" """
:param all_detector_classes: A list of all detectors that can be included/excluded. :param all_detector_classes: A list of all detectors that can be included/excluded.
@ -588,9 +630,8 @@ def main_impl(all_detector_classes, all_printer_classes):
outputting_json_stdout = args.json == "-" outputting_json_stdout = args.json == "-"
outputting_zip = args.zip is not None outputting_zip = args.zip is not None
if args.zip_type not in ZIP_TYPES_ACCEPTED.keys(): if args.zip_type not in ZIP_TYPES_ACCEPTED.keys():
logger.error( to_log = f'Zip type not accepted, it must be one of {",".join(ZIP_TYPES_ACCEPTED.keys())}'
f'Zip type not accepted, it must be one of {",".join(ZIP_TYPES_ACCEPTED.keys())}' logger.error(to_log)
)
# If we are outputting JSON, capture all standard output. If we are outputting to stdout, we block typical stdout # If we are outputting JSON, capture all standard output. If we are outputting to stdout, we block typical stdout
# output. # output.
@ -616,8 +657,8 @@ def main_impl(all_detector_classes, all_printer_classes):
("Printers", default_log), ("Printers", default_log),
# ('CryticCompile', default_log) # ('CryticCompile', default_log)
]: ]:
l = logging.getLogger(l_name) logger_level = logging.getLogger(l_name)
l.setLevel(l_level) logger_level.setLevel(l_level)
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
@ -649,7 +690,9 @@ def main_impl(all_detector_classes, all_printer_classes):
results_detectors, results_detectors,
results_printers, results_printers,
number_contracts, number_contracts,
) = process_from_asts(filenames, args, detector_classes, printer_classes) ) = process_from_asts(
filenames, args, detector_classes, printer_classes
)
slither_instances.append(slither_instance) slither_instances.append(slither_instance)
else: else:
for filename in filenames: for filename in filenames:
@ -658,7 +701,9 @@ def main_impl(all_detector_classes, all_printer_classes):
results_detectors_tmp, results_detectors_tmp,
results_printers_tmp, results_printers_tmp,
number_contracts_tmp, number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes) ) = process_single(
filename, args, detector_classes, printer_classes
)
number_contracts += number_contracts_tmp number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp results_detectors += results_detectors_tmp
results_printers += results_printers_tmp results_printers += results_printers_tmp
@ -728,17 +773,19 @@ def main_impl(all_detector_classes, all_printer_classes):
if args.ignore_return_value: if args.ignore_return_value:
return return
except SlitherException as se: except SlitherException as slither_exception:
output_error = str(se) output_error = str(slither_exception)
traceback.print_exc() traceback.print_exc()
logging.error(red("Error:")) logging.error(red("Error:"))
logging.error(red(output_error)) logging.error(red(output_error))
logging.error("Please report an issue to https://github.com/crytic/slither/issues") logging.error(
"Please report an issue to https://github.com/crytic/slither/issues"
)
except Exception: except Exception: # pylint: disable=broad-except
output_error = traceback.format_exc() output_error = traceback.format_exc()
logging.error(traceback.print_exc()) logging.error(traceback.print_exc())
logging.error("Error in %s" % args.filename) logging.error(f"Error in {args.filename}") # pylint: disable=logging-fstring-interpolation
logging.error(output_error) logging.error(output_error)
# If we are outputting JSON, capture the redirected output and disable the redirect to output the final JSON. # If we are outputting JSON, capture the redirected output and disable the redirect to output the final JSON.
@ -749,7 +796,9 @@ def main_impl(all_detector_classes, all_printer_classes):
"stderr": StandardOutputCapture.get_stderr_output(), "stderr": StandardOutputCapture.get_stderr_output(),
} }
StandardOutputCapture.disable() StandardOutputCapture.disable()
output_to_json(None if outputting_json_stdout else args.json, output_error, json_results) output_to_json(
None if outputting_json_stdout else args.json, output_error, json_results
)
if outputting_zip: if outputting_zip:
output_to_zip(args.zip, output_error, json_results, args.zip_type) output_to_zip(args.zip, output_error, json_results, args.zip_type)
@ -758,7 +807,7 @@ def main_impl(all_detector_classes, all_printer_classes):
if output_error: if output_error:
sys.exit(-1) sys.exit(-1)
else: else:
exit(results_detectors) my_exit(results_detectors)
if __name__ == "__main__": if __name__ == "__main__":

@ -1,6 +1,7 @@
""" """
This module import all slither exceptions This module import all slither exceptions
""" """
# pylint: disable=unused-import
from slither.slithir.exceptions import SlithIRError from slither.slithir.exceptions import SlithIRError
from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.exceptions import ParsingError, VariableNotFound
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError

@ -112,7 +112,9 @@ def is_tainted(variable, context, only_unprotected=False, ignore_generic_taint=F
) )
def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_taint=False): def is_tainted_ssa(
variable, context, only_unprotected=False, ignore_generic_taint=False
):
""" """
Args: Args:
variable variable
@ -135,7 +137,9 @@ def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_tai
def get_dependencies( def get_dependencies(
variable: Variable, context: Union[Contract, Function], only_unprotected: bool = False variable: Variable,
context: Union[Contract, Function],
only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
Return the variables for which `variable` depends on. Return the variables for which `variable` depends on.
@ -170,7 +174,9 @@ def get_all_dependencies(
def get_dependencies_ssa( def get_dependencies_ssa(
variable: Variable, context: Union[Contract, Function], only_unprotected: bool = False variable: Variable,
context: Union[Contract, Function],
only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
Return the variables for which `variable` depends on (SSA version). Return the variables for which `variable` depends on (SSA version).
@ -272,8 +278,11 @@ def compute_dependency_contract(contract, slither):
compute_dependency_function(function) compute_dependency_function(function)
propagate_function(contract, function, KEY_SSA, KEY_NON_SSA) propagate_function(contract, function, KEY_SSA, KEY_NON_SSA)
propagate_function(contract, function, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED) propagate_function(
contract, function, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED
)
# pylint: disable=expression-not-assigned
if function.visibility in ["public", "external"]: if function.visibility in ["public", "external"]:
[slither.context[KEY_INPUT].add(p) for p in function.parameters] [slither.context[KEY_INPUT].add(p) for p in function.parameters]
[slither.context[KEY_INPUT_SSA].add(p) for p in function.parameters_ssa] [slither.context[KEY_INPUT_SSA].add(p) for p in function.parameters_ssa]
@ -296,11 +305,12 @@ def propagate_function(contract, function, context_key, context_key_non_ssa):
def transitive_close_dependencies(context, context_key, context_key_non_ssa): def transitive_close_dependencies(context, context_key, context_key_non_ssa):
# transitive closure # transitive closure
changed = True changed = True
while changed: while changed: # pylint: disable=too-many-nested-blocks
changed = False changed = False
# Need to create new set() as its changed during iteration # Need to create new set() as its changed during iteration
data_depencencies = { data_depencencies = {
k: set([v for v in values]) for k, values in context.context[context_key].items() k: set(values)
for k, values in context.context[context_key].items()
} }
for key, items in data_depencencies.items(): for key, items in data_depencencies.items():
for item in items: for item in items:
@ -310,7 +320,9 @@ def transitive_close_dependencies(context, context_key, context_key_non_ssa):
if not additional_item in items and additional_item != key: if not additional_item in items and additional_item != key:
changed = True changed = True
context.context[context_key][key].add(additional_item) context.context[context_key][key].add(additional_item)
context.context[context_key_non_ssa] = convert_to_non_ssa(context.context[context_key]) context.context[context_key_non_ssa] = convert_to_non_ssa(
context.context[context_key]
)
def propagate_contract(contract, context_key, context_key_non_ssa): def propagate_contract(contract, context_key, context_key_non_ssa):
@ -328,7 +340,12 @@ def add_dependency(lvalue, function, ir, is_protected):
read = ir.function.return_values_ssa read = ir.function.return_values_ssa
else: else:
read = ir.read read = ir.read
[function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] # pylint: disable=expression-not-assigned
[
function.context[KEY_SSA][lvalue].add(v)
for v in read
if not isinstance(v, Constant)
]
if not is_protected: if not is_protected:
[ [
function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) function.context[KEY_SSA_UNPROTECTED][lvalue].add(v)
@ -375,7 +392,17 @@ def convert_variable_to_non_ssa(v):
): ):
return v.non_ssa_version return v.non_ssa_version
assert isinstance( assert isinstance(
v, (Constant, SolidityVariable, Contract, Enum, SolidityFunction, Structure, Function, Type) v,
(
Constant,
SolidityVariable,
Contract,
Enum,
SolidityFunction,
Structure,
Function,
Type,
),
) )
return v return v
@ -387,6 +414,6 @@ def convert_to_non_ssa(data_depencies):
var = convert_variable_to_non_ssa(k) var = convert_variable_to_non_ssa(k)
if not var in ret: if not var in ret:
ret[var] = set() ret[var] = set()
ret[var] = ret[var].union(set([convert_variable_to_non_ssa(v) for v in values])) ret[var] = ret[var].union({convert_variable_to_non_ssa(v) for v in values})
return ret return ret

@ -3,7 +3,7 @@ from slither.core.declarations import Contract, Function
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.utils.function import get_function_id from slither.utils.function import get_function_id
from slither.exceptions import SlitherError from slither.exceptions import SlitherError
from .evm_cfg_builder import load_evm_cfg_builder from slither.analyses.evm.evm_cfg_builder import load_evm_cfg_builder
logger = logging.getLogger("ConvertToEVM") logger = logging.getLogger("ConvertToEVM")
@ -101,16 +101,17 @@ def _get_evm_instructions_function(function_info):
# Todo: Could rename it appropriately in evm-cfg-builder # Todo: Could rename it appropriately in evm-cfg-builder
# by detecting that init bytecode is being parsed. # by detecting that init bytecode is being parsed.
name = "_dispatcher" name = "_dispatcher"
hash = "" func_hash = ""
else: else:
cfg = function_info["contract_info"]["cfg"] cfg = function_info["contract_info"]["cfg"]
name = function.name name = function.name
# Get first four bytes of function singature's keccak-256 hash used as function selector # Get first four bytes of function singature's keccak-256 hash used as function selector
hash = str(hex(get_function_id(function.full_name))) func_hash = str(hex(get_function_id(function.full_name)))
function_evm = _get_function_evm(cfg, name, hash) function_evm = _get_function_evm(cfg, name, func_hash)
if function_evm is None: if function_evm is None:
logger.error("Function " + function.name + " not found in the EVM code") to_log = "Function " + function.name + " not found in the EVM code"
logger.error(to_log)
raise SlitherError("Function " + function.name + " not found in the EVM code") raise SlitherError("Function " + function.name + " not found in the EVM code")
function_ins = [] function_ins = []
@ -137,7 +138,10 @@ def _get_evm_instructions_node(node_info):
# Get evm instructions corresponding to node's source line number # Get evm instructions corresponding to node's source line number
node_source_line = ( node_source_line = (
contract_file[0 : node_info["node"].source_mapping["start"]].count("\n".encode("utf-8")) + 1 contract_file[0 : node_info["node"].source_mapping["start"]].count(
"\n".encode("utf-8")
)
+ 1
) )
node_pcs = contract_pcs.get(node_source_line, []) node_pcs = contract_pcs.get(node_source_line, [])
node_ins = [] node_ins = []
@ -153,14 +157,19 @@ def _get_function_evm(cfg, function_name, function_hash):
if function_evm.name[:2] == "0x" and function_evm.name == function_hash: if function_evm.name[:2] == "0x" and function_evm.name == function_hash:
return function_evm return function_evm
# Match function name # Match function name
elif function_evm.name[:2] != "0x" and function_evm.name.split("(")[0] == function_name: if (
function_evm.name[:2] != "0x"
and function_evm.name.split("(")[0] == function_name
):
return function_evm return function_evm
return None return None
# pylint: disable=too-many-locals
def generate_source_to_evm_ins_mapping(evm_instructions, srcmap_runtime, slither, filename): def generate_source_to_evm_ins_mapping(
evm_instructions, srcmap_runtime, slither, filename
):
""" """
Generate Solidity source to EVM instruction mapping using evm_cfg_builder:cfg.instructions Generate Solidity source to EVM instruction mapping using evm_cfg_builder:cfg.instructions
and solc:srcmap_runtime and solc:srcmap_runtime
Returns: Solidity source to EVM instruction mapping Returns: Solidity source to EVM instruction mapping
@ -180,11 +189,11 @@ def generate_source_to_evm_ins_mapping(evm_instructions, srcmap_runtime, slither
mapping_item = mapping.split(":") mapping_item = mapping.split(":")
mapping_item += prev_mapping[len(mapping_item) :] mapping_item += prev_mapping[len(mapping_item) :]
for i in range(len(mapping_item)): for i, _ in enumerate(mapping_item):
if mapping_item[i] == "": if mapping_item[i] == "":
mapping_item[i] = int(prev_mapping[i]) mapping_item[i] = int(prev_mapping[i])
offset, length, file_id, _ = mapping_item offset, _length, file_id, _ = mapping_item
prev_mapping = mapping_item prev_mapping = mapping_item
if file_id == "-1": if file_id == "-1":
@ -198,6 +207,8 @@ def generate_source_to_evm_ins_mapping(evm_instructions, srcmap_runtime, slither
# Append evm instructions to the corresponding source line number # Append evm instructions to the corresponding source line number
# Note: Some evm instructions in mapping are not necessarily in program execution order # Note: Some evm instructions in mapping are not necessarily in program execution order
# Note: The order depends on how solc creates the srcmap_runtime # Note: The order depends on how solc creates the srcmap_runtime
source_to_evm_mapping.setdefault(line_number, []).append(evm_instructions[idx].pc) source_to_evm_mapping.setdefault(line_number, []).append(
evm_instructions[idx].pc
)
return source_to_evm_mapping return source_to_evm_mapping

@ -7,6 +7,7 @@ logger = logging.getLogger("ConvertToEVM")
def load_evm_cfg_builder(): def load_evm_cfg_builder():
try: try:
# Avoiding the addition of evm_cfg_builder as permanent dependency # Avoiding the addition of evm_cfg_builder as permanent dependency
# pylint: disable=import-outside-toplevel
from evm_cfg_builder.cfg import CFG from evm_cfg_builder.cfg import CFG
return CFG return CFG

@ -2,7 +2,7 @@
Detect if all the given variables are written in all the paths of the function Detect if all the given variables are written in all the paths of the function
""" """
from collections import defaultdict from collections import defaultdict
from typing import Dict, Tuple, Set, List, Optional from typing import Dict, Set, List
from slither.core.cfg.node import NodeType, Node from slither.core.cfg.node import NodeType, Node
from slither.core.declarations import SolidityFunction from slither.core.declarations import SolidityFunction
@ -18,7 +18,7 @@ from slither.slithir.operations import (
from slither.slithir.variables import ReferenceVariable, TemporaryVariable from slither.slithir.variables import ReferenceVariable, TemporaryVariable
class State: class State: # pylint: disable=too-few-public-methods
def __init__(self): def __init__(self):
# Map node -> list of variables set # Map node -> list of variables set
# Were each variables set represents a configuration of a path # Were each variables set represents a configuration of a path
@ -33,8 +33,12 @@ class State:
self.nodes: Dict[Node, List[Set[Variable]]] = defaultdict(list) self.nodes: Dict[Node, List[Set[Variable]]] = defaultdict(list)
# pylint: disable=too-many-branches
def _visit( def _visit(
node: Node, state: State, variables_written: Set[Variable], variables_to_write: List[Variable] node: Node,
state: State,
variables_written: Set[Variable],
variables_to_write: List[Variable],
): ):
""" """
Explore all the nodes to look for values not written when the node's function return Explore all the nodes to look for values not written when the node's function return
@ -51,7 +55,10 @@ def _visit(
for ir in node.irs: for ir in node.irs:
if isinstance(ir, SolidityCall): if isinstance(ir, SolidityCall):
# TODO convert the revert to a THROW node # TODO convert the revert to a THROW node
if ir.function in [SolidityFunction("revert(string)"), SolidityFunction("revert()")]: if ir.function in [
SolidityFunction("revert(string)"),
SolidityFunction("revert()"),
]:
return [] return []
if not isinstance(ir, OperationWithLValue): if not isinstance(ir, OperationWithLValue):
@ -61,7 +68,9 @@ def _visit(
if isinstance(ir, (Length, Balance)): if isinstance(ir, (Length, Balance)):
refs[ir.lvalue] = ir.value refs[ir.lvalue] = ir.value
if ir.lvalue and not isinstance(ir.lvalue, (TemporaryVariable, ReferenceVariable)): if ir.lvalue and not isinstance(
ir.lvalue, (TemporaryVariable, ReferenceVariable)
):
variables_written.add(ir.lvalue) variables_written.add(ir.lvalue)
lvalue = ir.lvalue lvalue = ir.lvalue

@ -56,6 +56,8 @@ if TYPE_CHECKING:
) )
# pylint: disable=too-many-lines,too-many-branches,too-many-instance-attributes
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region NodeType # region NodeType
@ -140,8 +142,9 @@ class NodeType(Enum):
# endregion # endregion
# I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin)
class Node(SourceMapping, ChildFunction): # pylint: disable=no-member
class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-methods
""" """
Node class Node class
@ -166,8 +169,12 @@ class Node(SourceMapping, ChildFunction):
self._dominance_frontier: Set["Node"] = set() self._dominance_frontier: Set["Node"] = set()
# Phi origin # Phi origin
# key are variable name # key are variable name
self._phi_origins_state_variables: Dict[str, Tuple[StateVariable, Set["Node"]]] = {} self._phi_origins_state_variables: Dict[
self._phi_origins_local_variables: Dict[str, Tuple[LocalVariable, Set["Node"]]] = {} str, Tuple[StateVariable, Set["Node"]]
] = {}
self._phi_origins_local_variables: Dict[
str, Tuple[LocalVariable, Set["Node"]]
] = {}
# self._phi_origins_member_variables: Dict[str, Tuple[MemberVariable, Set["Node"]]] = {} # self._phi_origins_member_variables: Dict[str, Tuple[MemberVariable, Set["Node"]]] = {}
self._expression: Optional[Expression] = None self._expression: Optional[Expression] = None
@ -180,7 +187,7 @@ class Node(SourceMapping, ChildFunction):
self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_written: List["SlithIRVariable"] = []
self._ssa_vars_read: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = []
self._internal_calls: List[Function] = [] self._internal_calls: List["Function"] = []
self._solidity_calls: List[SolidityFunction] = [] self._solidity_calls: List[SolidityFunction] = []
self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls
self._library_calls: List["LibraryCallType"] = [] self._library_calls: List["LibraryCallType"] = []
@ -457,6 +464,7 @@ class Node(SourceMapping, ChildFunction):
:param callstack: used internally to check for recursion :param callstack: used internally to check for recursion
:return bool: :return bool:
""" """
# pylint: disable=import-outside-toplevel
from slither.slithir.operations import Call from slither.slithir.operations import Call
if self._can_reenter is None: if self._can_reenter is None:
@ -472,6 +480,7 @@ class Node(SourceMapping, ChildFunction):
Check if the node can send eth Check if the node can send eth
:return bool: :return bool:
""" """
# pylint: disable=import-outside-toplevel
from slither.slithir.operations import Call from slither.slithir.operations import Call
if self._can_send_eth is None: if self._can_send_eth is None:
@ -712,7 +721,9 @@ class Node(SourceMapping, ChildFunction):
@staticmethod @staticmethod
def _is_non_slithir_var(var: Variable): def _is_non_slithir_var(var: Variable):
return not isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) return not isinstance(
var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)
)
@staticmethod @staticmethod
def _is_valid_slithir_var(var: Variable): def _is_valid_slithir_var(var: Variable):
@ -793,11 +804,15 @@ class Node(SourceMapping, ChildFunction):
################################################################################### ###################################################################################
@property @property
def phi_origins_local_variables(self) -> Dict[str, Tuple[LocalVariable, Set["Node"]]]: def phi_origins_local_variables(
self,
) -> Dict[str, Tuple[LocalVariable, Set["Node"]]]:
return self._phi_origins_local_variables return self._phi_origins_local_variables
@property @property
def phi_origins_state_variables(self) -> Dict[str, Tuple[StateVariable, Set["Node"]]]: def phi_origins_state_variables(
self,
) -> Dict[str, Tuple[StateVariable, Set["Node"]]]:
return self._phi_origins_state_variables return self._phi_origins_state_variables
# @property # @property
@ -835,11 +850,12 @@ class Node(SourceMapping, ChildFunction):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def _find_read_write_call(self): def _find_read_write_call(self): # pylint: disable=too-many-statements
for ir in self.irs: for ir in self.irs:
self._slithir_vars |= set([v for v in ir.read if self._is_valid_slithir_var(v)]) self._slithir_vars |= {v for v in ir.read if self._is_valid_slithir_var(v)}
if isinstance(ir, OperationWithLValue): if isinstance(ir, OperationWithLValue):
var = ir.lvalue var = ir.lvalue
if var and self._is_valid_slithir_var(var): if var and self._is_valid_slithir_var(var):
@ -884,7 +900,9 @@ class Node(SourceMapping, ChildFunction):
self._high_level_calls.append((self.function.contract, ir.function)) self._high_level_calls.append((self.function.contract, ir.function))
else: else:
try: try:
self._high_level_calls.append((ir.destination.type.type, ir.function)) self._high_level_calls.append(
(ir.destination.type.type, ir.function)
)
except AttributeError: except AttributeError:
raise SlitherException( raise SlitherException(
f"Function not found on {ir}. Please try compiling with a recent Solidity version." f"Function not found on {ir}. Please try compiling with a recent Solidity version."
@ -895,12 +913,22 @@ class Node(SourceMapping, ChildFunction):
self._library_calls.append((ir.destination, ir.function)) self._library_calls.append((ir.destination, ir.function))
self._vars_read = list(set(self._vars_read)) self._vars_read = list(set(self._vars_read))
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] v for v in self._vars_read if isinstance(v, StateVariable)
self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] ]
self._local_vars_read = [
v for v in self._vars_read if isinstance(v, LocalVariable)
]
self._solidity_vars_read = [
v for v in self._vars_read if isinstance(v, SolidityVariable)
]
self._vars_written = list(set(self._vars_written)) self._vars_written = list(set(self._vars_written))
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] v for v in self._vars_written if isinstance(v, StateVariable)
]
self._local_vars_written = [
v for v in self._vars_written if isinstance(v, LocalVariable)
]
self._internal_calls = list(set(self._internal_calls)) self._internal_calls = list(set(self._internal_calls))
self._solidity_calls = list(set(self._solidity_calls)) self._solidity_calls = list(set(self._solidity_calls))
self._high_level_calls = list(set(self._high_level_calls)) self._high_level_calls = list(set(self._high_level_calls))
@ -926,7 +954,9 @@ class Node(SourceMapping, ChildFunction):
continue continue
if not isinstance(ir, (Phi, Index, Member)): if not isinstance(ir, (Phi, Index, Member)):
self._ssa_vars_read += [ self._ssa_vars_read += [
v for v in ir.read if isinstance(v, (StateIRVariable, LocalIRVariable)) v
for v in ir.read
if isinstance(v, (StateIRVariable, LocalIRVariable))
] ]
for var in ir.read: for var in ir.read:
if isinstance(var, ReferenceVariable): if isinstance(var, ReferenceVariable):
@ -954,8 +984,12 @@ class Node(SourceMapping, ChildFunction):
continue continue
self._ssa_vars_written.append(var) self._ssa_vars_written.append(var)
self._ssa_vars_read = list(set(self._ssa_vars_read)) self._ssa_vars_read = list(set(self._ssa_vars_read))
self._ssa_state_vars_read = [v for v in self._ssa_vars_read if isinstance(v, StateVariable)] self._ssa_state_vars_read = [
self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] v for v in self._ssa_vars_read if isinstance(v, StateVariable)
]
self._ssa_local_vars_read = [
v for v in self._ssa_vars_read if isinstance(v, LocalVariable)
]
self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_vars_written = list(set(self._ssa_vars_written))
self._ssa_state_vars_written = [ self._ssa_state_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, StateVariable) v for v in self._ssa_vars_written if isinstance(v, StateVariable)
@ -968,12 +1002,20 @@ class Node(SourceMapping, ChildFunction):
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read] self._vars_read += [v for v in vars_read if v not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] v for v in self._vars_read if isinstance(v, StateVariable)
]
self._local_vars_read = [
v for v in self._vars_read if isinstance(v, LocalVariable)
]
self._vars_written += [v for v in vars_written if v not in self._vars_written] self._vars_written += [v for v in vars_written if v not in self._vars_written]
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] v for v in self._vars_written if isinstance(v, StateVariable)
]
self._local_vars_written = [
v for v in self._vars_written if isinstance(v, LocalVariable)
]
# endregion # endregion
################################################################################### ###################################################################################
@ -1024,11 +1066,11 @@ def recheable(node: Node) -> Set[Node]:
nodes = node.sons nodes = node.sons
visited = set() visited = set()
while nodes: while nodes:
next = nodes[0] next_node = nodes[0]
nodes = nodes[1:] nodes = nodes[1:]
if next not in visited: if next_node not in visited:
visited.add(next) visited.add(next_node)
for son in next.sons: for son in next_node.sons:
if son not in visited: if son not in visited:
nodes.append(son) nodes.append(son)
return visited return visited

@ -6,7 +6,7 @@ if TYPE_CHECKING:
from slither.core.declarations import Function, Contract from slither.core.declarations import Function, Contract
class ChildNode(object): class ChildNode:
def __init__(self): def __init__(self):
super(ChildNode, self).__init__() super(ChildNode, self).__init__()
self._node = None self._node = None

@ -2,7 +2,7 @@ from collections import defaultdict
from typing import Dict from typing import Dict
class Context: class Context: # pylint: disable=too-few-public-methods
def __init__(self): def __init__(self):
super(Context, self).__init__() super(Context, self).__init__()
self._context = {"MEMBERS": defaultdict(None)} self._context = {"MEMBERS": defaultdict(None)}

@ -5,5 +5,9 @@ from .function import Function
from .import_directive import Import from .import_directive import Import
from .modifier import Modifier from .modifier import Modifier
from .pragma_directive import Pragma from .pragma_directive import Pragma
from .solidity_variables import SolidityVariable, SolidityVariableComposed, SolidityFunction from .solidity_variables import (
SolidityVariable,
SolidityVariableComposed,
SolidityFunction,
)
from .structure import Structure from .structure import Structure

@ -3,9 +3,9 @@
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union
from crytic_compile.platform import Type as PlatformType from crytic_compile.platform import Type as PlatformType
from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union
from slither.core.children.child_slither import ChildSlither from slither.core.children.child_slither import ChildSlither
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -22,6 +22,7 @@ from slither.utils.erc import (
) )
from slither.utils.tests_pattern import is_test_contract from slither.utils.tests_pattern import is_test_contract
# pylint: disable=too-many-lines,too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.utils.type_helpers import LibraryCallType, HighLevelCallType from slither.utils.type_helpers import LibraryCallType, HighLevelCallType
from slither.core.declarations import Enum, Event, Modifier from slither.core.declarations import Enum, Event, Modifier
@ -33,7 +34,7 @@ if TYPE_CHECKING:
LOGGER = logging.getLogger("Contract") LOGGER = logging.getLogger("Contract")
class Contract(ChildSlither, SourceMapping): class Contract(ChildSlither, SourceMapping): # pylint: disable=too-many-public-methods
""" """
Contract class Contract class
""" """
@ -43,7 +44,9 @@ class Contract(ChildSlither, SourceMapping):
self._name: Optional[str] = None self._name: Optional[str] = None
self._id: Optional[int] = None self._id: Optional[int] = None
self._inheritance: List["Contract"] = [] # all contract inherited, c3 linearization self._inheritance: List[
"Contract"
] = [] # all contract inherited, c3 linearization
self._immediate_inheritance: List["Contract"] = [] # immediate inheritance self._immediate_inheritance: List["Contract"] = [] # immediate inheritance
# Constructors called on contract's definition # Constructors called on contract's definition
@ -338,7 +341,11 @@ class Contract(ChildSlither, SourceMapping):
On "contract B is A(){..}" it returns the constructor of A On "contract B is A(){..}" it returns the constructor of A
""" """
return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor] return [
c.constructor
for c in self._explicit_base_constructor_calls
if c.constructor
]
# endregion # endregion
################################################################################### ###################################################################################
@ -355,12 +362,16 @@ class Contract(ChildSlither, SourceMapping):
""" """
if self._signatures is None: if self._signatures is None:
sigs = [ sigs = [
v.full_name for v in self.state_variables if v.visibility in ["public", "external"] v.full_name
for v in self.state_variables
if v.visibility in ["public", "external"]
] ]
sigs += set( sigs += {
[f.full_name for f in self.functions if f.visibility in ["public", "external"]] f.full_name
) for f in self.functions
if f.visibility in ["public", "external"]
}
self._signatures = list(set(sigs)) self._signatures = list(set(sigs))
return self._signatures return self._signatures
@ -377,13 +388,11 @@ class Contract(ChildSlither, SourceMapping):
if v.visibility in ["public", "external"] if v.visibility in ["public", "external"]
] ]
sigs += set( sigs += {
[ f.full_name
f.full_name for f in self.functions_declared
for f in self.functions_declared if f.visibility in ["public", "external"]
if f.visibility in ["public", "external"] }
]
)
self._signatures_declared = list(set(sigs)) self._signatures_declared = list(set(sigs))
return self._signatures_declared return self._signatures_declared
@ -397,6 +406,9 @@ class Contract(ChildSlither, SourceMapping):
def available_functions_as_dict(self) -> Dict[str, "Function"]: def available_functions_as_dict(self) -> Dict[str, "Function"]:
return {f.full_name: f for f in self._functions.values() if not f.is_shadowed} return {f.full_name: f for f in self._functions.values() if not f.is_shadowed}
def add_function(self, func: "Function"):
self._functions[func.canonical_name] = func
def set_functions(self, functions: Dict[str, "Function"]): def set_functions(self, functions: Dict[str, "Function"]):
""" """
Set the functions Set the functions
@ -569,19 +581,25 @@ class Contract(ChildSlither, SourceMapping):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_functions_reading_from_variable(self, variable: "Variable") -> List["Function"]: def get_functions_reading_from_variable(
self, variable: "Variable"
) -> List["Function"]:
""" """
Return the functions reading the variable Return the functions reading the variable
""" """
return [f for f in self.functions if f.is_reading(variable)] return [f for f in self.functions if f.is_reading(variable)]
def get_functions_writing_to_variable(self, variable: "Variable") -> List["Function"]: def get_functions_writing_to_variable(
self, variable: "Variable"
) -> List["Function"]:
""" """
Return the functions writting the variable Return the functions writting the variable
""" """
return [f for f in self.functions if f.is_writing(variable)] return [f for f in self.functions if f.is_writing(variable)]
def get_function_from_signature(self, function_signature: str) -> Optional["Function"]: def get_function_from_signature(
self, function_signature: str
) -> Optional["Function"]:
""" """
Return a function from a signature Return a function from a signature
Args: Args:
@ -590,22 +608,34 @@ class Contract(ChildSlither, SourceMapping):
Function Function
""" """
return next( return next(
(f for f in self.functions if f.full_name == function_signature and not f.is_shadowed), (
f
for f in self.functions
if f.full_name == function_signature and not f.is_shadowed
),
None, None,
) )
def get_modifier_from_signature(self, modifier_signature: str) -> Optional["Modifier"]: def get_modifier_from_signature(
self, modifier_signature: str
) -> Optional["Modifier"]:
""" """
Return a modifier from a signature Return a modifier from a signature
:param modifier_signature: :param modifier_signature:
""" """
return next( return next(
(m for m in self.modifiers if m.full_name == modifier_signature and not m.is_shadowed), (
m
for m in self.modifiers
if m.full_name == modifier_signature and not m.is_shadowed
),
None, None,
) )
def get_function_from_canonical_name(self, canonical_name: str) -> Optional["Function"]: def get_function_from_canonical_name(
self, canonical_name: str
) -> Optional["Function"]:
""" """
Return a function from a a canonical name (contract.signature()) Return a function from a a canonical name (contract.signature())
Args: Args:
@ -613,9 +643,13 @@ class Contract(ChildSlither, SourceMapping):
Returns: Returns:
Function Function
""" """
return next((f for f in self.functions if f.canonical_name == canonical_name), None) return next(
(f for f in self.functions if f.canonical_name == canonical_name), None
)
def get_modifier_from_canonical_name(self, canonical_name: str) -> Optional["Modifier"]: def get_modifier_from_canonical_name(
self, canonical_name: str
) -> Optional["Modifier"]:
""" """
Return a modifier from a canonical name (contract.signature()) Return a modifier from a canonical name (contract.signature())
Args: Args:
@ -623,9 +657,13 @@ class Contract(ChildSlither, SourceMapping):
Returns: Returns:
Modifier Modifier
""" """
return next((m for m in self.modifiers if m.canonical_name == canonical_name), None) return next(
(m for m in self.modifiers if m.canonical_name == canonical_name), None
)
def get_state_variable_from_name(self, variable_name: str) -> Optional["StateVariable"]: def get_state_variable_from_name(
self, variable_name: str
) -> Optional["StateVariable"]:
""" """
Return a state variable from a name Return a state variable from a name
@ -655,7 +693,9 @@ class Contract(ChildSlither, SourceMapping):
""" """
return next((st for st in self.structures if st.name == structure_name), None) return next((st for st in self.structures if st.name == structure_name), None)
def get_structure_from_canonical_name(self, structure_name: str) -> Optional["Structure"]: def get_structure_from_canonical_name(
self, structure_name: str
) -> Optional["Structure"]:
""" """
Return a structure from a canonical name Return a structure from a canonical name
Args: Args:
@ -663,7 +703,9 @@ class Contract(ChildSlither, SourceMapping):
Returns: Returns:
Structure Structure
""" """
return next((st for st in self.structures if st.canonical_name == structure_name), None) return next(
(st for st in self.structures if st.canonical_name == structure_name), None
)
def get_event_from_signature(self, event_signature: str) -> Optional["Event"]: def get_event_from_signature(self, event_signature: str) -> Optional["Event"]:
""" """
@ -675,7 +717,9 @@ class Contract(ChildSlither, SourceMapping):
""" """
return next((e for e in self.events if e.full_name == event_signature), None) return next((e for e in self.events if e.full_name == event_signature), None)
def get_event_from_canonical_name(self, event_canonical_name: str) -> Optional["Event"]: def get_event_from_canonical_name(
self, event_canonical_name: str
) -> Optional["Event"]:
""" """
Return an event from a canonical name Return an event from a canonical name
Args: Args:
@ -683,7 +727,9 @@ class Contract(ChildSlither, SourceMapping):
Returns: Returns:
Event Event
""" """
return next((e for e in self.events if e.canonical_name == event_canonical_name), None) return next(
(e for e in self.events if e.canonical_name == event_canonical_name), None
)
def get_enum_from_name(self, enum_name: str) -> Optional["Enum"]: def get_enum_from_name(self, enum_name: str) -> Optional["Enum"]:
""" """
@ -775,7 +821,9 @@ class Contract(ChildSlither, SourceMapping):
list((Contract, Function): List all of the libraries func called list((Contract, Function): List all of the libraries func called
""" """
all_high_level_calls = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore
all_high_level_calls = [item for sublist in all_high_level_calls for item in sublist] all_high_level_calls = [
item for sublist in all_high_level_calls for item in sublist
]
return list(set(all_high_level_calls)) return list(set(all_high_level_calls))
@property @property
@ -784,7 +832,9 @@ class Contract(ChildSlither, SourceMapping):
list((Contract, Function|Variable)): List all of the external high level calls list((Contract, Function|Variable)): List all of the external high level calls
""" """
all_high_level_calls = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore
all_high_level_calls = [item for sublist in all_high_level_calls for item in sublist] all_high_level_calls = [
item for sublist in all_high_level_calls for item in sublist
]
return list(set(all_high_level_calls)) return list(set(all_high_level_calls))
# endregion # endregion
@ -804,10 +854,14 @@ class Contract(ChildSlither, SourceMapping):
(str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) (str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries)
""" """
func_summaries = [ func_summaries = [
f.get_summary() for f in self.functions if (not f.is_shadowed or include_shadowed) f.get_summary()
for f in self.functions
if (not f.is_shadowed or include_shadowed)
] ]
modif_summaries = [ modif_summaries = [
f.get_summary() for f in self.modifiers if (not f.is_shadowed or include_shadowed) f.get_summary()
for f in self.modifiers
if (not f.is_shadowed or include_shadowed)
] ]
return ( return (
self.name, self.name,
@ -971,7 +1025,9 @@ class Contract(ChildSlither, SourceMapping):
def is_from_dependency(self) -> bool: def is_from_dependency(self) -> bool:
if self.slither.crytic_compile is None: if self.slither.crytic_compile is None:
return False return False
return self.slither.crytic_compile.is_dependency(self.source_mapping["filename_absolute"]) return self.slither.crytic_compile.is_dependency(
self.source_mapping["filename_absolute"]
)
# endregion # endregion
################################################################################### ###################################################################################
@ -991,7 +1047,9 @@ class Contract(ChildSlither, SourceMapping):
if self.name == "Migrations": if self.name == "Migrations":
paths = Path(self.source_mapping["filename_absolute"]).parts paths = Path(self.source_mapping["filename_absolute"]).parts
if len(paths) >= 2: if len(paths) >= 2:
return paths[-2] == "contracts" and paths[-1] == "migrations.sol" return (
paths[-2] == "contracts" and paths[-1] == "migrations.sol"
)
return False return False
@property @property
@ -1027,7 +1085,10 @@ class Contract(ChildSlither, SourceMapping):
else: else:
for c in self.inheritance + [self]: for c in self.inheritance + [self]:
# This might lead to false positive # This might lead to false positive
if "upgradeable" in c.name.lower() or "upgradable" in c.name.lower(): if (
"upgradeable" in c.name.lower()
or "upgradable" in c.name.lower()
):
self._is_upgradeable = True self._is_upgradeable = True
break break
return self._is_upgradeable return self._is_upgradeable
@ -1043,7 +1104,10 @@ class Contract(ChildSlither, SourceMapping):
if f.is_fallback: if f.is_fallback:
for node in f.all_nodes(): for node in f.all_nodes():
for ir in node.irs: for ir in node.irs:
if isinstance(ir, LowLevelCall) and ir.function_name == "delegatecall": if (
isinstance(ir, LowLevelCall)
and ir.function_name == "delegatecall"
):
self._is_upgradeable_proxy = True self._is_upgradeable_proxy = True
return self._is_upgradeable_proxy return self._is_upgradeable_proxy
if node.type == NodeType.ASSEMBLY: if node.type == NodeType.ASSEMBLY:
@ -1079,21 +1143,29 @@ class Contract(ChildSlither, SourceMapping):
if variable_candidate.expression and not variable_candidate.is_constant: if variable_candidate.expression and not variable_candidate.is_constant:
constructor_variable = Function() constructor_variable = Function()
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) constructor_variable.set_function_type(
FunctionType.CONSTRUCTOR_VARIABLES
)
constructor_variable.set_contract(self) constructor_variable.set_contract(self)
constructor_variable.set_contract_declarer(self) constructor_variable.set_contract_declarer(self)
constructor_variable.set_visibility("internal") constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract # For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping # Could be improved with a targeted source mapping
constructor_variable.set_offset(self.source_mapping, self.slither) constructor_variable.set_offset(self.source_mapping, self.slither)
self._functions[constructor_variable.canonical_name] = constructor_variable self._functions[
constructor_variable.canonical_name
] = constructor_variable
prev_node = self._create_node(constructor_variable, 0, variable_candidate) prev_node = self._create_node(
constructor_variable, 0, variable_candidate
)
variable_candidate.node_initialization = prev_node variable_candidate.node_initialization = prev_node
counter = 1 counter = 1
for v in self.state_variables[idx + 1 :]: for v in self.state_variables[idx + 1 :]:
if v.expression and not v.is_constant: if v.expression and not v.is_constant:
next_node = self._create_node(constructor_variable, counter, v) next_node = self._create_node(
constructor_variable, counter, v
)
v.node_initialization = next_node v.node_initialization = next_node
prev_node.add_son(next_node) prev_node.add_son(next_node)
next_node.add_father(prev_node) next_node.add_father(prev_node)
@ -1113,14 +1185,20 @@ class Contract(ChildSlither, SourceMapping):
# For now, source mapping of the constructor variable is the whole contract # For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping # Could be improved with a targeted source mapping
constructor_variable.set_offset(self.source_mapping, self.slither) constructor_variable.set_offset(self.source_mapping, self.slither)
self._functions[constructor_variable.canonical_name] = constructor_variable self._functions[
constructor_variable.canonical_name
] = constructor_variable
prev_node = self._create_node(constructor_variable, 0, variable_candidate) prev_node = self._create_node(
constructor_variable, 0, variable_candidate
)
variable_candidate.node_initialization = prev_node variable_candidate.node_initialization = prev_node
counter = 1 counter = 1
for v in self.state_variables[idx + 1 :]: for v in self.state_variables[idx + 1 :]:
if v.expression and v.is_constant: if v.expression and v.is_constant:
next_node = self._create_node(constructor_variable, counter, v) next_node = self._create_node(
constructor_variable, counter, v
)
v.node_initialization = next_node v.node_initialization = next_node
prev_node.add_son(next_node) prev_node.add_son(next_node)
next_node.add_father(prev_node) next_node.add_father(prev_node)
@ -1142,7 +1220,10 @@ class Contract(ChildSlither, SourceMapping):
node.set_function(func) node.set_function(func)
func.add_node(node) func.add_node(node)
expression = AssignmentOperation( expression = AssignmentOperation(
Identifier(variable), variable.expression, AssignmentOperationType.ASSIGN, variable.type Identifier(variable),
variable.expression,
AssignmentOperationType.ASSIGN,
variable.type,
) )
expression.set_offset(variable.source_mapping, self.slither) expression.set_offset(variable.source_mapping, self.slither)
@ -1194,7 +1275,9 @@ class Contract(ChildSlither, SourceMapping):
last_state_variables_instances[variable_name] += instances last_state_variables_instances[variable_name] += instances
for func in self.functions + self.modifiers: for func in self.functions + self.modifiers:
func.fix_phi(last_state_variables_instances, initial_state_variables_instances) func.fix_phi(
last_state_variables_instances, initial_state_variables_instances
)
@property @property
def is_top_level(self) -> bool: def is_top_level(self) -> bool:

@ -14,7 +14,12 @@ from slither.core.declarations.solidity_variables import (
SolidityVariable, SolidityVariable,
SolidityVariableComposed, SolidityVariableComposed,
) )
from slither.core.expressions import Identifier, IndexAccess, MemberAccess, UnaryOperation from slither.core.expressions import (
Identifier,
IndexAccess,
MemberAccess,
UnaryOperation,
)
from slither.core.solidity_types import UserDefinedType from slither.core.solidity_types import UserDefinedType
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
@ -23,6 +28,8 @@ from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.utils.utils import unroll from slither.utils.utils import unroll
# pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.utils.type_helpers import ( from slither.utils.type_helpers import (
InternalCallType, InternalCallType,
@ -46,7 +53,10 @@ ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
class ModifierStatements: class ModifierStatements:
def __init__( def __init__(
self, modifier: Union["Contract", "Function"], entry_point: "Node", nodes: List["Node"] self,
modifier: Union["Contract", "Function"],
entry_point: "Node",
nodes: List["Node"],
): ):
self._modifier = modifier self._modifier = modifier
self._entry_point = entry_point self._entry_point = entry_point
@ -79,10 +89,26 @@ class FunctionType(Enum):
FALLBACK = 2 FALLBACK = 2
RECEIVE = 3 RECEIVE = 3
CONSTRUCTOR_VARIABLES = 10 # Fake function to hold variable declaration statements CONSTRUCTOR_VARIABLES = 10 # Fake function to hold variable declaration statements
CONSTRUCTOR_CONSTANT_VARIABLES = 11 # Fake function to hold variable declaration statements CONSTRUCTOR_CONSTANT_VARIABLES = (
11 # Fake function to hold variable declaration statements
)
class Function(ChildContract, ChildInheritance, SourceMapping): def _filter_state_variables_written(expressions: List["Expression"]):
ret = []
for expression in expressions:
if isinstance(expression, Identifier):
ret.append(expression)
if isinstance(expression, UnaryOperation):
ret.append(expression.expression)
if isinstance(expression, MemberAccess):
ret.append(expression.expression)
if isinstance(expression, IndexAccess):
ret.append(expression.expression_left)
return ret
class Function(ChildContract, ChildInheritance, SourceMapping): # pylint: disable=too-many-public-methods
""" """
Function class Function class
""" """
@ -147,13 +173,21 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
self._all_state_variables_written: Optional[List["StateVariable"]] = None self._all_state_variables_written: Optional[List["StateVariable"]] = None
self._all_slithir_variables: Optional[List["SlithIRVariable"]] = None self._all_slithir_variables: Optional[List["SlithIRVariable"]] = None
self._all_nodes: Optional[List["Node"]] = None self._all_nodes: Optional[List["Node"]] = None
self._all_conditional_state_variables_read: Optional[List["StateVariable"]] = None self._all_conditional_state_variables_read: Optional[
self._all_conditional_state_variables_read_with_loop: Optional[List["StateVariable"]] = None List["StateVariable"]
self._all_conditional_solidity_variables_read: Optional[List["SolidityVariable"]] = None ] = None
self._all_conditional_state_variables_read_with_loop: Optional[
List["StateVariable"]
] = None
self._all_conditional_solidity_variables_read: Optional[
List["SolidityVariable"]
] = None
self._all_conditional_solidity_variables_read_with_loop: Optional[ self._all_conditional_solidity_variables_read_with_loop: Optional[
List["SolidityVariable"] List["SolidityVariable"]
] = None ] = None
self._all_solidity_variables_used_as_args: Optional[List["SolidityVariable"]] = None self._all_solidity_variables_used_as_args: Optional[
List["SolidityVariable"]
] = None
self._is_shadowed: bool = False self._is_shadowed: bool = False
self._shadows: bool = False self._shadows: bool = False
@ -187,13 +221,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" """
if self._name == "" and self._function_type == FunctionType.CONSTRUCTOR: if self._name == "" and self._function_type == FunctionType.CONSTRUCTOR:
return "constructor" return "constructor"
elif self._function_type == FunctionType.FALLBACK: if self._function_type == FunctionType.FALLBACK:
return "fallback" return "fallback"
elif self._function_type == FunctionType.RECEIVE: if self._function_type == FunctionType.RECEIVE:
return "receive" return "receive"
elif self._function_type == FunctionType.CONSTRUCTOR_VARIABLES: if self._function_type == FunctionType.CONSTRUCTOR_VARIABLES:
return "slitherConstructorVariables" return "slitherConstructorVariables"
elif self._function_type == FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES: if self._function_type == FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES:
return "slitherConstructorConstantVariables" return "slitherConstructorConstantVariables"
return self._name return self._name
@ -815,15 +849,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if self._return_values is None: if self._return_values is None:
return_values = list() return_values = list()
returns = [n for n in self.nodes if n.type == NodeType.RETURN] returns = [n for n in self.nodes if n.type == NodeType.RETURN]
[ [ # pylint: disable=expression-not-assigned
return_values.extend(ir.values) return_values.extend(ir.values)
for node in returns for node in returns
for ir in node.irs for ir in node.irs
if isinstance(ir, Return) if isinstance(ir, Return)
] ]
self._return_values = list( self._return_values = list({x for x in return_values if not isinstance(x, Constant)})
set([x for x in return_values if not isinstance(x, Constant)])
)
return self._return_values return self._return_values
@property @property
@ -838,15 +870,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if self._return_values_ssa is None: if self._return_values_ssa is None:
return_values_ssa = list() return_values_ssa = list()
returns = [n for n in self.nodes if n.type == NodeType.RETURN] returns = [n for n in self.nodes if n.type == NodeType.RETURN]
[ [ # pylint: disable=expression-not-assigned
return_values_ssa.extend(ir.values) return_values_ssa.extend(ir.values)
for node in returns for node in returns
for ir in node.irs_ssa for ir in node.irs_ssa
if isinstance(ir, Return) if isinstance(ir, Return)
] ]
self._return_values_ssa = list( self._return_values_ssa = list({x for x in return_values_ssa if not isinstance(x, Constant)})
set([x for x in return_values_ssa if not isinstance(x, Constant)])
)
return self._return_values_ssa return self._return_values_ssa
# endregion # endregion
@ -900,7 +930,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Contract and converted into address Contract and converted into address
:return: the solidity signature :return: the solidity signature
""" """
parameters = [self._convert_type_for_solidity_signature(x.type) for x in self.parameters] parameters = [
self._convert_type_for_solidity_signature(x.type) for x in self.parameters
]
return self.name + "(" + ",".join(parameters) + ")" return self.name + "(" + ",".join(parameters) + ")"
@property @property
@ -922,7 +954,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Return the function signature as a str (contains the return values) Return the function signature as a str (contains the return values)
""" """
name, parameters, returnVars = self.signature name, parameters, returnVars = self.signature
return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")" return (
name
+ "("
+ ",".join(parameters)
+ ") returns("
+ ",".join(returnVars)
+ ")"
)
# endregion # endregion
################################################################################### ###################################################################################
@ -977,10 +1016,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
values = f_new_values(self) values = f_new_values(self)
explored = [self] explored = [self]
to_explore = [ to_explore = [
c for c in self.internal_calls if isinstance(c, Function) and c not in explored c
for c in self.internal_calls
if isinstance(c, Function) and c not in explored
] ]
to_explore += [ to_explore += [
c for (_, c) in self.library_calls if isinstance(c, Function) and c not in explored c
for (_, c) in self.library_calls
if isinstance(c, Function) and c not in explored
] ]
to_explore += [m for m in self.modifiers if m not in explored] to_explore += [m for m in self.modifiers if m not in explored]
@ -1003,7 +1046,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
for (_, c) in f.library_calls for (_, c) in f.library_calls
if isinstance(c, Function) and c not in explored and c not in to_explore if isinstance(c, Function) and c not in explored and c not in to_explore
] ]
to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] to_explore += [
m for m in f.modifiers if m not in explored and m not in to_explore
]
return list(set(values)) return list(set(values))
@ -1029,7 +1074,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" recursive version of slithir_variables """ recursive version of slithir_variables
""" """
if self._all_slithir_variables is None: if self._all_slithir_variables is None:
self._all_slithir_variables = self._explore_functions(lambda x: x.slithir_variables) self._all_slithir_variables = self._explore_functions(
lambda x: x.slithir_variables
)
return self._all_slithir_variables return self._all_slithir_variables
def all_nodes(self) -> List["Node"]: def all_nodes(self) -> List["Node"]:
@ -1047,10 +1094,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return self._all_expressions return self._all_expressions
def all_slithir_operations(self) -> List["Operation"]: def all_slithir_operations(self) -> List["Operation"]:
"""
"""
if self._all_slithir_operations is None: if self._all_slithir_operations is None:
self._all_slithir_operations = self._explore_functions(lambda x: x.slithir_operations) self._all_slithir_operations = self._explore_functions(
lambda x: x.slithir_operations
)
return self._all_slithir_operations return self._all_slithir_operations
def all_state_variables_written(self) -> List[StateVariable]: def all_state_variables_written(self) -> List[StateVariable]:
@ -1066,21 +1113,27 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" recursive version of internal_calls """ recursive version of internal_calls
""" """
if self._all_internals_calls is None: if self._all_internals_calls is None:
self._all_internals_calls = self._explore_functions(lambda x: x.internal_calls) self._all_internals_calls = self._explore_functions(
lambda x: x.internal_calls
)
return self._all_internals_calls return self._all_internals_calls
def all_low_level_calls(self) -> List["LowLevelCallType"]: def all_low_level_calls(self) -> List["LowLevelCallType"]:
""" recursive version of low_level calls """ recursive version of low_level calls
""" """
if self._all_low_level_calls is None: if self._all_low_level_calls is None:
self._all_low_level_calls = self._explore_functions(lambda x: x.low_level_calls) self._all_low_level_calls = self._explore_functions(
lambda x: x.low_level_calls
)
return self._all_low_level_calls return self._all_low_level_calls
def all_high_level_calls(self) -> List["HighLevelCallType"]: def all_high_level_calls(self) -> List["HighLevelCallType"]:
""" recursive version of high_level calls """ recursive version of high_level calls
""" """
if self._all_high_level_calls is None: if self._all_high_level_calls is None:
self._all_high_level_calls = self._explore_functions(lambda x: x.high_level_calls) self._all_high_level_calls = self._explore_functions(
lambda x: x.high_level_calls
)
return self._all_high_level_calls return self._all_high_level_calls
def all_library_calls(self) -> List["LibraryCallType"]: def all_library_calls(self) -> List["LibraryCallType"]:
@ -1094,15 +1147,23 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" recursive version of solidity calls """ recursive version of solidity calls
""" """
if self._all_solidity_calls is None: if self._all_solidity_calls is None:
self._all_solidity_calls = self._explore_functions(lambda x: x.solidity_calls) self._all_solidity_calls = self._explore_functions(
lambda x: x.solidity_calls
)
return self._all_solidity_calls return self._all_solidity_calls
@staticmethod @staticmethod
def _explore_func_cond_read(func: "Function", include_loop: bool) -> List["StateVariable"]: def _explore_func_cond_read(
ret = [n.state_variables_read for n in func.nodes if n.is_conditional(include_loop)] func: "Function", include_loop: bool
) -> List["StateVariable"]:
ret = [
n.state_variables_read for n in func.nodes if n.is_conditional(include_loop)
]
return [item for sublist in ret for item in sublist] return [item for sublist in ret for item in sublist]
def all_conditional_state_variables_read(self, include_loop=True) -> List["StateVariable"]: def all_conditional_state_variables_read(
self, include_loop=True
) -> List["StateVariable"]:
""" """
Return the state variable used in a condition Return the state variable used in a condition
@ -1133,12 +1194,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
@staticmethod @staticmethod
def _explore_func_conditional( def _explore_func_conditional(
func: "Function", f: Callable[["Node"], List[SolidityVariable]], include_loop: bool func: "Function",
f: Callable[["Node"], List[SolidityVariable]],
include_loop: bool,
): ):
ret = [f(n) for n in func.nodes if n.is_conditional(include_loop)] ret = [f(n) for n in func.nodes if n.is_conditional(include_loop)]
return [item for sublist in ret for item in sublist] return [item for sublist in ret for item in sublist]
def all_conditional_solidity_variables_read(self, include_loop=True) -> List[SolidityVariable]: def all_conditional_solidity_variables_read(
self, include_loop=True
) -> List[SolidityVariable]:
""" """
Return the Soldiity variables directly used in a condtion Return the Soldiity variables directly used in a condtion
@ -1174,7 +1239,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return [var for var in ret if isinstance(var, SolidityVariable)] return [var for var in ret if isinstance(var, SolidityVariable)]
@staticmethod @staticmethod
def _explore_func_nodes(func: "Function", f: Callable[["Node"], List[SolidityVariable]]): def _explore_func_nodes(
func: "Function", f: Callable[["Node"], List[SolidityVariable]]
):
ret = [f(n) for n in func.nodes] ret = [f(n) for n in func.nodes]
return [item for sublist in ret for item in sublist] return [item for sublist in ret for item in sublist]
@ -1187,7 +1254,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" """
if self._all_solidity_variables_used_as_args is None: if self._all_solidity_variables_used_as_args is None:
self._all_solidity_variables_used_as_args = self._explore_functions( self._all_solidity_variables_used_as_args = self._explore_functions(
lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls) lambda x: self._explore_func_nodes(
x, self._solidity_variable_in_internal_calls
)
) )
return self._all_solidity_variables_used_as_args return self._all_solidity_variables_used_as_args
@ -1217,7 +1286,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_local_variable_from_name(self, variable_name: str) -> Optional[LocalVariable]: def get_local_variable_from_name(
self, variable_name: str
) -> Optional[LocalVariable]:
""" """
Return a local variable from a name Return a local variable from a name
@ -1271,7 +1342,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
for node in self.nodes: for node in self.nodes:
f.write('{}[label="{}"];\n'.format(node.node_id, description(node))) f.write('{}[label="{}"];\n'.format(node.node_id, description(node)))
if node.immediate_dominator: if node.immediate_dominator:
f.write("{}->{};\n".format(node.immediate_dominator.node_id, node.node_id)) f.write(
"{}->{};\n".format(
node.immediate_dominator.node_id, node.node_id
)
)
f.write("}\n") f.write("}\n")
@ -1305,10 +1380,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if node.type in [NodeType.IF, NodeType.IFLOOP]: if node.type in [NodeType.IF, NodeType.IFLOOP]:
true_node = node.son_true true_node = node.son_true
if true_node: if true_node:
content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id) content += '{}->{}[label="True"];\n'.format(
node.node_id, true_node.node_id
)
false_node = node.son_false false_node = node.son_false
if false_node: if false_node:
content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id) content += '{}->{}[label="False"];\n'.format(
node.node_id, false_node.node_id
)
else: else:
for son in node.sons: for son in node.sons:
content += "{}->{};\n".format(node.node_id, son.node_id) content += "{}->{};\n".format(node.node_id, son.node_id)
@ -1353,7 +1432,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Returns: Returns:
bool: True if the variable is read bool: True if the variable is read
""" """
variables_reads = [n.variables_read for n in self.nodes if n.contains_require_or_assert()] variables_reads = [
n.variables_read for n in self.nodes if n.contains_require_or_assert()
]
variables_read = [item for sublist in variables_reads for item in sublist] variables_read = [item for sublist in variables_reads for item in sublist]
return variable in variables_read return variable in variables_read
@ -1401,7 +1482,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if self.is_constructor: if self.is_constructor:
return True return True
conditional_vars = self.all_conditional_solidity_variables_read(include_loop=False) conditional_vars = self.all_conditional_solidity_variables_read(
include_loop=False
)
args_vars = self.all_solidity_variables_used_as_args() args_vars = self.all_solidity_variables_used_as_args()
return SolidityVariableComposed("msg.sender") in conditional_vars + args_vars return SolidityVariableComposed("msg.sender") in conditional_vars + args_vars
@ -1412,19 +1495,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def _filter_state_variables_written(self, expressions: List["Expression"]):
ret = []
for expression in expressions:
if isinstance(expression, Identifier):
ret.append(expression)
if isinstance(expression, UnaryOperation):
ret.append(expression.expression)
if isinstance(expression, MemberAccess):
ret.append(expression.expression)
if isinstance(expression, IndexAccess):
ret.append(expression.expression_left)
return ret
def _analyze_read_write(self): def _analyze_read_write(self):
""" Compute variables read/written/... """ Compute variables read/written/...
@ -1436,7 +1506,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation # Remove dupplicate if they share the same string representation
write_var = [ write_var = [
next(obj) next(obj)
for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x)) for i, obj in groupby(
sorted(write_var, key=lambda x: str(x)), lambda x: str(x)
)
] ]
self._expression_vars_written = write_var self._expression_vars_written = write_var
@ -1447,7 +1519,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation # Remove dupplicate if they share the same string representation
write_var = [ write_var = [
next(obj) next(obj)
for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x)) for i, obj in groupby(
sorted(write_var, key=lambda x: str(x)), lambda x: str(x)
)
] ]
self._vars_written = write_var self._vars_written = write_var
@ -1457,7 +1531,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation # Remove dupplicate if they share the same string representation
read_var = [ read_var = [
next(obj) next(obj)
for i, obj in groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x)) for i, obj in groupby(
sorted(read_var, key=lambda x: str(x)), lambda x: str(x)
)
] ]
self._expression_vars_read = read_var self._expression_vars_read = read_var
@ -1467,14 +1543,18 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation # Remove dupplicate if they share the same string representation
read_var = [ read_var = [
next(obj) next(obj)
for i, obj in groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x)) for i, obj in groupby(
sorted(read_var, key=lambda x: str(x)), lambda x: str(x)
)
] ]
self._vars_read = read_var self._vars_read = read_var
self._state_vars_written = [ self._state_vars_written = [
x for x in self.variables_written if isinstance(x, StateVariable) x for x in self.variables_written if isinstance(x, StateVariable)
] ]
self._state_vars_read = [x for x in self.variables_read if isinstance(x, StateVariable)] self._state_vars_read = [
x for x in self.variables_read if isinstance(x, StateVariable)
]
self._solidity_vars_read = [ self._solidity_vars_read = [
x for x in self.variables_read if isinstance(x, SolidityVariable) x for x in self.variables_read if isinstance(x, SolidityVariable)
] ]
@ -1483,7 +1563,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
slithir_variables = [x.slithir_variables for x in self.nodes] slithir_variables = [x.slithir_variables for x in self.nodes]
slithir_variables = [x for x in slithir_variables if x] slithir_variables = [x for x in slithir_variables if x]
self._slithir_variables = [item for sublist in slithir_variables for item in sublist] self._slithir_variables = [
item for sublist in slithir_variables for item in sublist
]
def _analyze_calls(self): def _analyze_calls(self):
calls = [x.calls_as_expression for x in self.nodes] calls = [x.calls_as_expression for x in self.nodes]
@ -1496,7 +1578,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
internal_calls = [item for sublist in internal_calls for item in sublist] internal_calls = [item for sublist in internal_calls for item in sublist]
self._internal_calls = list(set(internal_calls)) self._internal_calls = list(set(internal_calls))
self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] self._solidity_calls = [
c for c in internal_calls if isinstance(c, SolidityFunction)
]
low_level_calls = [x.low_level_calls for x in self.nodes] low_level_calls = [x.low_level_calls for x in self.nodes]
low_level_calls = [x for x in low_level_calls if x] low_level_calls = [x for x in low_level_calls if x]
@ -1513,7 +1597,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
library_calls = [item for sublist in library_calls for item in sublist] library_calls = [item for sublist in library_calls for item in sublist]
self._library_calls = list(set(library_calls)) self._library_calls = list(set(library_calls))
external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] external_calls_as_expressions = [
x.external_calls_as_expressions for x in self.nodes
]
external_calls_as_expressions = [x for x in external_calls_as_expressions if x] external_calls_as_expressions = [x for x in external_calls_as_expressions if x]
external_calls_as_expressions = [ external_calls_as_expressions = [
item for sublist in external_calls_as_expressions for item in sublist item for sublist in external_calls_as_expressions for item in sublist
@ -1548,6 +1634,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
def _get_last_ssa_variable_instances( def _get_last_ssa_variable_instances(
self, target_state: bool, target_local: bool self, target_state: bool, target_local: bool
) -> Dict[str, Set["SlithIRVariable"]]: ) -> Dict[str, Set["SlithIRVariable"]]:
# pylint: disable=too-many-locals,too-many-branches
from slither.slithir.variables import ReferenceVariable from slither.slithir.variables import ReferenceVariable
from slither.slithir.operations import OperationWithLValue from slither.slithir.operations import OperationWithLValue
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
@ -1603,11 +1690,19 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return ret return ret
def get_last_ssa_state_variables_instances(self) -> Dict[str, Set["SlithIRVariable"]]: def get_last_ssa_state_variables_instances(
return self._get_last_ssa_variable_instances(target_state=True, target_local=False) self,
) -> Dict[str, Set["SlithIRVariable"]]:
return self._get_last_ssa_variable_instances(
target_state=True, target_local=False
)
def get_last_ssa_local_variables_instances(self) -> Dict[str, Set["SlithIRVariable"]]: def get_last_ssa_local_variables_instances(
return self._get_last_ssa_variable_instances(target_state=False, target_local=True) self,
) -> Dict[str, Set["SlithIRVariable"]]:
return self._get_last_ssa_variable_instances(
target_state=False, target_local=True
)
@staticmethod @staticmethod
def _unchange_phi(ir: "Operation"): def _unchange_phi(ir: "Operation"):
@ -1619,7 +1714,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return True return True
return ir.rvalues[0] == ir.lvalue return ir.rvalues[0] == ir.lvalue
def fix_phi(self, last_state_variables_instances, initial_state_variables_instances): def fix_phi(
self, last_state_variables_instances, initial_state_variables_instances
):
from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.operations import InternalCall, PhiCallback
from slither.slithir.variables import Constant, StateIRVariable from slither.slithir.variables import Constant, StateIRVariable
@ -1627,28 +1724,40 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
for ir in node.irs_ssa: for ir in node.irs_ssa:
if node == self.entry_point: if node == self.entry_point:
if isinstance(ir.lvalue, StateIRVariable): if isinstance(ir.lvalue, StateIRVariable):
additional = [initial_state_variables_instances[ir.lvalue.canonical_name]] additional = [
additional += last_state_variables_instances[ir.lvalue.canonical_name] initial_state_variables_instances[ir.lvalue.canonical_name]
]
additional += last_state_variables_instances[
ir.lvalue.canonical_name
]
ir.rvalues = list(set(additional + ir.rvalues)) ir.rvalues = list(set(additional + ir.rvalues))
# function parameter # function parameter
else: else:
# find index of the parameter # find index of the parameter
idx = self.parameters.index(ir.lvalue.non_ssa_version) idx = self.parameters.index(ir.lvalue.non_ssa_version)
# find non ssa version of that index # find non ssa version of that index
additional = [n.ir.arguments[idx] for n in self.reachable_from_nodes] additional = [
n.ir.arguments[idx] for n in self.reachable_from_nodes
]
additional = unroll(additional) additional = unroll(additional)
additional = [a for a in additional if not isinstance(a, Constant)] additional = [
a for a in additional if not isinstance(a, Constant)
]
ir.rvalues = list(set(additional + ir.rvalues)) ir.rvalues = list(set(additional + ir.rvalues))
if isinstance(ir, PhiCallback): if isinstance(ir, PhiCallback):
callee_ir = ir.callee_ir callee_ir = ir.callee_ir
if isinstance(callee_ir, InternalCall): if isinstance(callee_ir, InternalCall):
last_ssa = callee_ir.function.get_last_ssa_state_variables_instances() last_ssa = (
callee_ir.function.get_last_ssa_state_variables_instances()
)
if ir.lvalue.canonical_name in last_ssa: if ir.lvalue.canonical_name in last_ssa:
ir.rvalues = list(last_ssa[ir.lvalue.canonical_name]) ir.rvalues = list(last_ssa[ir.lvalue.canonical_name])
else: else:
ir.rvalues = [ir.lvalue] ir.rvalues = [ir.lvalue]
else: else:
additional = last_state_variables_instances[ir.lvalue.canonical_name] additional = last_state_variables_instances[
ir.lvalue.canonical_name
]
ir.rvalues = list(set(additional + ir.rvalues)) ir.rvalues = list(set(additional + ir.rvalues))
node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)] node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)]
@ -1662,7 +1771,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
def generate_slithir_ssa(self, all_ssa_state_variables_instances): def generate_slithir_ssa(self, all_ssa_state_variables_instances):
from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa
from slither.core.dominators.utils import compute_dominance_frontier, compute_dominators from slither.core.dominators.utils import (
compute_dominance_frontier,
compute_dominators,
)
compute_dominators(self.nodes) compute_dominators(self.nodes)
compute_dominance_frontier(self.nodes) compute_dominance_frontier(self.nodes)

@ -32,7 +32,10 @@ class Pragma(SourceMapping):
@property @property
def is_abi_encoder_v2(self) -> bool: def is_abi_encoder_v2(self) -> bool:
if len(self._directive) == 2: if len(self._directive) == 2:
return self._directive[0] == "experimental" and self._directive[1] == "ABIEncoderV2" return (
self._directive[0] == "experimental"
and self._directive[1] == "ABIEncoderV2"
)
return False return False
def __str__(self): def __str__(self):

@ -3,6 +3,7 @@ from typing import List, Dict, Union
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.solidity_types import ElementaryType, TypeInformation from slither.core.solidity_types import ElementaryType, TypeInformation
from slither.exceptions import SlitherException
SOLIDITY_VARIABLES = { SOLIDITY_VARIABLES = {
"now": "uint256", "now": "uint256",
@ -90,8 +91,12 @@ class SolidityVariable(Context):
self._name = name self._name = name
# dev function, will be removed once the code is stable # dev function, will be removed once the code is stable
def _check_name(self, name: str): def _check_name(self, name: str): # pylint: disable=no-self-use
assert name in SOLIDITY_VARIABLES or name.endswith("_slot") or name.endswith("_offset") assert (
name in SOLIDITY_VARIABLES
or name.endswith("_slot")
or name.endswith("_offset")
)
@property @property
def state_variable(self): def state_variable(self):
@ -99,6 +104,8 @@ class SolidityVariable(Context):
return self._name[:-5] return self._name[:-5]
if self._name.endswith("_offset"): if self._name.endswith("_offset"):
return self._name[:-7] return self._name[:-7]
to_log = f"Incorrect YUL parsing. {self} is not a solidity variable that can be seen as a state variable"
raise SlitherException(to_log)
@property @property
def name(self) -> str: def name(self) -> str:
@ -119,8 +126,6 @@ class SolidityVariable(Context):
class SolidityVariableComposed(SolidityVariable): class SolidityVariableComposed(SolidityVariable):
def __init__(self, name: str):
super(SolidityVariableComposed, self).__init__(name)
def _check_name(self, name: str): def _check_name(self, name: str):
assert name in SOLIDITY_VARIABLES_COMPOSED assert name in SOLIDITY_VARIABLES_COMPOSED

@ -7,7 +7,7 @@ if TYPE_CHECKING:
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
class DominatorNode(object): class DominatorNode:
def __init__(self): def __init__(self):
self._succ: Set["Node"] = set() self._succ: Set["Node"] = set()
self._nodes: List["Node"] = [] self._nodes: List["Node"] = []

@ -43,6 +43,7 @@ def compute_dominators(nodes: List["Node"]):
for dominator in node.dominators: for dominator in node.dominators:
if dominator != node: if dominator != node:
# pylint: disable=expression-not-assigned
[ [
idom_candidates.remove(d) idom_candidates.remove(d)
for d in dominator.dominators for d in dominator.dominators

@ -50,7 +50,9 @@ class AssignmentOperationType(Enum):
if operation_type == "%=": if operation_type == "%=":
return AssignmentOperationType.ASSIGN_MODULO return AssignmentOperationType.ASSIGN_MODULO
raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) raise SlitherCoreError(
"get_type: Unknown operation type {})".format(operation_type)
)
def __str__(self): def __str__(self):
if self == AssignmentOperationType.ASSIGN: if self == AssignmentOperationType.ASSIGN:
@ -115,4 +117,10 @@ class AssignmentOperation(ExpressionTyped):
return self._type return self._type
def __str__(self): def __str__(self):
return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) return (
str(self.expression_left)
+ " "
+ str(self.type)
+ " "
+ str(self.expression_right)
)

@ -41,7 +41,7 @@ class BinaryOperationType(Enum):
RIGHT_SHIFT_ARITHMETIC = 23 RIGHT_SHIFT_ARITHMETIC = 23
@staticmethod @staticmethod
def get_type(operation_type: "BinaryOperation"): def get_type(operation_type: "BinaryOperation"): # pylint: disable=too-many-branches
if operation_type == "**": if operation_type == "**":
return BinaryOperationType.POWER return BinaryOperationType.POWER
if operation_type == "*": if operation_type == "*":
@ -91,9 +91,11 @@ class BinaryOperationType(Enum):
if operation_type == ">>'": if operation_type == ">>'":
return BinaryOperationType.RIGHT_SHIFT_ARITHMETIC return BinaryOperationType.RIGHT_SHIFT_ARITHMETIC
raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) raise SlitherCoreError(
"get_type: Unknown operation type {})".format(operation_type)
)
def __str__(self): def __str__(self): # pylint: disable=too-many-branches
if self == BinaryOperationType.POWER: if self == BinaryOperationType.POWER:
return "**" return "**"
if self == BinaryOperationType.MULTIPLICATION: if self == BinaryOperationType.MULTIPLICATION:
@ -170,4 +172,10 @@ class BinaryOperation(ExpressionTyped):
return self._type return self._type
def __str__(self): def __str__(self):
return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) return (
str(self.expression_left)
+ " "
+ str(self.type)
+ " "
+ str(self.expression_right)
)

@ -3,7 +3,7 @@ from typing import Optional, List
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
class CallExpression(Expression): class CallExpression(Expression): # pylint: disable=too-many-instance-attributes
def __init__(self, called, arguments, type_call): def __init__(self, called, arguments, type_call):
assert isinstance(called, Expression) assert isinstance(called, Expression)
super(CallExpression, self).__init__() super(CallExpression, self).__init__()

@ -14,3 +14,7 @@ class ExpressionTyped(Expression):
@property @property
def type(self): def type(self):
return self._type return self._type
@type.setter
def type(self, new_type: "Type"):
self._type = new_type

@ -8,10 +8,10 @@ if TYPE_CHECKING:
class Literal(Expression): class Literal(Expression):
def __init__(self, value, type, subdenomination=None): def __init__(self, value, custom_type, subdenomination=None):
super(Literal, self).__init__() super(Literal, self).__init__()
self._value: Union[int, str] = value self._value: Union[int, str] = value
self._type = type self._type = custom_type
self._subdenomination: Optional[str] = subdenomination self._subdenomination: Optional[str] = subdenomination
@property @property

@ -1,4 +1,3 @@
from slither.core.expressions.expression import Expression
from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.call_expression import CallExpression

@ -41,7 +41,9 @@ class UnaryOperationType(Enum):
return UnaryOperationType.PLUSPLUS_POST return UnaryOperationType.PLUSPLUS_POST
if operation_type == "--": if operation_type == "--":
return UnaryOperationType.MINUSMINUS_POST return UnaryOperationType.MINUSMINUS_POST
raise SlitherCoreError("get_type: Unknown operation type {}".format(operation_type)) raise SlitherCoreError(
"get_type: Unknown operation type {}".format(operation_type)
)
def __str__(self): def __str__(self):
if self == UnaryOperationType.BANG: if self == UnaryOperationType.BANG:
@ -76,13 +78,15 @@ class UnaryOperationType(Enum):
UnaryOperationType.MINUS_PRE, UnaryOperationType.MINUS_PRE,
]: ]:
return True return True
elif operation_type in [ if operation_type in [
UnaryOperationType.PLUSPLUS_POST, UnaryOperationType.PLUSPLUS_POST,
UnaryOperationType.MINUSMINUS_POST, UnaryOperationType.MINUSMINUS_POST,
]: ]:
return False return False
raise SlitherCoreError("is_prefix: Unknown operation type {}".format(operation_type)) raise SlitherCoreError(
"is_prefix: Unknown operation type {}".format(operation_type)
)
class UnaryOperation(ExpressionTyped): class UnaryOperation(ExpressionTyped):
@ -117,5 +121,4 @@ class UnaryOperation(ExpressionTyped):
def __str__(self): def __str__(self):
if self.is_prefix: if self.is_prefix:
return str(self.type) + " " + str(self._expression) return str(self.type) + " " + str(self._expression)
else: return str(self._expression) + " " + str(self.type)
return str(self._expression) + " " + str(self.type)

@ -12,7 +12,15 @@ from typing import Optional, Dict, List, Set, Union, Tuple
from crytic_compile import CryticCompile from crytic_compile import CryticCompile
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.declarations import Contract, Pragma, Import, Function, Modifier, Structure, Enum from slither.core.declarations import (
Contract,
Pragma,
Import,
Function,
Modifier,
Structure,
Enum,
)
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import InternalCall from slither.slithir.operations import InternalCall
from slither.slithir.variables import Constant from slither.slithir.variables import Constant
@ -22,7 +30,14 @@ logger = logging.getLogger("Slither")
logging.basicConfig() logging.basicConfig()
class SlitherCore(Context): def _relative_path_format(path: str) -> str:
"""
Strip relative paths of "." and ".."
"""
return path.split("..")[-1].strip(".").strip("/")
class SlitherCore(Context): # pylint: disable=too-many-instance-attributes,too-many-public-methods
""" """
Slither static analyzer Slither static analyzer
""" """
@ -115,6 +130,10 @@ class SlitherCore(Context):
return self.crytic_compile.compiler_version.version return self.crytic_compile.compiler_version.version
return self._solc_version return self._solc_version
@solc_version.setter
def solc_version(self, version: str):
self._solc_version = version
@property @property
def pragma_directives(self) -> List[Pragma]: def pragma_directives(self) -> List[Pragma]:
""" list(core.declarations.Pragma): Pragma directives.""" """ list(core.declarations.Pragma): Pragma directives."""
@ -142,14 +161,20 @@ class SlitherCore(Context):
"""list(Contract): List of contracts that are derived and not inherited.""" """list(Contract): List of contracts that are derived and not inherited."""
inheritance = (x.inheritance for x in self.contracts) inheritance = (x.inheritance for x in self.contracts)
inheritance = [item for sublist in inheritance for item in sublist] inheritance = [item for sublist in inheritance for item in sublist]
return [c for c in self._contracts.values() if c not in inheritance and not c.is_top_level] return [
c
for c in self._contracts.values()
if c not in inheritance and not c.is_top_level
]
@property @property
def contracts_as_dict(self) -> Dict[str, Contract]: def contracts_as_dict(self) -> Dict[str, Contract]:
"""list(dict(str: Contract): List of contracts as dict: name -> Contract.""" """list(dict(str: Contract): List of contracts as dict: name -> Contract."""
return self._contracts return self._contracts
def get_contract_from_name(self, contract_name: Union[str, Constant]) -> Optional[Contract]: def get_contract_from_name(
self, contract_name: Union[str, Constant]
) -> Optional[Contract]:
""" """
Return a contract from a name Return a contract from a name
Args: Args:
@ -245,12 +270,6 @@ class SlitherCore(Context):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def relative_path_format(self, path: str) -> str:
"""
Strip relative paths of "." and ".."
"""
return path.split("..")[-1].strip(".").strip("/")
def valid_result(self, r: Dict) -> bool: def valid_result(self, r: Dict) -> bool:
""" """
Check if the result is valid Check if the result is valid
@ -272,7 +291,7 @@ class SlitherCore(Context):
for path in self._paths_to_filter: for path in self._paths_to_filter:
try: try:
if any( if any(
bool(re.search(self.relative_path_format(path), src_mapping)) bool(re.search(_relative_path_format(path), src_mapping))
for src_mapping in source_mapping_elements for src_mapping in source_mapping_elements
): ):
matching = True matching = True
@ -287,11 +306,15 @@ class SlitherCore(Context):
if r["elements"] and matching: if r["elements"] and matching:
return False return False
if r["elements"] and self._exclude_dependencies: if r["elements"] and self._exclude_dependencies:
return not all(element["source_mapping"]["is_dependency"] for element in r["elements"]) return not all(
element["source_mapping"]["is_dependency"] for element in r["elements"]
)
if r["id"] in self._previous_results_ids: if r["id"] in self._previous_results_ids:
return False return False
# Conserve previous result filtering. This is conserved for compatibility, but is meant to be removed # Conserve previous result filtering. This is conserved for compatibility, but is meant to be removed
return not r["description"] in [pr["description"] for pr in self._previous_results] return not r["description"] in [
pr["description"] for pr in self._previous_results
]
def load_previous_results(self): def load_previous_results(self):
filename = self._previous_results_filename filename = self._previous_results_filename
@ -305,7 +328,11 @@ class SlitherCore(Context):
self._previous_results_ids.add(r["id"]) self._previous_results_ids.add(r["id"])
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
logger.error( logger.error(
red("Impossible to decode {}. Consider removing the file".format(filename)) red(
"Impossible to decode {}. Consider removing the file".format(
filename
)
)
) )
def write_results_to_hide(self): def write_results_to_hide(self):
@ -404,7 +431,10 @@ class SlitherCore(Context):
slot += 1 slot += 1
offset = 0 offset = 0
self._storage_layouts[contract.name][var.canonical_name] = (slot, offset) self._storage_layouts[contract.name][var.canonical_name] = (
slot,
offset,
)
if new_slot: if new_slot:
slot += math.ceil(size / 32) slot += math.ceil(size / 32)
else: else:

@ -123,7 +123,9 @@ MN = list(itertools.product(M, N))
Fixed = ["fixed{}x{}".format(m, n) for (m, n) in MN] + ["fixed"] Fixed = ["fixed{}x{}".format(m, n) for (m, n) in MN] + ["fixed"]
Ufixed = ["ufixed{}x{}".format(m, n) for (m, n) in MN] + ["ufixed"] Ufixed = ["ufixed{}x{}".format(m, n) for (m, n) in MN] + ["ufixed"]
ElementaryTypeName = ["address", "bool", "string", "var"] + Int + Uint + Byte + Fixed + Ufixed ElementaryTypeName = (
["address", "bool", "string", "var"] + Int + Uint + Byte + Fixed + Ufixed
)
class NonElementaryType(Exception): class NonElementaryType(Exception):

@ -6,7 +6,9 @@ from slither.core.variables.function_type_variable import FunctionTypeVariable
class FunctionType(Type): class FunctionType(Type):
def __init__( def __init__(
self, params: List[FunctionTypeVariable], return_values: List[FunctionTypeVariable] self,
params: List[FunctionTypeVariable],
return_values: List[FunctionTypeVariable],
): ):
assert all(isinstance(x, FunctionTypeVariable) for x in params) assert all(isinstance(x, FunctionTypeVariable) for x in params)
assert all(isinstance(x, FunctionTypeVariable) for x in return_values) assert all(isinstance(x, FunctionTypeVariable) for x in return_values)

@ -10,6 +10,7 @@ if TYPE_CHECKING:
# https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information
class TypeInformation(Type): class TypeInformation(Type):
def __init__(self, c): def __init__(self, c):
# pylint: disable=import-outside-toplevel
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
assert isinstance(c, Contract) assert isinstance(c, Contract)

@ -2,13 +2,14 @@ from typing import Union, TYPE_CHECKING, Tuple
import math import math
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.exceptions import SlitherException
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
# pylint: disable=import-outside-toplevel
class UserDefinedType(Type): class UserDefinedType(Type):
def __init__(self, t): def __init__(self, t):
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
@ -31,9 +32,9 @@ class UserDefinedType(Type):
if isinstance(self._type, Contract): if isinstance(self._type, Contract):
return 20, False return 20, False
elif isinstance(self._type, Enum): if isinstance(self._type, Enum):
return int(math.ceil(math.log2(len(self._type.values)) / 8)), False return int(math.ceil(math.log2(len(self._type.values)) / 8)), False
elif isinstance(self._type, Structure): if isinstance(self._type, Structure):
# todo there's some duplicate logic here and slither_core, can we refactor this? # todo there's some duplicate logic here and slither_core, can we refactor this?
slot = 0 slot = 0
offset = 0 offset = 0
@ -54,6 +55,8 @@ class UserDefinedType(Type):
if offset > 0: if offset > 0:
slot += 1 slot += 1
return slot * 32, True return slot * 32, True
to_log = f"{self} does not have storage size"
raise SlitherException(to_log)
def __str__(self): def __str__(self):
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure

@ -58,7 +58,7 @@ class SourceMapping(Context):
return lines, starting_column, ending_column return lines, starting_column, ending_column
@staticmethod @staticmethod
def _convert_source_mapping(offset: str, slither): def _convert_source_mapping(offset: str, slither): # pylint: disable=too-many-locals
""" """
Convert a text offset to a real offset Convert a text offset to a real offset
see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings
@ -112,10 +112,14 @@ class SourceMapping(Context):
if slither.crytic_compile and filename in slither.crytic_compile.src_content: if slither.crytic_compile and filename in slither.crytic_compile.src_content:
source_code = slither.crytic_compile.src_content[filename] source_code = slither.crytic_compile.src_content[filename]
(lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l) (lines, starting_column, ending_column) = SourceMapping._compute_line(
source_code, s, l
)
elif filename in slither.source_code: elif filename in slither.source_code:
source_code = slither.source_code[filename] source_code = slither.source_code[filename]
(lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l) (lines, starting_column, ending_column) = SourceMapping._compute_line(
source_code, s, l
)
else: else:
(lines, starting_column, ending_column) = ([], None, None) (lines, starting_column, ending_column) = ([], None, None)
@ -145,7 +149,7 @@ class SourceMapping(Context):
elif len(lines) == 1: elif len(lines) == 1:
lines = "#{}{}".format(line_descr, lines[0]) lines = "#{}{}".format(line_descr, lines[0])
else: else:
lines = "#{}{}-{}{}".format(line_descr, lines[0], line_descr, lines[-1]) lines = f"#{line_descr}{lines[0]}-{line_descr}{lines[-1]}"
return lines return lines
def source_mapping_to_markdown(self, markdown_root: str) -> str: def source_mapping_to_markdown(self, markdown_root: str) -> str:

@ -1,4 +1,4 @@
from .variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_event import ChildEvent from slither.core.children.child_event import ChildEvent

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from .variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_function import ChildFunction from slither.core.children.child_function import ChildFunction
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.array_type import ArrayType

@ -1,6 +1,6 @@
from typing import Optional, TYPE_CHECKING, Tuple, List from typing import Optional, TYPE_CHECKING, Tuple, List
from .variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.utils.type import export_nested_types_from_variable from slither.utils.type import export_nested_types_from_variable
@ -34,7 +34,11 @@ class StateVariable(ChildContract, Variable):
Return the signature of the state variable as a function signature Return the signature of the state variable as a function signature
:return: (str, list(str), list(str)), as (name, list parameters type, list return values type) :return: (str, list(str), list(str)), as (name, list parameters type, list return values type)
""" """
return self.name, [str(x) for x in export_nested_types_from_variable(self)], str(self.type) return (
self.name,
[str(x) for x in export_nested_types_from_variable(self)],
str(self.type),
)
@property @property
def signature_str(self) -> str: def signature_str(self) -> str:
@ -43,7 +47,14 @@ class StateVariable(ChildContract, Variable):
:return: str: func_name(type1,type2) returns(type3) :return: str: func_name(type1,type2) returns(type3)
""" """
name, parameters, returnVars = self.signature name, parameters, returnVars = self.signature
return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")" return (
name
+ "("
+ ",".join(parameters)
+ ") returns("
+ ",".join(returnVars)
+ ")"
)
# endregion # endregion
################################################################################### ###################################################################################

@ -1,4 +1,4 @@
from .variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_structure import ChildStructure from slither.core.children.child_structure import ChildStructure

@ -112,6 +112,7 @@ class Variable(SourceMapping):
Return the name of the variable as a function signature Return the name of the variable as a function signature
:return: :return:
""" """
# pylint: disable=import-outside-toplevel
from slither.core.solidity_types import ArrayType, MappingType from slither.core.solidity_types import ArrayType, MappingType
from slither.utils.type import export_nested_types_from_variable from slither.utils.type import export_nested_types_from_variable
@ -120,7 +121,9 @@ class Variable(SourceMapping):
assert return_type assert return_type
if isinstance(return_type, (ArrayType, MappingType)): if isinstance(return_type, (ArrayType, MappingType)):
variable_getter_args = ",".join(map(str, export_nested_types_from_variable(self))) variable_getter_args = ",".join(
map(str, export_nested_types_from_variable(self))
)
return f"{self.name}({variable_getter_args})" return f"{self.name}({variable_getter_args})"

@ -11,7 +11,7 @@ class IncorrectDetectorInitialization(Exception):
pass pass
class DetectorClassification: class DetectorClassification: # pylint: disable=too-few-public-methods
HIGH = 0 HIGH = 0
MEDIUM = 1 MEDIUM = 1
LOW = 2 LOW = 2
@ -87,12 +87,16 @@ class AbstractDetector(metaclass=abc.ABCMeta):
DetectorClassification.OPTIMIZATION, DetectorClassification.OPTIMIZATION,
]: ]:
raise IncorrectDetectorInitialization( raise IncorrectDetectorInitialization(
"WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__) "WIKI_EXPLOIT_SCENARIO is not initialized {}".format(
self.__class__.__name__
)
) )
if not self.WIKI_RECOMMENDATION: if not self.WIKI_RECOMMENDATION:
raise IncorrectDetectorInitialization( raise IncorrectDetectorInitialization(
"WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__) "WIKI_RECOMMENDATION is not initialized {}".format(
self.__class__.__name__
)
) )
if re.match("^[a-zA-Z0-9_-]*$", self.ARGUMENT) is None: if re.match("^[a-zA-Z0-9_-]*$", self.ARGUMENT) is None:
@ -131,12 +135,14 @@ class AbstractDetector(metaclass=abc.ABCMeta):
"""TODO Documentation""" """TODO Documentation"""
return [] return []
# pylint: disable=too-many-branches
def detect(self): def detect(self):
all_results = self._detect() all_results = self._detect()
# Keep only dictionaries # Keep only dictionaries
all_results = [r.data for r in all_results] all_results = [r.data for r in all_results]
results = [] results = []
# only keep valid result, and remove dupplicate # only keep valid result, and remove dupplicate
# pylint: disable=expression-not-assigned
[ [
results.append(r) results.append(r)
for r in all_results for r in all_results
@ -173,15 +179,19 @@ class AbstractDetector(metaclass=abc.ABCMeta):
) )
continue continue
for patch in patches: for patch in patches:
patched_txt, offset = apply_patch(patched_txt, patch, offset) patched_txt, offset = apply_patch(
diff = create_diff(self.slither, original_txt, patched_txt, file) patched_txt, patch, offset
)
diff = create_diff(
self.slither, original_txt, patched_txt, file
)
if not diff: if not diff:
self._log(f"Impossible to generate patch; empty {result}") self._log(f"Impossible to generate patch; empty {result}")
else: else:
result["patches_diff"][file] = diff result["patches_diff"][file] = diff
except FormatImpossible as e: except FormatImpossible as exception:
self._log(f'\nImpossible to patch:\n\t{result["description"]}\t{e}') self._log(f'\nImpossible to patch:\n\t{result["description"]}\t{exception}')
if results and self.slither.triage_mode: if results and self.slither.triage_mode:
while True: while True:
@ -206,7 +216,9 @@ class AbstractDetector(metaclass=abc.ABCMeta):
) )
return [r for (idx, r) in enumerate(results) if idx not in indexes] return [r for (idx, r) in enumerate(results) if idx not in indexes]
except ValueError: except ValueError:
self.logger.error(yellow("Malformed input. Example of valid input: 0,1,2,3")) self.logger.error(
yellow("Malformed input. Example of valid input: 0,1,2,3")
)
return results return results
@property @property
@ -228,6 +240,6 @@ class AbstractDetector(metaclass=abc.ABCMeta):
return output return output
@staticmethod @staticmethod
def _format(slither, result): def _format(_slither, _result):
"""Implement format""" """Implement format"""
return return

@ -1,3 +1,4 @@
# pylint: disable=unused-import,relative-beyond-top-level
from .examples.backdoor import Backdoor from .examples.backdoor import Backdoor
from .variables.uninitialized_state_variables import UninitializedStateVarsDetection from .variables.uninitialized_state_variables import UninitializedStateVarsDetection
from .variables.uninitialized_storage_variables import UninitializedStorageVars from .variables.uninitialized_storage_variables import UninitializedStorageVars

@ -3,7 +3,7 @@ Module detecting constant functions
Recursively check the called functions Recursively check the called functions
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.const_functions import format from slither.formatters.attributes.const_functions import custom_format
class ConstantFunctionsAsm(AbstractDetector): class ConstantFunctionsAsm(AbstractDetector):
@ -40,9 +40,7 @@ contract Constant{
`Constant` was deployed with Solidity 0.4.25. Bob writes a smart contract that interacts with `Constant` in Solidity 0.5.0. `Constant` was deployed with Solidity 0.4.25. Bob writes a smart contract that interacts with `Constant` in Solidity 0.5.0.
All the calls to `get` revert, breaking Bob's smart contract execution.""" All the calls to `get` revert, breaking Bob's smart contract execution."""
WIKI_RECOMMENDATION = ( WIKI_RECOMMENDATION = "Ensure the attributes of contracts compiled prior to Solidity 0.5.0 are correct."
"Ensure the attributes of contracts compiled prior to Solidity 0.5.0 are correct."
)
def _detect(self): def _detect(self):
""" Detect the constant function using assembly code """ Detect the constant function using assembly code
@ -71,4 +69,4 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -3,7 +3,7 @@ Module detecting constant functions
Recursively check the called functions Recursively check the called functions
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.const_functions import format from slither.formatters.attributes.const_functions import custom_format
class ConstantFunctionsState(AbstractDetector): class ConstantFunctionsState(AbstractDetector):
@ -40,9 +40,7 @@ contract Constant{
`Constant` was deployed with Solidity 0.4.25. Bob writes a smart contract that interacts with `Constant` in Solidity 0.5.0. `Constant` was deployed with Solidity 0.4.25. Bob writes a smart contract that interacts with `Constant` in Solidity 0.5.0.
All the calls to `get` revert, breaking Bob's smart contract execution.""" All the calls to `get` revert, breaking Bob's smart contract execution."""
WIKI_RECOMMENDATION = ( WIKI_RECOMMENDATION = "Ensure that attributes of contracts compiled prior to Solidity 0.5.0 are correct."
"Ensure that attributes of contracts compiled prior to Solidity 0.5.0 are correct."
)
def _detect(self): def _detect(self):
""" Detect the constant function changing the state """ Detect the constant function changing the state
@ -63,7 +61,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
if variables_written: if variables_written:
attr = "view" if f.view else "pure" attr = "view" if f.view else "pure"
info = [f, f" is declared {attr} but changes state variables:\n"] info = [
f,
f" is declared {attr} but changes state variables:\n",
]
for variable_written in variables_written: for variable_written in variables_written:
info += ["\t- ", variable_written, "\n"] info += ["\t- ", variable_written, "\n"]
@ -76,4 +77,4 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -3,7 +3,7 @@
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.constant_pragma import format from slither.formatters.attributes.constant_pragma import custom_format
class ConstantPragma(AbstractDetector): class ConstantPragma(AbstractDetector):
@ -43,4 +43,4 @@ class ConstantPragma(AbstractDetector):
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -4,7 +4,7 @@
import re import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.incorrect_solc import format from slither.formatters.attributes.incorrect_solc import custom_format
# group: # group:
# 0: ^ > >= < <= (optional) # 0: ^ > >= < <= (optional)
@ -13,6 +13,7 @@ from slither.formatters.attributes.incorrect_solc import format
# 3: version number # 3: version number
# 4: version number # 4: version number
# pylint: disable=anomalous-backslash-in-string
PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)")
@ -45,12 +46,8 @@ Consider using the latest version of Solidity for testing."""
OLD_VERSION_TXT = "allows old versions" OLD_VERSION_TXT = "allows old versions"
LESS_THAN_TXT = "uses lesser than" LESS_THAN_TXT = "uses lesser than"
TOO_RECENT_VERSION_TXT = ( TOO_RECENT_VERSION_TXT = "necessitates a version too recent to be trusted. Consider deploying with 0.6.11"
"necessitates a version too recent to be trusted. Consider deploying with 0.6.11" BUGGY_VERSION_TXT = "is known to contain severe issues (https://solidity.readthedocs.io/en/latest/bugs.html)"
)
BUGGY_VERSION_TXT = (
"is known to contain severe issues (https://solidity.readthedocs.io/en/latest/bugs.html)"
)
# Indicates the allowed versions. Must be formatted in increasing order. # Indicates the allowed versions. Must be formatted in increasing order.
ALLOWED_VERSIONS = [ ALLOWED_VERSIONS = [
@ -85,7 +82,9 @@ Consider using the latest version of Solidity for testing."""
return self.LESS_THAN_TXT return self.LESS_THAN_TXT
version_number = ".".join(version[2:]) version_number = ".".join(version[2:])
if version_number not in self.ALLOWED_VERSIONS: if version_number not in self.ALLOWED_VERSIONS:
if list(map(int, version[2:])) > list(map(int, self.ALLOWED_VERSIONS[-1].split("."))): if list(map(int, version[2:])) > list(
map(int, self.ALLOWED_VERSIONS[-1].split("."))
):
return self.TOO_RECENT_VERSION_TXT return self.TOO_RECENT_VERSION_TXT
return self.OLD_VERSION_TXT return self.OLD_VERSION_TXT
return None return None
@ -97,7 +96,7 @@ Consider using the latest version of Solidity for testing."""
if len(versions) == 1: if len(versions) == 1:
version = versions[0] version = versions[0]
return self._check_version(version) return self._check_version(version)
elif len(versions) == 2: if len(versions) == 2:
version_left = versions[0] version_left = versions[0]
version_right = versions[1] version_right = versions[1]
# Only allow two elements if the second one is # Only allow two elements if the second one is
@ -109,8 +108,7 @@ Consider using the latest version of Solidity for testing."""
]: ]:
return self.COMPLEX_PRAGMA_TXT return self.COMPLEX_PRAGMA_TXT
return self._check_version(version_left) return self._check_version(version_left)
else: return self.COMPLEX_PRAGMA_TXT
return self.COMPLEX_PRAGMA_TXT
def _detect(self): def _detect(self):
""" """
@ -121,14 +119,13 @@ Consider using the latest version of Solidity for testing."""
results = [] results = []
pragma = self.slither.pragma_directives pragma = self.slither.pragma_directives
disallowed_pragmas = [] disallowed_pragmas = []
detected_version = False
for p in pragma: for p in pragma:
# Skip any pragma directives which do not refer to version # Skip any pragma directives which do not refer to version
if len(p.directive) < 1 or p.directive[0] != "solidity": if len(p.directive) < 1 or p.directive[0] != "solidity":
continue continue
# This is version, so we test if this is disallowed. # This is version, so we test if this is disallowed.
detected_version = True
reason = self._check_pragma(p.version) reason = self._check_pragma(p.version)
if reason: if reason:
disallowed_pragmas.append((reason, p)) disallowed_pragmas.append((reason, p))
@ -162,4 +159,4 @@ Consider using the latest version of Solidity for testing."""
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -14,9 +14,7 @@ from slither.slithir.operations import (
) )
class LockedEther(AbstractDetector): class LockedEther(AbstractDetector): # pylint: disable=too-many-nested-blocks
"""
"""
ARGUMENT = "locked-ether" ARGUMENT = "locked-ether"
HELP = "Contracts that lock ether" HELP = "Contracts that lock ether"
@ -26,7 +24,9 @@ class LockedEther(AbstractDetector):
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#contracts-that-lock-ether" WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#contracts-that-lock-ether"
WIKI_TITLE = "Contracts that lock Ether" WIKI_TITLE = "Contracts that lock Ether"
WIKI_DESCRIPTION = "Contract with a `payable` function, but without a withdrawal capacity." WIKI_DESCRIPTION = (
"Contract with a `payable` function, but without a withdrawal capacity."
)
WIKI_EXPLOIT_SCENARIO = """ WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
pragma solidity 0.4.24; pragma solidity 0.4.24;
@ -44,7 +44,7 @@ Every Ether sent to `Locked` will be lost."""
functions = contract.all_functions_called functions = contract.all_functions_called
to_explore = functions to_explore = functions
explored = [] explored = []
while to_explore: while to_explore: # pylint: disable=too-many-nested-blocks
functions = to_explore functions = to_explore
explored += to_explore explored += to_explore
to_explore = [] to_explore = []
@ -55,7 +55,8 @@ Every Ether sent to `Locked` will be lost."""
for node in function.nodes: for node in function.nodes:
for ir in node.irs: for ir in node.irs:
if isinstance( if isinstance(
ir, (Send, Transfer, HighLevelCall, LowLevelCall, NewContract) ir,
(Send, Transfer, HighLevelCall, LowLevelCall, NewContract),
): ):
if ir.call_value and ir.call_value != 0: if ir.call_value and ir.call_value != 0:
return False return False
@ -77,13 +78,15 @@ Every Ether sent to `Locked` will be lost."""
for contract in self.slither.contracts_derived: for contract in self.slither.contracts_derived:
if contract.is_signature_only(): if contract.is_signature_only():
continue continue
funcs_payable = [function for function in contract.functions if function.payable] funcs_payable = [
function for function in contract.functions if function.payable
]
if funcs_payable: if funcs_payable:
if self.do_no_send_ether(contract): if self.do_no_send_ether(contract):
info = [f"Contract locking ether found in {self.filename}:\n"] info = [f"Contract locking ether found in {self.filename}:\n"]
info += ["\tContract ", contract, " has payable functions:\n"] info += ["\tContract ", contract, " has payable functions:\n"]
for function in funcs_payable: for function in funcs_payable:
info += [f"\t - ", function, "\n"] info += ["\t - ", function, "\n"]
info += "\tBut does not have a function to withdraw the ether\n" info += "\tBut does not have a function to withdraw the ether\n"
json = self.generate_result(info) json = self.generate_result(info)

@ -36,7 +36,11 @@ contract Token{
def incorrect_erc20_interface(signature): def incorrect_erc20_interface(signature):
(name, parameters, returnVars) = signature (name, parameters, returnVars) = signature
if name == "transfer" and parameters == ["address", "uint256"] and returnVars != ["bool"]: if (
name == "transfer"
and parameters == ["address", "uint256"]
and returnVars != ["bool"]
):
return True return True
if ( if (
@ -46,7 +50,11 @@ contract Token{
): ):
return True return True
if name == "approve" and parameters == ["address", "uint256"] and returnVars != ["bool"]: if (
name == "approve"
and parameters == ["address", "uint256"]
and returnVars != ["bool"]
):
return True return True
if ( if (
@ -56,7 +64,11 @@ contract Token{
): ):
return True return True
if name == "balanceOf" and parameters == ["address"] and returnVars != ["uint256"]: if (
name == "balanceOf"
and parameters == ["address"]
and returnVars != ["uint256"]
):
return True return True
if name == "totalSupply" and parameters == [] and returnVars != ["uint256"]: if name == "totalSupply" and parameters == [] and returnVars != ["uint256"]:
@ -98,10 +110,17 @@ contract Token{
""" """
results = [] results = []
for c in self.slither.contracts_derived: for c in self.slither.contracts_derived:
functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(c) functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(
c
)
if functions: if functions:
for function in functions: for function in functions:
info = [c, " has incorrect ERC20 function interface:", function, "\n"] info = [
c,
" has incorrect ERC20 function interface:",
function,
"\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -14,9 +14,7 @@ class IncorrectERC721InterfaceDetection(AbstractDetector):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#incorrect-erc721-interface"
"https://github.com/crytic/slither/wiki/Detector-Documentation#incorrect-erc721-interface"
)
WIKI_TITLE = "Incorrect erc721 interface" WIKI_TITLE = "Incorrect erc721 interface"
WIKI_DESCRIPTION = "Incorrect return values for `ERC721` functions. A contract compiled with solidity > 0.4.22 interacting with these functions will fail to execute them, as the return value is missing." WIKI_DESCRIPTION = "Incorrect return values for `ERC721` functions. A contract compiled with solidity > 0.4.22 interacting with these functions will fail to execute them, as the return value is missing."
@ -29,18 +27,24 @@ contract Token{
``` ```
`Token.ownerOf` does not return an address like `ERC721` expects. Bob deploys the token. Alice creates a contract that interacts with it but assumes a correct `ERC721` interface implementation. Alice's contract is unable to interact with Bob's contract.""" `Token.ownerOf` does not return an address like `ERC721` expects. Bob deploys the token. Alice creates a contract that interacts with it but assumes a correct `ERC721` interface implementation. Alice's contract is unable to interact with Bob's contract."""
WIKI_RECOMMENDATION = ( WIKI_RECOMMENDATION = "Set the appropriate return values and vtypes for the defined `ERC721` functions."
"Set the appropriate return values and vtypes for the defined `ERC721` functions."
)
@staticmethod @staticmethod
def incorrect_erc721_interface(signature): def incorrect_erc721_interface(signature):
(name, parameters, returnVars) = signature (name, parameters, returnVars) = signature
# ERC721 # ERC721
if name == "balanceOf" and parameters == ["address"] and returnVars != ["uint256"]: if (
name == "balanceOf"
and parameters == ["address"]
and returnVars != ["uint256"]
):
return True return True
if name == "ownerOf" and parameters == ["uint256"] and returnVars != ["address"]: if (
name == "ownerOf"
and parameters == ["uint256"]
and returnVars != ["address"]
):
return True return True
if ( if (
name == "safeTransferFrom" name == "safeTransferFrom"
@ -60,11 +64,23 @@ contract Token{
and returnVars != [] and returnVars != []
): ):
return True return True
if name == "approve" and parameters == ["address", "uint256"] and returnVars != []: if (
name == "approve"
and parameters == ["address", "uint256"]
and returnVars != []
):
return True return True
if name == "setApprovalForAll" and parameters == ["address", "bool"] and returnVars != []: if (
name == "setApprovalForAll"
and parameters == ["address", "bool"]
and returnVars != []
):
return True return True
if name == "getApproved" and parameters == ["uint256"] and returnVars != ["address"]: if (
name == "getApproved"
and parameters == ["uint256"]
and returnVars != ["address"]
):
return True return True
if ( if (
name == "isApprovedForAll" name == "isApprovedForAll"
@ -74,7 +90,11 @@ contract Token{
return True return True
# ERC165 (dependency) # ERC165 (dependency)
if name == "supportsInterface" and parameters == ["bytes4"] and returnVars != ["bool"]: if (
name == "supportsInterface"
and parameters == ["bytes4"]
and returnVars != ["bool"]
):
return True return True
return False return False
@ -107,10 +127,17 @@ contract Token{
""" """
results = [] results = []
for c in self.slither.contracts_derived: for c in self.slither.contracts_derived:
functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(c) functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(
c
)
if functions: if functions:
for function in functions: for function in functions:
info = [c, " has incorrect ERC721 function interface:", function, "\n"] info = [
c,
" has incorrect ERC721 function interface:",
function,
"\n",
]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -76,7 +76,11 @@ Failure to include these keywords will exclude the parameter data in the transac
# Add each problematic event definition to our result list # Add each problematic event definition to our result list
for (event, parameter) in unindexed_params: for (event, parameter) in unindexed_params:
info = ["ERC20 event ", event, f"does not index parameter {parameter}\n"] info = [
"ERC20 event ",
event,
f"does not index parameter {parameter}\n",
]
# Add the events to the JSON (note: we do not add the params/vars as they have no source mapping). # Add the events to the JSON (note: we do not add the params/vars as they have no source mapping).
res = self.generate_result(info) res = self.generate_result(info)

@ -6,7 +6,9 @@ class Backdoor(AbstractDetector):
Detect function named backdoor Detect function named backdoor
""" """
ARGUMENT = "backdoor" # slither will launch the detector with slither.py --mydetector ARGUMENT = (
"backdoor" # slither will launch the detector with slither.py --mydetector
)
HELP = "Function named backdoor (detector example)" HELP = "Function named backdoor (detector example)"
IMPACT = DetectorClassification.HIGH IMPACT = DetectorClassification.HIGH
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH

@ -11,7 +11,10 @@
""" """
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent
from slither.core.declarations.solidity_variables import SolidityFunction, SolidityVariableComposed from slither.core.declarations.solidity_variables import (
SolidityFunction,
SolidityVariableComposed,
)
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import ( from slither.slithir.operations import (
HighLevelCall, HighLevelCall,
@ -23,10 +26,70 @@ from slither.slithir.operations import (
) )
class ArbitrarySend(AbstractDetector): # pylint: disable=too-many-nested-blocks,too-many-branches
def arbitrary_send(func):
if func.is_protected():
return []
ret = []
for node in func.nodes:
for ir in node.irs:
if isinstance(ir, SolidityCall):
if ir.function == SolidityFunction(
"ecrecover(bytes32,uint8,bytes32,bytes32)"
):
return False
if isinstance(ir, Index):
if ir.variable_right == SolidityVariableComposed("msg.sender"):
return False
if is_dependent(
ir.variable_right,
SolidityVariableComposed("msg.sender"),
func.contract,
):
return False
if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)):
if isinstance(ir, (HighLevelCall)):
if isinstance(ir.function, Function):
if (
ir.function.full_name
== "transferFrom(address,address,uint256)"
):
return False
if ir.call_value is None:
continue
if ir.call_value == SolidityVariableComposed("msg.value"):
continue
if is_dependent(
ir.call_value,
SolidityVariableComposed("msg.value"),
func.contract,
):
continue
if is_tainted(ir.destination, func.contract):
ret.append(node)
return ret
def detect_arbitrary_send(contract):
""" """
Detect arbitrary send
Args:
contract (Contract)
Returns:
list((Function), (list (Node)))
""" """
ret = []
for f in [f for f in contract.functions if f.contract_declarer == contract]:
nodes = arbitrary_send(f)
if nodes:
ret.append((f, nodes))
return ret
class ArbitrarySend(AbstractDetector):
ARGUMENT = "arbitrary-send" ARGUMENT = "arbitrary-send"
HELP = "Functions that send Ether to arbitrary destinations" HELP = "Functions that send Ether to arbitrary destinations"
IMPACT = DetectorClassification.HIGH IMPACT = DetectorClassification.HIGH
@ -35,7 +98,9 @@ class ArbitrarySend(AbstractDetector):
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#functions-that-send-ether-to-arbitrary-destinations" WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#functions-that-send-ether-to-arbitrary-destinations"
WIKI_TITLE = "Functions that send Ether to arbitrary destinations" WIKI_TITLE = "Functions that send Ether to arbitrary destinations"
WIKI_DESCRIPTION = "Unprotected call to a function sending Ether to an arbitrary address." WIKI_DESCRIPTION = (
"Unprotected call to a function sending Ether to an arbitrary address."
)
WIKI_EXPLOIT_SCENARIO = """ WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract ArbitrarySend{ contract ArbitrarySend{
@ -51,60 +116,9 @@ contract ArbitrarySend{
``` ```
Bob calls `setDestination` and `withdraw`. As a result he withdraws the contract's balance.""" Bob calls `setDestination` and `withdraw`. As a result he withdraws the contract's balance."""
WIKI_RECOMMENDATION = "Ensure that an arbitrary user cannot withdraw unauthorized funds." WIKI_RECOMMENDATION = (
"Ensure that an arbitrary user cannot withdraw unauthorized funds."
def arbitrary_send(self, func): )
"""
"""
if func.is_protected():
return []
ret = []
for node in func.nodes:
for ir in node.irs:
if isinstance(ir, SolidityCall):
if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"):
return False
if isinstance(ir, Index):
if ir.variable_right == SolidityVariableComposed("msg.sender"):
return False
if is_dependent(
ir.variable_right, SolidityVariableComposed("msg.sender"), func.contract
):
return False
if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)):
if isinstance(ir, (HighLevelCall)):
if isinstance(ir.function, Function):
if ir.function.full_name == "transferFrom(address,address,uint256)":
return False
if ir.call_value is None:
continue
if ir.call_value == SolidityVariableComposed("msg.value"):
continue
if is_dependent(
ir.call_value, SolidityVariableComposed("msg.value"), func.contract
):
continue
if is_tainted(ir.destination, func.contract):
ret.append(node)
return ret
def detect_arbitrary_send(self, contract):
"""
Detect arbitrary send
Args:
contract (Contract)
Returns:
list((Function), (list (Node)))
"""
ret = []
for f in [f for f in contract.functions if f.contract_declarer == contract]:
nodes = self.arbitrary_send(f)
if nodes:
ret.append((f, nodes))
return ret
def _detect(self): def _detect(self):
""" """
@ -112,8 +126,8 @@ Bob calls `setDestination` and `withdraw`. As a result he withdraws the contract
results = [] results = []
for c in self.contracts: for c in self.contracts:
arbitrary_send = self.detect_arbitrary_send(c) arbitrary_send_result = detect_arbitrary_send(c)
for (func, nodes) in arbitrary_send: for (func, nodes) in arbitrary_send_result:
info = [func, " sends eth to arbitrary user\n"] info = [func, " sends eth to arbitrary user\n"]
info += ["\tDangerous calls:\n"] info += ["\tDangerous calls:\n"]

@ -1,105 +0,0 @@
from slither.core.declarations.solidity_variables import SolidityFunction, SolidityVariableComposed
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import HighLevelCall, LowLevelCall, LibraryCall
from slither.utils.code_complexity import compute_cyclomatic_complexity
class ComplexFunction(AbstractDetector):
"""
Module detecting complex functions
A complex function is defined by:
- high cyclomatic complexity
- numerous writes to state variables
- numerous external calls
"""
ARGUMENT = "complex-function"
HELP = "Complex functions"
IMPACT = DetectorClassification.INFORMATIONAL
CONFIDENCE = DetectorClassification.MEDIUM
MAX_STATE_VARIABLES = 10
MAX_EXTERNAL_CALLS = 5
MAX_CYCLOMATIC_COMPLEXITY = 7
CAUSE_CYCLOMATIC = "cyclomatic"
CAUSE_EXTERNAL_CALL = "external_calls"
CAUSE_STATE_VARS = "state_vars"
STANDARD_JSON = True
@staticmethod
def detect_complex_func(func):
"""Detect the cyclomatic complexity of the contract functions
shouldn't be greater than 7
"""
result = []
code_complexity = compute_cyclomatic_complexity(func)
if code_complexity > ComplexFunction.MAX_CYCLOMATIC_COMPLEXITY:
result.append({"func": func, "cause": ComplexFunction.CAUSE_CYCLOMATIC})
"""Detect the number of external calls in the func
shouldn't be greater than 5
"""
count = 0
for node in func.nodes:
for ir in node.irs:
if isinstance(ir, (HighLevelCall, LowLevelCall, LibraryCall)):
count += 1
if count > ComplexFunction.MAX_EXTERNAL_CALLS:
result.append({"func": func, "cause": ComplexFunction.CAUSE_EXTERNAL_CALL})
"""Checks the number of the state variables written
shouldn't be greater than 10
"""
if len(func.state_variables_written) > ComplexFunction.MAX_STATE_VARIABLES:
result.append({"func": func, "cause": ComplexFunction.CAUSE_STATE_VARS})
return result
def detect_complex(self, contract):
ret = []
for func in contract.all_functions_called:
result = self.detect_complex_func(func)
ret.extend(result)
return ret
def detect(self):
results = []
for contract in self.contracts:
issues = self.detect_complex(contract)
for issue in issues:
func, cause = issue.values()
txt = "{} ({}) is a complex function:\n"
if cause == self.CAUSE_EXTERNAL_CALL:
txt += "\t- Reason: High number of external calls"
if cause == self.CAUSE_CYCLOMATIC:
txt += "\t- Reason: High number of branches"
if cause == self.CAUSE_STATE_VARS:
txt += "\t- Reason: High number of modified state variables"
info = txt.format(func.canonical_name, func.source_mapping_str)
info = info + "\n"
self.log(info)
res = self.generate_result(info)
res.add(
func,
{
"high_number_of_external_calls": cause == self.CAUSE_EXTERNAL_CALL,
"high_number_of_branches": cause == self.CAUSE_CYCLOMATIC,
"high_number_of_state_variables": cause == self.CAUSE_STATE_VARS,
},
)
results.append(res)
return results

@ -1,7 +1,7 @@
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import SolidityCall from slither.slithir.operations import SolidityCall
from slither.slithir.operations import InternalCall, InternalDynamicCall from slither.slithir.operations import InternalCall, InternalDynamicCall
from slither.formatters.functions.external_function import format from slither.formatters.functions.external_function import custom_format
class ExternalFunction(AbstractDetector): class ExternalFunction(AbstractDetector):
@ -81,7 +81,9 @@ class ExternalFunction(AbstractDetector):
# Somehow we couldn't resolve it, which shouldn't happen, as the provided function should be found if we could # Somehow we couldn't resolve it, which shouldn't happen, as the provided function should be found if we could
# not find some any more basic. # not find some any more basic.
raise Exception("Could not resolve the base-most function for the provided function.") raise Exception(
"Could not resolve the base-most function for the provided function."
)
@staticmethod @staticmethod
def get_all_function_definitions(base_most_function): def get_all_function_definitions(base_most_function):
@ -105,7 +107,7 @@ class ExternalFunction(AbstractDetector):
def function_parameters_written(function): def function_parameters_written(function):
return any(p in function.variables_written for p in function.parameters) return any(p in function.variables_written for p in function.parameters)
def _detect(self): def _detect(self): # pylint: disable=too-many-locals,too-many-branches
results = [] results = []
# Create a set to track contracts with dynamic calls. All contracts with dynamic calls could potentially be # Create a set to track contracts with dynamic calls. All contracts with dynamic calls could potentially be
@ -157,20 +159,22 @@ class ExternalFunction(AbstractDetector):
all_function_definitions = set( all_function_definitions = set(
self.get_all_function_definitions(base_most_function) self.get_all_function_definitions(base_most_function)
) )
completed_functions = completed_functions.union(all_function_definitions) completed_functions = completed_functions.union(
all_function_definitions
)
# Filter false-positives: Determine if any of these sources have dynamic calls, if so, flag all of these # Filter false-positives: Determine if any of these sources have dynamic calls, if so, flag all of these
# function definitions, and then flag all functions in all contracts that make dynamic calls. # function definitions, and then flag all functions in all contracts that make dynamic calls.
sources_with_dynamic_calls = set(all_possible_sources) & dynamic_call_contracts sources_with_dynamic_calls = (
set(all_possible_sources) & dynamic_call_contracts
)
if sources_with_dynamic_calls: if sources_with_dynamic_calls:
functions_in_dynamic_call_sources = set( functions_in_dynamic_call_sources = {
[ f
f for dyn_contract in sources_with_dynamic_calls
for dyn_contract in sources_with_dynamic_calls for f in dyn_contract.functions
for f in dyn_contract.functions if not f.is_constructor
if not f.is_constructor }
]
)
completed_functions = completed_functions.union( completed_functions = completed_functions.union(
functions_in_dynamic_call_sources functions_in_dynamic_call_sources
) )
@ -200,10 +204,12 @@ class ExternalFunction(AbstractDetector):
function_definition = all_function_definitions[0] function_definition = all_function_definitions[0]
all_function_definitions = all_function_definitions[1:] all_function_definitions = all_function_definitions[1:]
info = [f"{function_definition.full_name} should be declared external:\n"] info = [
info += [f"\t- ", function_definition, "\n"] f"{function_definition.full_name} should be declared external:\n"
]
info += ["\t- ", function_definition, "\n"]
for other_function_definition in all_function_definitions: for other_function_definition in all_function_definitions:
info += [f"\t- ", other_function_definition, "\n"] info += ["\t- ", other_function_definition, "\n"]
res = self.generate_result(info) res = self.generate_result(info)
@ -213,4 +219,4 @@ class ExternalFunction(AbstractDetector):
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -20,7 +20,9 @@ class Suicidal(AbstractDetector):
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#suicidal" WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#suicidal"
WIKI_TITLE = "Suicidal" WIKI_TITLE = "Suicidal"
WIKI_DESCRIPTION = "Unprotected call to a function executing `selfdestruct`/`suicide`." WIKI_DESCRIPTION = (
"Unprotected call to a function executing `selfdestruct`/`suicide`."
)
WIKI_EXPLOIT_SCENARIO = """ WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Suicidal{ contract Suicidal{

@ -1,6 +1,6 @@
import re import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.naming_convention.naming_convention import format from slither.formatters.naming_convention.naming_convention import custom_format
class NamingConvention(AbstractDetector): class NamingConvention(AbstractDetector):
@ -54,7 +54,7 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
def should_avoid_name(name): def should_avoid_name(name):
return re.search("^[lOI]$", name) is not None return re.search("^[lOI]$", name) is not None
def _detect(self): def _detect(self): # pylint: disable=too-many-branches,too-many-statements
results = [] results = []
for contract in self.contracts: for contract in self.contracts:
@ -91,7 +91,9 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
"private", "private",
] and self.is_mixed_case_with_underscore(func.name): ] and self.is_mixed_case_with_underscore(func.name):
continue continue
if func.name.startswith("echidna_") or func.name.startswith("crytic_"): if func.name.startswith("echidna_") or func.name.startswith(
"crytic_"
):
continue continue
info = ["Function ", func, " is not in mixedCase\n"] info = ["Function ", func, " is not in mixedCase\n"]
@ -106,22 +108,34 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
if argument in func.variables_read_or_written: if argument in func.variables_read_or_written:
correct_naming = self.is_mixed_case(argument.name) correct_naming = self.is_mixed_case(argument.name)
else: else:
correct_naming = self.is_mixed_case_with_underscore(argument.name) correct_naming = self.is_mixed_case_with_underscore(
argument.name
)
if not correct_naming: if not correct_naming:
info = ["Parameter ", argument, " is not in mixedCase\n"] info = ["Parameter ", argument, " is not in mixedCase\n"]
res = self.generate_result(info) res = self.generate_result(info)
res.add(argument, {"target": "parameter", "convention": "mixedCase"}) res.add(
argument, {"target": "parameter", "convention": "mixedCase"}
)
results.append(res) results.append(res)
for var in contract.state_variables_declared: for var in contract.state_variables_declared:
if self.should_avoid_name(var.name): if self.should_avoid_name(var.name):
if not self.is_upper_case_with_underscores(var.name): if not self.is_upper_case_with_underscores(var.name):
info = ["Variable ", var, " used l, O, I, which should not be used\n"] info = [
"Variable ",
var,
" used l, O, I, which should not be used\n",
]
res = self.generate_result(info) res = self.generate_result(info)
res.add( res.add(
var, {"target": "variable", "convention": "l_O_I_should_not_be_used"} var,
{
"target": "variable",
"convention": "l_O_I_should_not_be_used",
},
) )
results.append(res) results.append(res)
@ -131,7 +145,11 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
continue continue
if not self.is_upper_case_with_underscores(var.name): if not self.is_upper_case_with_underscores(var.name):
info = ["Constant ", var, " is not in UPPER_CASE_WITH_UNDERSCORES\n"] info = [
"Constant ",
var,
" is not in UPPER_CASE_WITH_UNDERSCORES\n",
]
res = self.generate_result(info) res = self.generate_result(info)
res.add( res.add(
@ -175,4 +193,4 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -7,7 +7,10 @@ from typing import List, Tuple
from slither.analyses.data_dependency.data_dependency import is_dependent from slither.analyses.data_dependency.data_dependency import is_dependent
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract from slither.core.declarations import Function, Contract
from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityVariable from slither.core.declarations.solidity_variables import (
SolidityVariableComposed,
SolidityVariable,
)
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import Binary, BinaryType from slither.slithir.operations import Binary, BinaryType
@ -17,7 +20,9 @@ def _timestamp(func: Function) -> List[Node]:
for node in func.nodes: for node in func.nodes:
if node.contains_require_or_assert(): if node.contains_require_or_assert():
for var in node.variables_read: for var in node.variables_read:
if is_dependent(var, SolidityVariableComposed("block.timestamp"), func.contract): if is_dependent(
var, SolidityVariableComposed("block.timestamp"), func.contract
):
ret.add(node) ret.add(node)
if is_dependent(var, SolidityVariable("now"), func.contract): if is_dependent(var, SolidityVariable("now"), func.contract):
ret.add(node) ret.add(node)
@ -33,7 +38,9 @@ def _timestamp(func: Function) -> List[Node]:
return sorted(list(ret), key=lambda x: x.node_id) return sorted(list(ret), key=lambda x: x.node_id)
def _detect_dangerous_timestamp(contract: Contract) -> List[Tuple[Function, List[Node]]]: def _detect_dangerous_timestamp(
contract: Contract,
) -> List[Tuple[Function, List[Node]]]:
""" """
Args: Args:
contract (Contract) contract (Contract)
@ -49,20 +56,18 @@ def _detect_dangerous_timestamp(contract: Contract) -> List[Tuple[Function, List
class Timestamp(AbstractDetector): class Timestamp(AbstractDetector):
"""
"""
ARGUMENT = "timestamp" ARGUMENT = "timestamp"
HELP = "Dangerous usage of `block.timestamp`" HELP = "Dangerous usage of `block.timestamp`"
IMPACT = DetectorClassification.LOW IMPACT = DetectorClassification.LOW
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#block-timestamp" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#block-timestamp"
)
WIKI_TITLE = "Block timestamp" WIKI_TITLE = "Block timestamp"
WIKI_DESCRIPTION = ( WIKI_DESCRIPTION = "Dangerous usage of `block.timestamp`. `block.timestamp` can be manipulated by miners."
"Dangerous usage of `block.timestamp`. `block.timestamp` can be manipulated by miners."
)
WIKI_EXPLOIT_SCENARIO = """"Bob's contract relies on `block.timestamp` for its randomness. Eve is a miner and manipulates `block.timestamp` to exploit Bob's contract.""" WIKI_EXPLOIT_SCENARIO = """"Bob's contract relies on `block.timestamp` for its randomness. Eve is a miner and manipulates `block.timestamp` to exploit Bob's contract."""
WIKI_RECOMMENDATION = "Avoid relying on `block.timestamp`." WIKI_RECOMMENDATION = "Avoid relying on `block.timestamp`."

@ -16,7 +16,9 @@ class LowLevelCalls(AbstractDetector):
IMPACT = DetectorClassification.INFORMATIONAL IMPACT = DetectorClassification.INFORMATIONAL
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#low-level-calls" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#low-level-calls"
)
WIKI_TITLE = "Low-level calls" WIKI_TITLE = "Low-level calls"
WIKI_DESCRIPTION = "The use of low-level calls is error-prone. Low-level calls do not check for [code existence](https://solidity.readthedocs.io/en/v0.4.25/control-structures.html#error-handling-assert-require-revert-and-exceptions) or call success." WIKI_DESCRIPTION = "The use of low-level calls is error-prone. Low-level calls do not check for [code existence](https://solidity.readthedocs.io/en/v0.4.25/control-structures.html#error-handling-assert-require-revert-and-exceptions) or call success."

@ -2,7 +2,7 @@
Module detecting unused return values from low level Module detecting unused return values from low level
""" """
from slither.detectors.abstract_detector import DetectorClassification from slither.detectors.abstract_detector import DetectorClassification
from .unused_return_values import UnusedReturnValues from slither.detectors.operations.unused_return_values import UnusedReturnValues
from slither.slithir.operations import LowLevelCall from slither.slithir.operations import LowLevelCall
@ -32,9 +32,11 @@ The return value of the low-level call is not checked, so if the call fails, the
If the low level is used to prevent blocking operations, consider logging failed calls. If the low level is used to prevent blocking operations, consider logging failed calls.
""" """
WIKI_RECOMMENDATION = "Ensure that the return value of a low-level call is checked or logged." WIKI_RECOMMENDATION = (
"Ensure that the return value of a low-level call is checked or logged."
)
_txt_description = "low-level calls" _txt_description = "low-level calls"
def _is_instance(self, ir): def _is_instance(self, ir): # pylint: disable=no-self-use
return isinstance(ir, LowLevelCall) return isinstance(ir, LowLevelCall)

@ -3,7 +3,7 @@ Module detecting unused return values from send
""" """
from slither.detectors.abstract_detector import DetectorClassification from slither.detectors.abstract_detector import DetectorClassification
from .unused_return_values import UnusedReturnValues from slither.detectors.operations.unused_return_values import UnusedReturnValues
from slither.slithir.operations import Send from slither.slithir.operations import Send
@ -17,7 +17,9 @@ class UncheckedSend(UnusedReturnValues):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#unchecked-send" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#unchecked-send"
)
WIKI_TITLE = "Unchecked Send" WIKI_TITLE = "Unchecked Send"
WIKI_DESCRIPTION = "The return value of a `send` is not checked." WIKI_DESCRIPTION = "The return value of a `send` is not checked."
@ -37,5 +39,5 @@ If `send` is used to prevent blocking operations, consider logging the failed `s
_txt_description = "send calls" _txt_description = "send calls"
def _is_instance(self, ir): def _is_instance(self, ir): # pylint: disable=no-self-use
return isinstance(ir, Send) return isinstance(ir, Send)

@ -2,9 +2,11 @@
Module detecting unused return values from external calls Module detecting unused return values from external calls
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import HighLevelCall, InternalCall, InternalDynamicCall
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import HighLevelCall
class UnusedReturnValues(AbstractDetector): class UnusedReturnValues(AbstractDetector):
@ -20,9 +22,7 @@ class UnusedReturnValues(AbstractDetector):
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#unused-return" WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#unused-return"
WIKI_TITLE = "Unused return" WIKI_TITLE = "Unused return"
WIKI_DESCRIPTION = ( WIKI_DESCRIPTION = "The return value of an external call is not stored in a local or state variable."
"The return value of an external call is not stored in a local or state variable."
)
WIKI_EXPLOIT_SCENARIO = """ WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract MyConc{ contract MyConc{
@ -34,14 +34,16 @@ contract MyConc{
``` ```
`MyConc` calls `add` of `SafeMath`, but does not store the result in `a`. As a result, the computation has no effect.""" `MyConc` calls `add` of `SafeMath`, but does not store the result in `a`. As a result, the computation has no effect."""
WIKI_RECOMMENDATION = "Ensure that all the return values of the function calls are used." WIKI_RECOMMENDATION = (
"Ensure that all the return values of the function calls are used."
)
_txt_description = "external calls" _txt_description = "external calls"
def _is_instance(self, ir): def _is_instance(self, ir): # pylint: disable=no-self-use
return isinstance(ir, HighLevelCall) return isinstance(ir, HighLevelCall)
def detect_unused_return_values(self, f): def detect_unused_return_values(self, f): # pylint: disable=no-self-use
""" """
Return the nodes where the return value of a call is unused Return the nodes where the return value of a call is unused
Args: Args:
@ -76,7 +78,7 @@ contract MyConc{
if unused_return: if unused_return:
for node in unused_return: for node in unused_return:
info = [f, f" ignores return value by ", node, "\n"] info = [f, " ignores return value by ", node, "\n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -9,7 +9,9 @@ class VoidConstructor(AbstractDetector):
IMPACT = DetectorClassification.LOW IMPACT = DetectorClassification.LOW
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#void-constructor" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#void-constructor"
)
WIKI_TITLE = "Void constructor" WIKI_TITLE = "Void constructor"
WIKI_DESCRIPTION = "Detect the call to a constructor that is not implemented" WIKI_DESCRIPTION = "Detect the call to a constructor that is not implemented"

@ -16,7 +16,10 @@ from slither.slithir.operations import Call, EventCall
def union_dict(d1, d2): def union_dict(d1, d2):
d3 = {k: d1.get(k, set()) | d2.get(k, set()) for k in set(list(d1.keys()) + list(d2.keys()))} d3 = {
k: d1.get(k, set()) | d2.get(k, set())
for k in set(list(d1.keys()) + list(d2.keys()))
}
return defaultdict(set, d3) return defaultdict(set, d3)
@ -40,7 +43,8 @@ def is_subset(
def to_hashable(d: Dict[Node, Set[Node]]): def to_hashable(d: Dict[Node, Set[Node]]):
list_tuple = list( list_tuple = list(
tuple((k, tuple(sorted(values, key=lambda x: x.node_id)))) for k, values in d.items() tuple((k, tuple(sorted(values, key=lambda x: x.node_id))))
for k, values in d.items()
) )
return tuple(sorted(list_tuple, key=lambda x: x[0].node_id)) return tuple(sorted(list_tuple, key=lambda x: x[0].node_id))
@ -125,9 +129,12 @@ class AbstractState:
if key != skip_father if key != skip_father
}, },
) )
self._reads = union_dict(self._reads, father.context[detector.KEY].reads) self._reads = union_dict(
self._reads, father.context[detector.KEY].reads
)
self._reads_prior_calls = union_dict( self._reads_prior_calls = union_dict(
self.reads_prior_calls, father.context[detector.KEY].reads_prior_calls self.reads_prior_calls,
father.context[detector.KEY].reads_prior_calls,
) )
def analyze_node(self, node, detector): def analyze_node(self, node, detector):
@ -178,14 +185,36 @@ class AbstractState:
self._send_eth = union_dict(self._send_eth, fathers.send_eth) self._send_eth = union_dict(self._send_eth, fathers.send_eth)
self._calls = union_dict(self._calls, fathers.calls) self._calls = union_dict(self._calls, fathers.calls)
self._reads = union_dict(self._reads, fathers.reads) self._reads = union_dict(self._reads, fathers.reads)
self._reads_prior_calls = union_dict(self._reads_prior_calls, fathers.reads_prior_calls) self._reads_prior_calls = union_dict(
self._reads_prior_calls, fathers.reads_prior_calls
)
def does_not_bring_new_info(self, new_info): def does_not_bring_new_info(self, new_info):
if is_subset(new_info.calls, self.calls): if is_subset(new_info.calls, self.calls):
if is_subset(new_info.send_eth, self.send_eth): if is_subset(new_info.send_eth, self.send_eth):
if is_subset(new_info.reads, self.reads): if is_subset(new_info.reads, self.reads):
if dict_are_equal(new_info.reads_prior_calls, self.reads_prior_calls): if dict_are_equal(
new_info.reads_prior_calls, self.reads_prior_calls
):
return True return True
return False
def _filter_if(node):
"""
Check if the node is a condtional node where
there is an external call checked
Heuristic:
- The call is a IF node
- It contains a, external call
- The condition is the negation (!)
This will work only on naive implementation
"""
return (
isinstance(node.expression, UnaryOperation)
and node.expression.type == UnaryOperationType.BANG
)
class Reentrancy(AbstractDetector): class Reentrancy(AbstractDetector):
@ -215,22 +244,6 @@ class Reentrancy(AbstractDetector):
""" """
return isinstance(ir, Call) and ir.can_send_eth() return isinstance(ir, Call) and ir.can_send_eth()
def _filter_if(self, node):
"""
Check if the node is a condtional node where
there is an external call checked
Heuristic:
- The call is a IF node
- It contains a, external call
- The condition is the negation (!)
This will work only on naive implementation
"""
return (
isinstance(node.expression, UnaryOperation)
and node.expression.type == UnaryOperationType.BANG
)
def _explore(self, node, visited, skip_father=None): def _explore(self, node, visited, skip_father=None):
""" """
Explore the CFG and look for re-entrancy Explore the CFG and look for re-entrancy
@ -266,7 +279,7 @@ class Reentrancy(AbstractDetector):
sons = node.sons sons = node.sons
if contains_call and node.type in [NodeType.IF, NodeType.IFLOOP]: if contains_call and node.type in [NodeType.IF, NodeType.IFLOOP]:
if self._filter_if(node): if _filter_if(node):
son = sons[0] son = sons[0]
self._explore(son, visited, node) self._explore(son, visited, node)
sons = sons[1:] sons = sons[1:]
@ -279,8 +292,6 @@ class Reentrancy(AbstractDetector):
self._explore(son, visited) self._explore(son, visited)
def detect_reentrancy(self, contract): def detect_reentrancy(self, contract):
"""
"""
for function in contract.functions_and_modifiers_declared: for function in contract.functions_and_modifiers_declared:
if function.is_implemented: if function.is_implemented:
if self.KEY in function.context: if self.KEY in function.context:
@ -296,7 +307,7 @@ class Reentrancy(AbstractDetector):
# new variables written # new variables written
# This speedup the exploration through a light fixpoint # This speedup the exploration through a light fixpoint
# Its particular useful on 'complex' functions with several loops and conditions # Its particular useful on 'complex' functions with several loops and conditions
self.visited_all_paths = {} self.visited_all_paths = {} # pylint: disable=attribute-defined-outside-init
for c in self.contracts: for c in self.contracts:
self.detect_reentrancy(c) self.detect_reentrancy(c)

@ -20,9 +20,7 @@ class ReentrancyBenign(Reentrancy):
IMPACT = DetectorClassification.LOW IMPACT = DetectorClassification.LOW
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-2"
"https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-2"
)
WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_TITLE = "Reentrancy vulnerabilities"
WIKI_DESCRIPTION = """ WIKI_DESCRIPTION = """
@ -62,13 +60,15 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr
for v in node.context[self.KEY].written for v in node.context[self.KEY].written
if v in node.context[self.KEY].reads_prior_calls[c] if v in node.context[self.KEY].reads_prior_calls[c]
] ]
not_read_then_written = set( not_read_then_written = {
[ FindingValue(
FindingValue(v, node, tuple(sorted(nodes, key=lambda x: x.node_id))) v,
for (v, nodes) in node.context[self.KEY].written.items() node,
if v not in read_then_written tuple(sorted(nodes, key=lambda x: x.node_id)),
] )
) for (v, nodes) in node.context[self.KEY].written.items()
if v not in read_then_written
}
if not_read_then_written: if not_read_then_written:
# calls are ordered # calls are ordered
finding_key = FindingKey( finding_key = FindingKey(
@ -79,7 +79,7 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr
result[finding_key] |= not_read_then_written result[finding_key] |= not_read_then_written
return result return result
def _detect(self): def _detect(self): # pylint: disable=too-many-branches
""" """
""" """
@ -88,12 +88,16 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr
results = [] results = []
result_sorted = sorted(list(reentrancies.items()), key=lambda x: x[0].function.name) result_sorted = sorted(
list(reentrancies.items()), key=lambda x: x[0].function.name
)
varsWritten: List[FindingValue] varsWritten: List[FindingValue]
for (func, calls, send_eth), varsWritten in result_sorted: for (func, calls, send_eth), varsWritten in result_sorted:
calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) calls = sorted(list(set(calls)), key=lambda x: x[0].node_id)
send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id) send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id)
varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) varsWritten = sorted(
varsWritten, key=lambda x: (x.variable.name, x.node.node_id)
)
info = ["Reentrancy in ", func, ":\n"] info = ["Reentrancy in ", func, ":\n"]
@ -128,18 +132,24 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr
res.add(call_info, {"underlying_type": "external_calls"}) res.add(call_info, {"underlying_type": "external_calls"})
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_list_info,
{"underlying_type": "external_calls_sending_eth"},
)
# #
# If the calls are not the same ones that send eth, add the eth sending nodes. # If the calls are not the same ones that send eth, add the eth sending nodes.
if calls != send_eth: if calls != send_eth:
for (call_info, calls_list) in calls: for (call_info, calls_list) in calls:
res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_info, {"underlying_type": "external_calls_sending_eth"}
)
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add( res.add(
call_list_info, {"underlying_type": "external_calls_sending_eth"} call_list_info,
{"underlying_type": "external_calls_sending_eth"},
) )
# Add all variables written via nodes which write them. # Add all variables written via nodes which write them.

@ -20,9 +20,7 @@ class ReentrancyEth(Reentrancy):
IMPACT = DetectorClassification.HIGH IMPACT = DetectorClassification.HIGH
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities"
"https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities"
)
WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_TITLE = "Reentrancy vulnerabilities"
WIKI_DESCRIPTION = """ WIKI_DESCRIPTION = """
@ -48,7 +46,7 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m
def find_reentrancies(self): def find_reentrancies(self):
result = defaultdict(set) result = defaultdict(set)
for contract in self.contracts: for contract in self.contracts: # pylint: disable=too-many-nested-blocks
for f in contract.functions_and_modifiers_declared: for f in contract.functions_and_modifiers_declared:
for node in f.nodes: for node in f.nodes:
# dead code # dead code
@ -61,15 +59,17 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m
for c in node.context[self.KEY].calls: for c in node.context[self.KEY].calls:
if c == node: if c == node:
continue continue
read_then_written |= set( read_then_written |= {
[ FindingValue(
FindingValue( v,
v, node, tuple(sorted(nodes, key=lambda x: x.node_id)) node,
) tuple(sorted(nodes, key=lambda x: x.node_id)),
for (v, nodes) in node.context[self.KEY].written.items() )
if v in node.context[self.KEY].reads_prior_calls[c] for (v, nodes) in node.context[
] self.KEY
) ].written.items()
if v in node.context[self.KEY].reads_prior_calls[c]
}
if read_then_written: if read_then_written:
# calls are ordered # calls are ordered
@ -82,7 +82,7 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m
result[finding_key] |= set(read_then_written) result[finding_key] |= set(read_then_written)
return result return result
def _detect(self): def _detect(self): # pylint: disable=too-many-branches
""" """
""" """
super()._detect() super()._detect()
@ -91,12 +91,16 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m
results = [] results = []
result_sorted = sorted(list(reentrancies.items()), key=lambda x: x[0].function.name) result_sorted = sorted(
list(reentrancies.items()), key=lambda x: x[0].function.name
)
varsWritten: List[FindingValue] varsWritten: List[FindingValue]
for (func, calls, send_eth), varsWritten in result_sorted: for (func, calls, send_eth), varsWritten in result_sorted:
calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) calls = sorted(list(set(calls)), key=lambda x: x[0].node_id)
send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id) send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id)
varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) varsWritten = sorted(
varsWritten, key=lambda x: (x.variable.name, x.node.node_id)
)
info = ["Reentrancy in ", func, ":\n"] info = ["Reentrancy in ", func, ":\n"]
info += ["\tExternal calls:\n"] info += ["\tExternal calls:\n"]
@ -130,16 +134,22 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m
res.add(call_info, {"underlying_type": "external_calls"}) res.add(call_info, {"underlying_type": "external_calls"})
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_list_info,
{"underlying_type": "external_calls_sending_eth"},
)
# If the calls are not the same ones that send eth, add the eth sending nodes. # If the calls are not the same ones that send eth, add the eth sending nodes.
if calls != send_eth: if calls != send_eth:
for (call_info, calls_list) in send_eth: for (call_info, calls_list) in send_eth:
res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_info, {"underlying_type": "external_calls_sending_eth"}
)
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add( res.add(
call_list_info, {"underlying_type": "external_calls_sending_eth"} call_list_info,
{"underlying_type": "external_calls_sending_eth"},
) )
# Add all variables written via nodes which write them. # Add all variables written via nodes which write them.

@ -19,9 +19,7 @@ class ReentrancyEvent(Reentrancy):
IMPACT = DetectorClassification.LOW IMPACT = DetectorClassification.LOW
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-3"
"https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-3"
)
WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_TITLE = "Reentrancy vulnerabilities"
WIKI_DESCRIPTION = """ WIKI_DESCRIPTION = """
@ -60,19 +58,19 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w
calls=to_hashable(node.context[self.KEY].calls), calls=to_hashable(node.context[self.KEY].calls),
send_eth=to_hashable(node.context[self.KEY].send_eth), send_eth=to_hashable(node.context[self.KEY].send_eth),
) )
finding_vars = set( finding_vars = {
[ FindingValue(
FindingValue( e,
e, e.node, tuple(sorted(nodes, key=lambda x: x.node_id)) e.node,
) tuple(sorted(nodes, key=lambda x: x.node_id)),
for (e, nodes) in node.context[self.KEY].events.items() )
] for (e, nodes) in node.context[self.KEY].events.items()
) }
if finding_vars: if finding_vars:
result[finding_key] |= finding_vars result[finding_key] |= finding_vars
return result return result
def _detect(self): def _detect(self): # pylint: disable=too-many-branches
""" """
""" """
super()._detect() super()._detect()
@ -85,7 +83,9 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w
for (func, calls, send_eth), events in result_sorted: for (func, calls, send_eth), events in result_sorted:
calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) calls = sorted(list(set(calls)), key=lambda x: x[0].node_id)
send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id) send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id)
events = sorted(events, key=lambda x: (str(x.variable.name), x.node.node_id)) events = sorted(
events, key=lambda x: (str(x.variable.name), x.node.node_id)
)
info = ["Reentrancy in ", func, ":\n"] info = ["Reentrancy in ", func, ":\n"]
info += ["\tExternal calls:\n"] info += ["\tExternal calls:\n"]
@ -119,18 +119,24 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w
res.add(call_info, {"underlying_type": "external_calls"}) res.add(call_info, {"underlying_type": "external_calls"})
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_list_info,
{"underlying_type": "external_calls_sending_eth"},
)
# #
# If the calls are not the same ones that send eth, add the eth sending nodes. # If the calls are not the same ones that send eth, add the eth sending nodes.
if calls != send_eth: if calls != send_eth:
for (call_info, calls_list) in send_eth: for (call_info, calls_list) in send_eth:
res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_info, {"underlying_type": "external_calls_sending_eth"}
)
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add( res.add(
call_list_info, {"underlying_type": "external_calls_sending_eth"} call_list_info,
{"underlying_type": "external_calls_sending_eth"},
) )
for finding_value in events: for finding_value in events:

@ -23,9 +23,7 @@ class ReentrancyNoGas(Reentrancy):
IMPACT = DetectorClassification.INFORMATIONAL IMPACT = DetectorClassification.INFORMATIONAL
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-4"
"https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-4"
)
WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_TITLE = "Reentrancy vulnerabilities"
WIKI_DESCRIPTION = """ WIKI_DESCRIPTION = """
@ -71,25 +69,27 @@ Only report reentrancy that is based on `transfer` or `send`."""
calls=to_hashable(node.context[self.KEY].calls), calls=to_hashable(node.context[self.KEY].calls),
send_eth=to_hashable(node.context[self.KEY].send_eth), send_eth=to_hashable(node.context[self.KEY].send_eth),
) )
finding_vars = set( finding_vars = {
[ FindingValue(
FindingValue(v, node, tuple(sorted(nodes, key=lambda x: x.node_id))) v,
for (v, nodes) in node.context[self.KEY].written.items() node,
] tuple(sorted(nodes, key=lambda x: x.node_id)),
) )
finding_vars |= set( for (v, nodes) in node.context[self.KEY].written.items()
[ }
FindingValue( finding_vars |= {
e, e.node, tuple(sorted(nodes, key=lambda x: x.node_id)) FindingValue(
) e,
for (e, nodes) in node.context[self.KEY].events.items() e.node,
] tuple(sorted(nodes, key=lambda x: x.node_id)),
) )
for (e, nodes) in node.context[self.KEY].events.items()
}
if finding_vars: if finding_vars:
result[finding_key] |= finding_vars result[finding_key] |= finding_vars
return result return result
def _detect(self): def _detect(self): # pylint: disable=too-many-branches,too-many-locals
""" """
""" """
@ -123,7 +123,9 @@ Only report reentrancy that is based on `transfer` or `send`."""
for (v, node, nodes) in varsWrittenOrEvent for (v, node, nodes) in varsWrittenOrEvent
if isinstance(v, Variable) if isinstance(v, Variable)
] ]
varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) varsWritten = sorted(
varsWritten, key=lambda x: (x.variable.name, x.node.node_id)
)
if varsWritten: if varsWritten:
info += ["\tState variables written after the call(s):\n"] info += ["\tState variables written after the call(s):\n"]
for finding_value in varsWritten: for finding_value in varsWritten:
@ -157,18 +159,24 @@ Only report reentrancy that is based on `transfer` or `send`."""
res.add(call_info, {"underlying_type": "external_calls"}) res.add(call_info, {"underlying_type": "external_calls"})
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_list_info,
{"underlying_type": "external_calls_sending_eth"},
)
# #
# If the calls are not the same ones that send eth, add the eth sending nodes. # If the calls are not the same ones that send eth, add the eth sending nodes.
if calls != send_eth: if calls != send_eth:
for (call_info, calls_list) in send_eth: for (call_info, calls_list) in send_eth:
res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_info, {"underlying_type": "external_calls_sending_eth"}
)
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add( res.add(
call_list_info, {"underlying_type": "external_calls_sending_eth"} call_list_info,
{"underlying_type": "external_calls_sending_eth"},
) )
# Add all variables written via nodes which write them. # Add all variables written via nodes which write them.

@ -19,9 +19,7 @@ class ReentrancyReadBeforeWritten(Reentrancy):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-1"
"https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-1"
)
WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_TITLE = "Reentrancy vulnerabilities"
WIKI_DESCRIPTION = """ WIKI_DESCRIPTION = """
@ -45,26 +43,31 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`)."""
def find_reentrancies(self): def find_reentrancies(self):
result = defaultdict(set) result = defaultdict(set)
for contract in self.contracts: for contract in self.contracts: # pylint: disable=too-many-nested-blocks
for f in contract.functions_and_modifiers_declared: for f in contract.functions_and_modifiers_declared:
for node in f.nodes: for node in f.nodes:
# dead code # dead code
if self.KEY not in node.context: if self.KEY not in node.context:
continue continue
if node.context[self.KEY].calls and not node.context[self.KEY].send_eth: if (
node.context[self.KEY].calls
and not node.context[self.KEY].send_eth
):
read_then_written = set() read_then_written = set()
for c in node.context[self.KEY].calls: for c in node.context[self.KEY].calls:
if c == node: if c == node:
continue continue
read_then_written |= set( read_then_written |= {
[ FindingValue(
FindingValue( v,
v, node, tuple(sorted(nodes, key=lambda x: x.node_id)) node,
) tuple(sorted(nodes, key=lambda x: x.node_id)),
for (v, nodes) in node.context[self.KEY].written.items() )
if v in node.context[self.KEY].reads_prior_calls[c] for (v, nodes) in node.context[
] self.KEY
) ].written.items()
if v in node.context[self.KEY].reads_prior_calls[c]
}
# We found a potential re-entrancy bug # We found a potential re-entrancy bug
if read_then_written: if read_then_written:
@ -76,7 +79,7 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`)."""
result[finding_key] |= read_then_written result[finding_key] |= read_then_written
return result return result
def _detect(self): def _detect(self): # pylint: disable=too-many-branches
""" """
""" """
@ -85,10 +88,14 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`)."""
results = [] results = []
result_sorted = sorted(list(reentrancies.items()), key=lambda x: x[0].function.name) result_sorted = sorted(
list(reentrancies.items()), key=lambda x: x[0].function.name
)
for (func, calls), varsWritten in result_sorted: for (func, calls), varsWritten in result_sorted:
calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) calls = sorted(list(set(calls)), key=lambda x: x[0].node_id)
varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) varsWritten = sorted(
varsWritten, key=lambda x: (x.variable.name, x.node.node_id)
)
info = ["Reentrancy in ", func, ":\n"] info = ["Reentrancy in ", func, ":\n"]
@ -116,7 +123,10 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`)."""
res.add(call_info, {"underlying_type": "external_calls"}) res.add(call_info, {"underlying_type": "external_calls"})
for call_list_info in calls_list: for call_list_info in calls_list:
if call_list_info != call_info: if call_list_info != call_info:
res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) res.add(
call_list_info,
{"underlying_type": "external_calls_sending_eth"},
)
# Add all variables written via nodes which write them. # Add all variables written via nodes which write them.
for finding_value in varsWritten: for finding_value in varsWritten:

@ -6,6 +6,20 @@ Recursively check the called functions
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
def detect_shadowing(contract):
ret = []
variables_fathers = []
for father in contract.inheritance:
if all(not f.is_implemented for f in father.functions + father.modifiers):
variables_fathers += father.state_variables_declared
for var in contract.state_variables_declared:
shadow = [v for v in variables_fathers if v.name == var.name]
if shadow:
ret.append([var] + shadow)
return ret
class ShadowingAbstractDetection(AbstractDetector): class ShadowingAbstractDetection(AbstractDetector):
""" """
Shadowing detection Shadowing detection
@ -34,19 +48,6 @@ contract DerivedContract is BaseContract{
WIKI_RECOMMENDATION = "Remove the state variable shadowing." WIKI_RECOMMENDATION = "Remove the state variable shadowing."
def detect_shadowing(self, contract):
ret = []
variables_fathers = []
for father in contract.inheritance:
if all(not f.is_implemented for f in father.functions + father.modifiers):
variables_fathers += father.state_variables_declared
for var in contract.state_variables_declared:
shadow = [v for v in variables_fathers if v.name == var.name]
if shadow:
ret.append([var] + shadow)
return ret
def _detect(self): def _detect(self):
""" Detect shadowing """ Detect shadowing
@ -56,8 +57,8 @@ contract DerivedContract is BaseContract{
""" """
results = [] results = []
for c in self.contracts: for contract in self.contracts:
shadowing = self.detect_shadowing(c) shadowing = detect_shadowing(contract)
if shadowing: if shadowing:
for all_variables in shadowing: for all_variables in shadowing:
shadow = all_variables[0] shadow = all_variables[0]

@ -179,7 +179,10 @@ contract Bug {
shadow_type = shadow[0] shadow_type = shadow[0]
shadow_object = shadow[1] shadow_object = shadow[1]
info = [shadow_object, f' ({shadow_type}) shadows built-in symbol"\n'] info = [
shadow_object,
f' ({shadow_type}) shadows built-in symbol"\n',
]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -47,7 +47,7 @@ contract Bug {
OVERSHADOWED_STATE_VARIABLE = "state variable" OVERSHADOWED_STATE_VARIABLE = "state variable"
OVERSHADOWED_EVENT = "event" OVERSHADOWED_EVENT = "event"
def detect_shadowing_definitions(self, contract): def detect_shadowing_definitions(self, contract): # pylint: disable=too-many-branches
""" Detects if functions, access modifiers, events, state variables, and local variables are named after """ Detects if functions, access modifiers, events, state variables, and local variables are named after
reserved keywords. Any such definitions are returned in a list. reserved keywords. Any such definitions are returned in a list.
@ -68,11 +68,15 @@ contract Bug {
# Check functions # Check functions
for scope_function in scope_contract.functions_declared: for scope_function in scope_contract.functions_declared:
if variable.name == scope_function.name: if variable.name == scope_function.name:
overshadowed.append((self.OVERSHADOWED_FUNCTION, scope_function)) overshadowed.append(
(self.OVERSHADOWED_FUNCTION, scope_function)
)
# Check modifiers # Check modifiers
for scope_modifier in scope_contract.modifiers_declared: for scope_modifier in scope_contract.modifiers_declared:
if variable.name == scope_modifier.name: if variable.name == scope_modifier.name:
overshadowed.append((self.OVERSHADOWED_MODIFIER, scope_modifier)) overshadowed.append(
(self.OVERSHADOWED_MODIFIER, scope_modifier)
)
# Check events # Check events
for scope_event in scope_contract.events_declared: for scope_event in scope_contract.events_declared:
if variable.name == scope_event.name: if variable.name == scope_event.name:
@ -108,7 +112,11 @@ contract Bug {
overshadowed = shadow[1] overshadowed = shadow[1]
info = [local_variable, " shadows:\n"] info = [local_variable, " shadows:\n"]
for overshadowed_entry in overshadowed: for overshadowed_entry in overshadowed:
info += ["\t- ", overshadowed_entry[1], f" ({overshadowed_entry[0]})\n"] info += [
"\t- ",
overshadowed_entry[1],
f" ({overshadowed_entry[0]})\n",
]
# Generate relevant JSON data for this shadowing definition. # Generate relevant JSON data for this shadowing definition.
res = self.generate_result(info) res = self.generate_result(info)

@ -5,6 +5,20 @@ Module detecting shadowing of state variables
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
def detect_shadowing(contract):
ret = []
variables_fathers = []
for father in contract.inheritance:
if any(f.is_implemented for f in father.functions + father.modifiers):
variables_fathers += father.state_variables_declared
for var in contract.state_variables_declared:
shadow = [v for v in variables_fathers if v.name == var.name]
if shadow:
ret.append([var] + shadow)
return ret
class StateShadowing(AbstractDetector): class StateShadowing(AbstractDetector):
""" """
Shadowing of state variable Shadowing of state variable
@ -47,19 +61,6 @@ contract DerivedContract is BaseContract{
WIKI_RECOMMENDATION = "Remove the state variable shadowing." WIKI_RECOMMENDATION = "Remove the state variable shadowing."
def detect_shadowing(self, contract):
ret = []
variables_fathers = []
for father in contract.inheritance:
if any(f.is_implemented for f in father.functions + father.modifiers):
variables_fathers += father.state_variables_declared
for var in contract.state_variables_declared:
shadow = [v for v in variables_fathers if v.name == var.name]
if shadow:
ret.append([var] + shadow)
return ret
def _detect(self): def _detect(self):
""" Detect shadowing """ Detect shadowing
@ -70,7 +71,7 @@ contract DerivedContract is BaseContract{
""" """
results = [] results = []
for c in self.contracts: for c in self.contracts:
shadowing = self.detect_shadowing(c) shadowing = detect_shadowing(c)
if shadowing: if shadowing:
for all_variables in shadowing: for all_variables in shadowing:
shadow = all_variables[0] shadow = all_variables[0]

@ -42,14 +42,16 @@ As a result, the second contract cannot be analyzed.
""" """
WIKI_RECOMMENDATION = "Rename the contract." WIKI_RECOMMENDATION = "Rename the contract."
def _detect(self): def _detect(self): # pylint: disable=too-many-locals,too-many-branches
results = [] results = []
names_reused = self.slither.contract_name_collisions names_reused = self.slither.contract_name_collisions
# First show the contracts that we know are missing # First show the contracts that we know are missing
incorrectly_constructed = [ incorrectly_constructed = [
contract for contract in self.contracts if contract.is_incorrectly_constructed contract
for contract in self.contracts
if contract.is_incorrectly_constructed
] ]
inheritance_corrupted = defaultdict(list) inheritance_corrupted = defaultdict(list)
@ -66,7 +68,9 @@ As a result, the second contract cannot be analyzed.
info += ["\t- ", file, "\n"] info += ["\t- ", file, "\n"]
if contract_name in inheritance_corrupted: if contract_name in inheritance_corrupted:
info += ["\tAs a result, the inherited contracts are not correctly analyzed:\n"] info += [
"\tAs a result, the inherited contracts are not correctly analyzed:\n"
]
for corrupted in inheritance_corrupted[contract_name]: for corrupted in inheritance_corrupted[contract_name]:
info += ["\t\t- ", corrupted, "\n"] info += ["\t\t- ", corrupted, "\n"]
res = self.generate_result(info) res = self.generate_result(info)
@ -79,14 +83,18 @@ As a result, the second contract cannot be analyzed.
for b in most_base_with_missing_inheritance: for b in most_base_with_missing_inheritance:
info = [b, " inherits from a contract for which the name is reused.\n"] info = [b, " inherits from a contract for which the name is reused.\n"]
if b.inheritance: if b.inheritance:
info += ["\t- Slither could not determine which contract has a duplicate name:\n"] info += [
"\t- Slither could not determine which contract has a duplicate name:\n"
]
for inheritance in b.inheritance: for inheritance in b.inheritance:
info += ["\t\t-", inheritance, "\n"] info += ["\t\t-", inheritance, "\n"]
info += ["\t- Check if:\n"] info += ["\t- Check if:\n"]
info += ["\t\t- A inherited contract is missing from this list,\n"] info += ["\t\t- A inherited contract is missing from this list,\n"]
info += ["\t\t- The contract are imported from the correct files.\n"] info += ["\t\t- The contract are imported from the correct files.\n"]
if b.derived_contracts: if b.derived_contracts:
info += [f"\t- This issue impacts the contracts inheriting from {b.name}:\n"] info += [
f"\t- This issue impacts the contracts inheriting from {b.name}:\n"
]
for derived in b.derived_contracts: for derived in b.derived_contracts:
info += ["\t\t-", derived, "\n"] info += ["\t\t-", derived, "\n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -1,5 +1,5 @@
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
import re import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
class RightToLeftOverride(AbstractDetector): class RightToLeftOverride(AbstractDetector):
@ -65,25 +65,27 @@ contract Token
# If we couldn't find the character in the remainder of source, stop. # If we couldn't find the character in the remainder of source, stop.
if result_index == -1: if result_index == -1:
break break
else:
# We found another instance of the character, define our output # We found another instance of the character, define our output
idx = start_index + result_index idx = start_index + result_index
relative = self.slither.crytic_compile.filename_lookup(filename).relative relative = self.slither.crytic_compile.filename_lookup(
info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" filename
).relative
# We have a patch, so pattern.find will return at least one result info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n"
info += f"\t- {pattern.findall(source_encoded)[0]}\n" # We have a patch, so pattern.find will return at least one result
res = self.generate_result(info)
res.add_other( info += f"\t- {pattern.findall(source_encoded)[0]}\n"
"rtlo-character", res = self.generate_result(info)
(filename, idx, len(self.RTLO_CHARACTER_ENCODED)), res.add_other(
self.slither, "rtlo-character",
) (filename, idx, len(self.RTLO_CHARACTER_ENCODED)),
results.append(res) self.slither,
)
# Advance the start index for the next iteration results.append(res)
start_index = result_index + 1
# Advance the start index for the next iteration
start_index = result_index + 1
return results return results

@ -16,7 +16,9 @@ class Assembly(AbstractDetector):
IMPACT = DetectorClassification.INFORMATIONAL IMPACT = DetectorClassification.INFORMATIONAL
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#assembly-usage" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#assembly-usage"
)
WIKI_TITLE = "Assembly usage" WIKI_TITLE = "Assembly usage"
WIKI_DESCRIPTION = "The use of assembly is error-prone and should be avoided." WIKI_DESCRIPTION = "The use of assembly is error-prone and should be avoided."

@ -3,7 +3,10 @@ Module detecting misuse of Boolean constants
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import Assignment, Call, Return, InitArray, Binary, BinaryType from slither.slithir.operations import (
Binary,
BinaryType,
)
from slither.slithir.variables import Constant from slither.slithir.variables import Constant
@ -17,7 +20,9 @@ class BooleanEquality(AbstractDetector):
IMPACT = DetectorClassification.INFORMATIONAL IMPACT = DetectorClassification.INFORMATIONAL
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#boolean-equality" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#boolean-equality"
)
WIKI_TITLE = "Boolean equality" WIKI_TITLE = "Boolean equality"
WIKI_DESCRIPTION = """Detects the comparison to boolean constants.""" WIKI_DESCRIPTION = """Detects the comparison to boolean constants."""
@ -44,7 +49,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o
results = [] results = []
# Loop for each function and modifier. # Loop for each function and modifier.
for function in contract.functions_and_modifiers_declared: for function in contract.functions_and_modifiers_declared: # pylint: disable=too-many-nested-blocks
f_results = set() f_results = set()
# Loop for every node in this function, looking for boolean constants # Loop for every node in this function, looking for boolean constants
@ -54,7 +59,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o
if ir.type in [BinaryType.EQUAL, BinaryType.NOT_EQUAL]: if ir.type in [BinaryType.EQUAL, BinaryType.NOT_EQUAL]:
for r in ir.read: for r in ir.read:
if isinstance(r, Constant): if isinstance(r, Constant):
if type(r.value) is bool: if isinstance(r.value, bool):
f_results.add(node) f_results.add(node)
results.append((function, f_results)) results.append((function, f_results))
@ -71,7 +76,12 @@ Boolean constants can be used directly and do not need to be compare to `true` o
if boolean_constant_misuses: if boolean_constant_misuses:
for (func, nodes) in boolean_constant_misuses: for (func, nodes) in boolean_constant_misuses:
for node in nodes: for node in nodes:
info = [func, " compares to a boolean constant:\n\t-", node, "\n"] info = [
func,
" compares to a boolean constant:\n\t-",
node,
"\n",
]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -26,9 +26,7 @@ class BooleanConstantMisuse(AbstractDetector):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#misuse-of-a-boolean-constant"
"https://github.com/crytic/slither/wiki/Detector-Documentation#misuse-of-a-boolean-constant"
)
WIKI_TITLE = "Misuse of a Boolean constant" WIKI_TITLE = "Misuse of a Boolean constant"
WIKI_DESCRIPTION = """Detects the misuse of a Boolean constant.""" WIKI_DESCRIPTION = """Detects the misuse of a Boolean constant."""
@ -56,7 +54,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or
WIKI_RECOMMENDATION = """Verify and simplify the condition.""" WIKI_RECOMMENDATION = """Verify and simplify the condition."""
@staticmethod @staticmethod
def _detect_boolean_constant_misuses(contract): def _detect_boolean_constant_misuses(contract): # pylint: disable=too-many-branches
""" """
Detects and returns all nodes which misuse a Boolean constant. Detects and returns all nodes which misuse a Boolean constant.
:param contract: Contract to detect assignment within. :param contract: Contract to detect assignment within.
@ -67,7 +65,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or
results = [] results = []
# Loop for each function and modifier. # Loop for each function and modifier.
for function in contract.functions_declared: for function in contract.functions_declared: # pylint: disable=too-many-nested-blocks
f_results = set() f_results = set()
# Loop for every node in this function, looking for boolean constants # Loop for every node in this function, looking for boolean constants
@ -88,13 +86,17 @@ Other uses (in complex expressions, as conditionals) indicate either an error or
# It's ok to use a bare boolean constant in these contexts # It's ok to use a bare boolean constant in these contexts
continue continue
if isinstance(ir, Binary): if isinstance(ir, Binary):
if ir.type in [BinaryType.ADDITION, BinaryType.EQUAL, BinaryType.NOT_EQUAL]: if ir.type in [
BinaryType.ADDITION,
BinaryType.EQUAL,
BinaryType.NOT_EQUAL,
]:
# Comparing to a Boolean constant is dubious style, but harmless # Comparing to a Boolean constant is dubious style, but harmless
# Equal is catch by another detector (informational severity) # Equal is catch by another detector (informational severity)
continue continue
for r in ir.read: for r in ir.read:
if isinstance(r, Constant): if isinstance(r, Constant):
if type(r.value) is bool: if isinstance(r.value, bool):
f_results.add(node) f_results.add(node)
results.append((function, f_results)) results.append((function, f_results))
@ -111,7 +113,12 @@ Other uses (in complex expressions, as conditionals) indicate either an error or
if boolean_constant_misuses: if boolean_constant_misuses:
for (func, nodes) in boolean_constant_misuses: for (func, nodes) in boolean_constant_misuses:
for node in nodes: for node in nodes:
info = [func, " uses a Boolean constant improperly:\n\t-", node, "\n"] info = [
func,
" uses a Boolean constant improperly:\n\t-",
node,
"\n",
]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -1,13 +1,15 @@
"""
"""
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import HighLevelCall, LibraryCall, LowLevelCall, Send, Transfer from slither.slithir.operations import (
HighLevelCall,
LibraryCall,
LowLevelCall,
Send,
Transfer,
)
class MultipleCallsInLoop(AbstractDetector): class MultipleCallsInLoop(AbstractDetector):
"""
"""
ARGUMENT = "calls-loop" ARGUMENT = "calls-loop"
HELP = "Multiple calls in a loop" HELP = "Multiple calls in a loop"

@ -3,9 +3,20 @@ from slither.slithir.operations import LowLevelCall
from slither.analyses.data_dependency.data_dependency import is_tainted from slither.analyses.data_dependency.data_dependency import is_tainted
def controlled_delegatecall(function):
ret = []
for node in function.nodes:
for ir in node.irs:
if isinstance(ir, LowLevelCall) and ir.function_name in [
"delegatecall",
"callcode",
]:
if is_tainted(ir.destination, function.contract):
ret.append(node)
return ret
class ControlledDelegateCall(AbstractDetector): class ControlledDelegateCall(AbstractDetector):
"""
"""
ARGUMENT = "controlled-delegatecall" ARGUMENT = "controlled-delegatecall"
HELP = "Controlled delegatecall destination" HELP = "Controlled delegatecall destination"
@ -15,7 +26,9 @@ class ControlledDelegateCall(AbstractDetector):
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#controlled-delegatecall" WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#controlled-delegatecall"
WIKI_TITLE = "Controlled Delegatecall" WIKI_TITLE = "Controlled Delegatecall"
WIKI_DESCRIPTION = "`Delegatecall` or `callcode` to an address controlled by the user." WIKI_DESCRIPTION = (
"`Delegatecall` or `callcode` to an address controlled by the user."
)
WIKI_EXPLOIT_SCENARIO = """ WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Delegatecall{ contract Delegatecall{
@ -28,18 +41,6 @@ Bob calls `delegate` and delegates the execution to his malicious contract. As a
WIKI_RECOMMENDATION = "Avoid using `delegatecall`. Use only trusted destinations." WIKI_RECOMMENDATION = "Avoid using `delegatecall`. Use only trusted destinations."
def controlled_delegatecall(self, function):
ret = []
for node in function.nodes:
for ir in node.irs:
if isinstance(ir, LowLevelCall) and ir.function_name in [
"delegatecall",
"callcode",
]:
if is_tainted(ir.destination, function.contract):
ret.append(node)
return ret
def _detect(self): def _detect(self):
results = [] results = []
@ -49,9 +50,12 @@ Bob calls `delegate` and delegates the execution to his malicious contract. As a
# As functions to upgrades the destination lead to too many FPs # As functions to upgrades the destination lead to too many FPs
if contract.is_upgradeable_proxy and f.is_protected(): if contract.is_upgradeable_proxy and f.is_protected():
continue continue
nodes = self.controlled_delegatecall(f) nodes = controlled_delegatecall(f)
if nodes: if nodes:
func_info = [f, " uses delegatecall to a input-controlled function id\n"] func_info = [
f,
" uses delegatecall to a input-controlled function id\n",
]
for node in nodes: for node in nodes:
node_info = func_info + ["\t- ", node, "\n"] node_info = func_info + ["\t- ", node, "\n"]

@ -2,12 +2,15 @@
Module detecting deprecated standards. Module detecting deprecated standards.
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.visitors.expression.export_values import ExportValues
from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityFunction
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
from slither.core.declarations.solidity_variables import (
SolidityVariableComposed,
SolidityFunction,
)
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import LowLevelCall from slither.slithir.operations import LowLevelCall
from slither.solc_parsing.variables.state_variable import StateVariableSolc, StateVariable from slither.visitors.expression.export_values import ExportValues
# Reference: https://smartcontractsecurity.github.io/SWC-registry/docs/SWC-111 # Reference: https://smartcontractsecurity.github.io/SWC-registry/docs/SWC-111
class DeprecatedStandards(AbstractDetector): class DeprecatedStandards(AbstractDetector):
@ -124,7 +127,7 @@ contract ContractWithDeprecatedReferences {
results.append((state_variable, deprecated_results)) results.append((state_variable, deprecated_results))
# Loop through all functions + modifiers in this contract. # Loop through all functions + modifiers in this contract.
for function in contract.functions_and_modifiers_declared: for function in contract.functions_and_modifiers_declared: # pylint: disable=too-many-nested-blocks
# Loop through each node in this function. # Loop through each node in this function.
for node in function.nodes: for node in function.nodes:
# Detect deprecated references in the node. # Detect deprecated references in the node.
@ -153,14 +156,16 @@ contract ContractWithDeprecatedReferences {
""" """
results = [] results = []
for contract in self.contracts: for contract in self.contracts:
deprecated_references = self.detect_deprecated_references_in_contract(contract) deprecated_references = self.detect_deprecated_references_in_contract(
contract
)
if deprecated_references: if deprecated_references:
for deprecated_reference in deprecated_references: for deprecated_reference in deprecated_references:
source_object = deprecated_reference[0] source_object = deprecated_reference[0]
deprecated_entries = deprecated_reference[1] deprecated_entries = deprecated_reference[1]
info = ["Deprecated standard detected ", source_object, ":\n"] info = ["Deprecated standard detected ", source_object, ":\n"]
for (dep_id, original_desc, recommended_disc) in deprecated_entries: for (_dep_id, original_desc, recommended_disc) in deprecated_entries:
info += [ info += [
f'\t- Usage of "{original_desc}" should be replaced with "{recommended_disc}"\n' f'\t- Usage of "{original_desc}" should be replaced with "{recommended_disc}"\n'
] ]

@ -79,7 +79,7 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl
WIKI_RECOMMENDATION = """Consider ordering multiplication before division.""" WIKI_RECOMMENDATION = """Consider ordering multiplication before division."""
def _explore(self, node, explored, f_results, divisions): def _explore(self, node, explored, f_results, divisions): # pylint: disable=too-many-branches
if node in explored: if node in explored:
return return
explored.add(node) explored.add(node)
@ -111,7 +111,9 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl
if node in divisions[r]: if node in divisions[r]:
nodes += [n for n in divisions[r] if n not in nodes] nodes += [n for n in divisions[r] if n not in nodes]
else: else:
nodes += [n for n in divisions[r] + [node] if n not in nodes] nodes += [
n for n in divisions[r] + [node] if n not in nodes
]
if nodes: if nodes:
node_results = nodes node_results = nodes
@ -170,11 +172,16 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl
""" """
results = [] results = []
for contract in self.contracts: for contract in self.contracts:
divisions_before_multiplications = self.detect_divide_before_multiply(contract) divisions_before_multiplications = self.detect_divide_before_multiply(
contract
)
if divisions_before_multiplications: if divisions_before_multiplications:
for (func, nodes) in divisions_before_multiplications: for (func, nodes) in divisions_before_multiplications:
info = [func, " performs a multiplication on the result of a division:\n"] info = [
func,
" performs a multiplication on the result of a division:\n",
]
for node in nodes: for node in nodes:
info += ["\t-", node, "\n"] info += ["\t-", node, "\n"]

@ -6,12 +6,21 @@
from slither.analyses.data_dependency.data_dependency import is_dependent_ssa from slither.analyses.data_dependency.data_dependency import is_dependent_ssa
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import Assignment, Balance, Binary, BinaryType, HighLevelCall from slither.slithir.operations import (
Assignment,
Balance,
Binary,
BinaryType,
HighLevelCall,
)
from slither.core.solidity_types import MappingType, ElementaryType from slither.core.solidity_types import MappingType, ElementaryType
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.declarations.solidity_variables import SolidityVariable, SolidityVariableComposed from slither.core.declarations.solidity_variables import (
SolidityVariable,
SolidityVariableComposed,
)
class IncorrectStrictEquality(AbstractDetector): class IncorrectStrictEquality(AbstractDetector):
@ -20,12 +29,12 @@ class IncorrectStrictEquality(AbstractDetector):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-strict-equalities"
"https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-strict-equalities"
)
WIKI_TITLE = "Dangerous strict equalities" WIKI_TITLE = "Dangerous strict equalities"
WIKI_DESCRIPTION = "Use of strict equalities that can be easily manipulated by an attacker." WIKI_DESCRIPTION = (
"Use of strict equalities that can be easily manipulated by an attacker."
)
WIKI_EXPLOIT_SCENARIO = """ WIKI_EXPLOIT_SCENARIO = """
```solidity ```solidity
contract Crowdsale{ contract Crowdsale{
@ -36,9 +45,7 @@ contract Crowdsale{
`Crowdsale` relies on `fund_reached` to know when to stop the sale of tokens. `Crowdsale` relies on `fund_reached` to know when to stop the sale of tokens.
`Crowdsale` reaches 100 Ether. Bob sends 0.1 Ether. As a result, `fund_reached` is always false and the `crowdsale` never ends.""" `Crowdsale` reaches 100 Ether. Bob sends 0.1 Ether. As a result, `fund_reached` is always false and the `crowdsale` never ends."""
WIKI_RECOMMENDATION = ( WIKI_RECOMMENDATION = """Don't use strict equality to determine if an account has enough Ether or tokens."""
"""Don't use strict equality to determine if an account has enough Ether or tokens."""
)
sources_taint = [ sources_taint = [
SolidityVariable("now"), SolidityVariable("now"),
@ -98,7 +105,9 @@ contract Crowdsale{
for ir in node.irs_ssa: for ir in node.irs_ssa:
# Filter to only tainted equality (==) comparisons # Filter to only tainted equality (==) comparisons
if self.is_direct_comparison(ir) and self.is_any_tainted(ir.used, taints, func): if self.is_direct_comparison(ir) and self.is_any_tainted(
ir.used, taints, func
):
if func not in results: if func not in results:
results[func] = [] results[func] = []
results[func].append(node) results[func].append(node)
@ -133,7 +142,7 @@ contract Crowdsale{
# Output each node with the function info header as a separate result. # Output each node with the function info header as a separate result.
for node in nodes: for node in nodes:
node_info = func_info + [f"\t- ", node, "\n"] node_info = func_info + ["\t- ", node, "\n"]
res = self.generate_result(node_info) res = self.generate_result(node_info)
results.append(res) results.append(res)

@ -16,7 +16,9 @@ class TooManyDigits(AbstractDetector):
IMPACT = DetectorClassification.INFORMATIONAL IMPACT = DetectorClassification.INFORMATIONAL
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#too-many-digits" WIKI = (
"https://github.com/crytic/slither/wiki/Detector-Documentation#too-many-digits"
)
WIKI_TITLE = "Too many digits" WIKI_TITLE = "Too many digits"
WIKI_DESCRIPTION = """ WIKI_DESCRIPTION = """
Literals with many digits are difficult to read and review. Literals with many digits are difficult to read and review.

@ -15,9 +15,7 @@ class TxOrigin(AbstractDetector):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-usage-of-txorigin"
"https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-usage-of-txorigin"
)
WIKI_TITLE = "Dangerous usage of `tx.origin`" WIKI_TITLE = "Dangerous usage of `tx.origin`"
WIKI_DESCRIPTION = "`tx.origin`-based protection can be abused by a malicious contract if a legitimate user interacts with the malicious contract." WIKI_DESCRIPTION = "`tx.origin`-based protection can be abused by a malicious contract if a legitimate user interacts with the malicious contract."

@ -8,6 +8,56 @@ from slither.slithir.variables import Constant
from slither.core.solidity_types.elementary_type import Int, Uint from slither.core.solidity_types.elementary_type import Int, Uint
def typeRange(t):
bits = int(t.split("int")[1])
if t in Uint:
return 0, (2 ** bits) - 1
if t in Int:
v = (2 ** (bits - 1)) - 1
return -v, v
return None
def _detect_tautology_or_contradiction(low, high, cval, op):
"""
Return true if "[low high] op cval " is always true or always false
:param low:
:param high:
:param cval:
:param op:
:return:
"""
if op == BinaryType.LESS:
# a < cval
# its a tautology if
# high(a) < cval
# its a contradiction if
# low(a) >= cval
return high < cval or low >= cval
if op == BinaryType.GREATER:
# a > cval
# its a tautology if
# low(a) > cval
# its a contradiction if
# high(a) <= cval
return low > cval or high <= cval
if op == BinaryType.LESS_EQUAL:
# a <= cval
# its a tautology if
# high(a) <= cval
# its a contradiction if
# low(a) > cval
return (high <= cval) or (low > cval)
if op == BinaryType.GREATER_EQUAL:
# a >= cval
# its a tautology if
# low(a) >= cval
# its a contradiction if
# high(a) < cval
return (low >= cval) or (high < cval)
return False
class TypeBasedTautology(AbstractDetector): class TypeBasedTautology(AbstractDetector):
""" """
Type-based tautology or contradiction Type-based tautology or contradiction
@ -18,9 +68,7 @@ class TypeBasedTautology(AbstractDetector):
IMPACT = DetectorClassification.MEDIUM IMPACT = DetectorClassification.MEDIUM
CONFIDENCE = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH
WIKI = ( WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#tautology-or-contradiction"
"https://github.com/crytic/slither/wiki/Detector-Documentation#tautology-or-contradiction"
)
WIKI_TITLE = "Tautology or contradiction" WIKI_TITLE = "Tautology or contradiction"
WIKI_DESCRIPTION = """Detects expressions that are tautologies or contradictions.""" WIKI_DESCRIPTION = """Detects expressions that are tautologies or contradictions."""
@ -50,14 +98,6 @@ contract A {
"""Fix the incorrect comparison by changing the value type or the comparison.""" """Fix the incorrect comparison by changing the value type or the comparison."""
) )
def typeRange(self, t):
bits = int(t.split("int")[1])
if t in Uint:
return (0, (2 ** bits) - 1)
if t in Int:
v = (2 ** (bits - 1)) - 1
return (-v, v)
flip_table = { flip_table = {
BinaryType.GREATER: BinaryType.LESS, BinaryType.GREATER: BinaryType.LESS,
BinaryType.GREATER_EQUAL: BinaryType.LESS_EQUAL, BinaryType.GREATER_EQUAL: BinaryType.LESS_EQUAL,
@ -65,45 +105,6 @@ contract A {
BinaryType.LESS_EQUAL: BinaryType.GREATER_EQUAL, BinaryType.LESS_EQUAL: BinaryType.GREATER_EQUAL,
} }
def _detect_tautology_or_contradiction(self, low, high, cval, op):
"""
Return true if "[low high] op cval " is always true or always false
:param low:
:param high:
:param cval:
:param op:
:return:
"""
if op == BinaryType.LESS:
# a < cval
# its a tautology if
# high(a) < cval
# its a contradiction if
# low(a) >= cval
return high < cval or low >= cval
elif op == BinaryType.GREATER:
# a > cval
# its a tautology if
# low(a) > cval
# its a contradiction if
# high(a) <= cval
return low > cval or high <= cval
elif op == BinaryType.LESS_EQUAL:
# a <= cval
# its a tautology if
# high(a) <= cval
# its a contradiction if
# low(a) > cval
return (high <= cval) or (low > cval)
elif op == BinaryType.GREATER_EQUAL:
# a >= cval
# its a tautology if
# low(a) >= cval
# its a contradiction if
# high(a) < cval
return (low >= cval) or (high < cval)
return False
def detect_type_based_tautologies(self, contract): def detect_type_based_tautologies(self, contract):
""" """
Detects and returns all nodes with tautology/contradiction comparisons (based on type alone). Detects and returns all nodes with tautology/contradiction comparisons (based on type alone).
@ -116,7 +117,7 @@ contract A {
allInts = Int + Uint allInts = Int + Uint
# Loop for each function and modifier. # Loop for each function and modifier.
for function in contract.functions_declared: for function in contract.functions_declared: # pylint: disable=too-many-nested-blocks
f_results = set() f_results = set()
for node in function.nodes: for node in function.nodes:
@ -127,8 +128,8 @@ contract A {
cval = ir.variable_left.value cval = ir.variable_left.value
rtype = str(ir.variable_right.type) rtype = str(ir.variable_right.type)
if rtype in allInts: if rtype in allInts:
(low, high) = self.typeRange(rtype) (low, high) = typeRange(rtype)
if self._detect_tautology_or_contradiction( if _detect_tautology_or_contradiction(
low, high, cval, self.flip_table[ir.type] low, high, cval, self.flip_table[ir.type]
): ):
f_results.add(node) f_results.add(node)
@ -137,8 +138,8 @@ contract A {
cval = ir.variable_right.value cval = ir.variable_right.value
ltype = str(ir.variable_left.type) ltype = str(ir.variable_left.type)
if ltype in allInts: if ltype in allInts:
(low, high) = self.typeRange(ltype) (low, high) = typeRange(ltype)
if self._detect_tautology_or_contradiction( if _detect_tautology_or_contradiction(
low, high, cval, ir.type low, high, cval, ir.type
): ):
f_results.add(node) f_results.add(node)

@ -7,7 +7,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.export_values import ExportValues
from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.formatters.variables.possible_const_state_variables import format from slither.formatters.variables.possible_const_state_variables import custom_format
class ConstCandidateStateVars(AbstractDetector): class ConstCandidateStateVars(AbstractDetector):
@ -26,8 +26,12 @@ class ConstCandidateStateVars(AbstractDetector):
WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#state-variables-that-could-be-declared-constant" WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#state-variables-that-could-be-declared-constant"
WIKI_TITLE = "State variables that could be declared constant" WIKI_TITLE = "State variables that could be declared constant"
WIKI_DESCRIPTION = "Constant state variables should be declared constant to save gas." WIKI_DESCRIPTION = (
WIKI_RECOMMENDATION = "Add the `constant` attributes to state variables that never change." "Constant state variables should be declared constant to save gas."
)
WIKI_RECOMMENDATION = (
"Add the `constant` attributes to state variables that never change."
)
@staticmethod @staticmethod
def _valid_candidate(v): def _valid_candidate(v):
@ -61,7 +65,10 @@ class ConstCandidateStateVars(AbstractDetector):
if not values: if not values:
return True return True
if all( if all(
(val in self.valid_solidity_function or self._is_constant_var(val) for val in values) (
val in self.valid_solidity_function or self._is_constant_var(val)
for val in values
)
): ):
return True return True
return False return False
@ -72,18 +79,18 @@ class ConstCandidateStateVars(AbstractDetector):
results = [] results = []
all_variables = [c.state_variables for c in self.slither.contracts] all_variables = [c.state_variables for c in self.slither.contracts]
all_variables = set([item for sublist in all_variables for item in sublist]) all_variables = {item for sublist in all_variables for item in sublist}
all_non_constant_elementary_variables = set( all_non_constant_elementary_variables = {v for v in all_variables if self._valid_candidate(v)}
[v for v in all_variables if self._valid_candidate(v)]
)
all_functions = [c.all_functions_called for c in self.slither.contracts] all_functions = [c.all_functions_called for c in self.slither.contracts]
all_functions = list(set([item for sublist in all_functions for item in sublist])) all_functions = list({item for sublist in all_functions for item in sublist})
all_variables_written = [ all_variables_written = [
f.state_variables_written for f in all_functions if not f.is_constructor_variables f.state_variables_written
for f in all_functions
if not f.is_constructor_variables
] ]
all_variables_written = set([item for sublist in all_variables_written for item in sublist]) all_variables_written = {item for sublist in all_variables_written for item in sublist}
constable_variables = [ constable_variables = [
v v
@ -91,7 +98,9 @@ class ConstCandidateStateVars(AbstractDetector):
if (not v in all_variables_written) and self._constant_initial_expression(v) if (not v in all_variables_written) and self._constant_initial_expression(v)
] ]
# Order for deterministic results # Order for deterministic results
constable_variables = sorted(constable_variables, key=lambda x: x.canonical_name) constable_variables = sorted(
constable_variables, key=lambda x: x.canonical_name
)
# Create a result for each finding # Create a result for each finding
for v in constable_variables: for v in constable_variables:
@ -103,4 +112,4 @@ class ConstCandidateStateVars(AbstractDetector):
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -9,8 +9,6 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
class UninitializedLocalVars(AbstractDetector): class UninitializedLocalVars(AbstractDetector):
"""
"""
ARGUMENT = "uninitialized-local" ARGUMENT = "uninitialized-local"
HELP = "Uninitialized local variables" HELP = "Uninitialized local variables"
@ -55,7 +53,9 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is
else: else:
self.visited_all_paths[node] = [] self.visited_all_paths[node] = []
self.visited_all_paths[node] = list(set(self.visited_all_paths[node] + fathers_context)) self.visited_all_paths[node] = list(
set(self.visited_all_paths[node] + fathers_context)
)
if self.key in node.context: if self.key in node.context:
fathers_context += node.context[self.key] fathers_context += node.context[self.key]
@ -66,7 +66,9 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is
self.results.append((function, uninitialized_local_variable)) self.results.append((function, uninitialized_local_variable))
# Only save the local variables that are not yet written # Only save the local variables that are not yet written
uninitialized_local_variables = list(set(fathers_context) - set(node.variables_written)) uninitialized_local_variables = list(
set(fathers_context) - set(node.variables_written)
)
node.context[self.key] = uninitialized_local_variables node.context[self.key] = uninitialized_local_variables
for son in node.sons: for son in node.sons:
@ -81,6 +83,7 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is
""" """
results = [] results = []
# pylint: disable=attribute-defined-outside-init
self.results = [] self.results = []
self.visited_all_paths = {} self.visited_all_paths = {}
@ -91,14 +94,21 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is
continue continue
# dont consider storage variable, as they are detected by another detector # dont consider storage variable, as they are detected by another detector
uninitialized_local_variables = [ uninitialized_local_variables = [
v for v in function.local_variables if not v.is_storage and v.uninitialized v
for v in function.local_variables
if not v.is_storage and v.uninitialized
] ]
function.entry_point.context[self.key] = uninitialized_local_variables function.entry_point.context[
self.key
] = uninitialized_local_variables
self._detect_uninitialized(function, function.entry_point, []) self._detect_uninitialized(function, function.entry_point, [])
all_results = list(set(self.results)) all_results = list(set(self.results))
for (function, uninitialized_local_variable) in all_results: for (function, uninitialized_local_variable) in all_results:
info = [uninitialized_local_variable, " is a local variable never initialized\n"] info = [
uninitialized_local_variable,
" is a local variable never initialized\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -47,11 +47,12 @@ Initialize all the variables. If a variable is meant to be initialized to zero,
@staticmethod @staticmethod
def _written_variables(contract): def _written_variables(contract):
ret = [] ret = []
# pylint: disable=too-many-nested-blocks
for f in contract.all_functions_called + contract.modifiers: for f in contract.all_functions_called + contract.modifiers:
for n in f.nodes: for n in f.nodes:
ret += n.state_variables_written ret += n.state_variables_written
for ir in n.irs: for ir in n.irs:
if isinstance(ir, LibraryCall) or isinstance(ir, InternalCall): if isinstance(ir, (LibraryCall, InternalCall)):
idx = 0 idx = 0
if ir.function: if ir.function:
for param in ir.function.parameters: for param in ir.function.parameters:
@ -69,6 +70,7 @@ Initialize all the variables. If a variable is meant to be initialized to zero,
def _variable_written_in_proxy(self): def _variable_written_in_proxy(self):
# Hack to memoize without having it define in the init # Hack to memoize without having it define in the init
if hasattr(self, "__variables_written_in_proxy"): if hasattr(self, "__variables_written_in_proxy"):
# pylint: disable=access-member-before-definition
return self.__variables_written_in_proxy return self.__variables_written_in_proxy
variables_written_in_proxy = [] variables_written_in_proxy = []
@ -76,7 +78,8 @@ Initialize all the variables. If a variable is meant to be initialized to zero,
if c.is_upgradeable_proxy: if c.is_upgradeable_proxy:
variables_written_in_proxy += self._written_variables(c) variables_written_in_proxy += self._written_variables(c)
self.__variables_written_in_proxy = list(set([v.name for v in variables_written_in_proxy])) # pylint: disable=attribute-defined-outside-init
self.__variables_written_in_proxy = list({v.name for v in variables_written_in_proxy})
return self.__variables_written_in_proxy return self.__variables_written_in_proxy
def _written_variables_in_proxy(self, contract): def _written_variables_in_proxy(self, contract):

@ -9,8 +9,6 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
class UninitializedStorageVars(AbstractDetector): class UninitializedStorageVars(AbstractDetector):
"""
"""
ARGUMENT = "uninitialized-storage" ARGUMENT = "uninitialized-storage"
HELP = "Uninitialized storage variables" HELP = "Uninitialized storage variables"
@ -63,7 +61,9 @@ Bob calls `func`. As a result, `owner` is overridden to `0`.
else: else:
self.visited_all_paths[node] = [] self.visited_all_paths[node] = []
self.visited_all_paths[node] = list(set(self.visited_all_paths[node] + fathers_context)) self.visited_all_paths[node] = list(
set(self.visited_all_paths[node] + fathers_context)
)
if self.key in node.context: if self.key in node.context:
fathers_context += node.context[self.key] fathers_context += node.context[self.key]
@ -74,7 +74,9 @@ Bob calls `func`. As a result, `owner` is overridden to `0`.
self.results.append((function, uninitialized_storage_variable)) self.results.append((function, uninitialized_storage_variable))
# Only save the storage variables that are not yet written # Only save the storage variables that are not yet written
uninitialized_storage_variables = list(set(fathers_context) - set(node.variables_written)) uninitialized_storage_variables = list(
set(fathers_context) - set(node.variables_written)
)
node.context[self.key] = uninitialized_storage_variables node.context[self.key] = uninitialized_storage_variables
for son in node.sons: for son in node.sons:
@ -89,6 +91,7 @@ Bob calls `func`. As a result, `owner` is overridden to `0`.
""" """
results = [] results = []
# pylint: disable=attribute-defined-outside-init
self.results = [] self.results = []
self.visited_all_paths = {} self.visited_all_paths = {}
@ -96,13 +99,20 @@ Bob calls `func`. As a result, `owner` is overridden to `0`.
for function in contract.functions: for function in contract.functions:
if function.is_implemented: if function.is_implemented:
uninitialized_storage_variables = [ uninitialized_storage_variables = [
v for v in function.local_variables if v.is_storage and v.uninitialized v
for v in function.local_variables
if v.is_storage and v.uninitialized
] ]
function.entry_point.context[self.key] = uninitialized_storage_variables function.entry_point.context[
self.key
] = uninitialized_storage_variables
self._detect_uninitialized(function, function.entry_point, []) self._detect_uninitialized(function, function.entry_point, [])
for (function, uninitialized_storage_variable) in self.results: for (function, uninitialized_storage_variable) in self.results:
info = [uninitialized_storage_variable, " is a storage variable never initialized\n"] info = [
uninitialized_storage_variable,
" is a storage variable never initialized\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -6,7 +6,45 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.core.solidity_types import ArrayType from slither.core.solidity_types import ArrayType
from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.export_values import ExportValues
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.formatters.variables.unused_state_variables import format from slither.formatters.variables.unused_state_variables import custom_format
def detect_unused(contract):
if contract.is_signature_only():
return None
# Get all the variables read in all the functions and modifiers
all_functions = contract.all_functions_called + contract.modifiers
variables_used = [x.state_variables_read for x in all_functions]
variables_used += [
x.state_variables_written
for x in all_functions
if not x.is_constructor_variables
]
array_candidates = [x.variables for x in all_functions]
array_candidates = [
i for sl in array_candidates for i in sl
] + contract.state_variables
array_candidates = [
x.type.length
for x in array_candidates
if isinstance(x.type, ArrayType) and x.type.length
]
array_candidates = [ExportValues(x).result() for x in array_candidates]
array_candidates = [i for sl in array_candidates for i in sl]
array_candidates = [v for v in array_candidates if isinstance(v, StateVariable)]
# Flat list
variables_used = [item for sublist in variables_used for item in sublist]
variables_used = list(set(variables_used + array_candidates))
# Return the variables unused that are not public
return [
x
for x in contract.variables
if x not in variables_used and x.visibility != "public"
]
class UnusedStateVars(AbstractDetector): class UnusedStateVars(AbstractDetector):
@ -26,43 +64,12 @@ class UnusedStateVars(AbstractDetector):
WIKI_EXPLOIT_SCENARIO = "" WIKI_EXPLOIT_SCENARIO = ""
WIKI_RECOMMENDATION = "Remove unused state variables." WIKI_RECOMMENDATION = "Remove unused state variables."
def detect_unused(self, contract):
if contract.is_signature_only():
return None
# Get all the variables read in all the functions and modifiers
all_functions = contract.all_functions_called + contract.modifiers
variables_used = [x.state_variables_read for x in all_functions]
variables_used += [
x.state_variables_written for x in all_functions if not x.is_constructor_variables
]
array_candidates = [x.variables for x in all_functions]
array_candidates = [i for sl in array_candidates for i in sl] + contract.state_variables
array_candidates = [
x.type.length
for x in array_candidates
if isinstance(x.type, ArrayType) and x.type.length
]
array_candidates = [ExportValues(x).result() for x in array_candidates]
array_candidates = [i for sl in array_candidates for i in sl]
array_candidates = [v for v in array_candidates if isinstance(v, StateVariable)]
# Flat list
variables_used = [item for sublist in variables_used for item in sublist]
variables_used = list(set(variables_used + array_candidates))
# Return the variables unused that are not public
return [
x for x in contract.variables if x not in variables_used and x.visibility != "public"
]
def _detect(self): def _detect(self):
""" Detect unused state variables """ Detect unused state variables
""" """
results = [] results = []
for c in self.slither.contracts_derived: for c in self.slither.contracts_derived:
unusedVars = self.detect_unused(c) unusedVars = detect_unused(c)
if unusedVars: if unusedVars:
for var in unusedVars: for var in unusedVars:
info = [var, " is never used in ", c, "\n"] info = [var, " is never used in ", c, "\n"]
@ -73,4 +80,4 @@ class UnusedStateVars(AbstractDetector):
@staticmethod @staticmethod
def _format(slither, result): def _format(slither, result):
format(slither, result) custom_format(slither, result)

@ -3,7 +3,7 @@ from slither.formatters.exceptions import FormatError
from slither.formatters.utils.patches import create_patch from slither.formatters.utils.patches import create_patch
def format(slither, result): def custom_format(slither, result):
elements = result["elements"] elements = result["elements"]
for element in elements: for element in elements:
if element["type"] != "function": if element["type"] != "function":

@ -5,6 +5,8 @@ from slither.formatters.utils.patches import create_patch
# Indicates the recommended versions for replacement # Indicates the recommended versions for replacement
REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"] REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"]
# pylint: disable=anomalous-backslash-in-string
# group: # group:
# 0: ^ > >= < <= (optional) # 0: ^ > >= < <= (optional)
# 1: ' ' (optional) # 1: ' ' (optional)
@ -14,7 +16,7 @@ REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"]
PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)")
def format(slither, result): def custom_format(slither, result):
elements = result["elements"] elements = result["elements"]
versions_used = [] versions_used = []
for element in elements: for element in elements:
@ -35,10 +37,11 @@ def _analyse_versions(used_solc_versions):
replace_solc_versions = list() replace_solc_versions = list()
for version in used_solc_versions: for version in used_solc_versions:
replace_solc_versions.append(_determine_solc_version_replacement(version)) replace_solc_versions.append(_determine_solc_version_replacement(version))
if not all(version == replace_solc_versions[0] for version in replace_solc_versions): if not all(
version == replace_solc_versions[0] for version in replace_solc_versions
):
raise FormatImpossible("Multiple incompatible versions!") raise FormatImpossible("Multiple incompatible versions!")
else: return replace_solc_versions[0]
return replace_solc_versions[0]
def _determine_solc_version_replacement(used_solc_version): def _determine_solc_version_replacement(used_solc_version):
@ -48,24 +51,28 @@ def _determine_solc_version_replacement(used_solc_version):
minor_version = ".".join(version[2:])[2] minor_version = ".".join(version[2:])[2]
if minor_version == "4": if minor_version == "4":
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";" return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";"
elif minor_version == "5": if minor_version == "5":
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";" return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";"
else: raise FormatImpossible("Unknown version!")
raise FormatImpossible("Unknown version!") if len(versions) == 2:
elif len(versions) == 2:
version_right = versions[1] version_right = versions[1]
minor_version_right = ".".join(version_right[2:])[2] minor_version_right = ".".join(version_right[2:])[2]
if minor_version_right == "4": if minor_version_right == "4":
# Replace with 0.4.25 # Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";" return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";"
elif minor_version_right in ["5", "6"]: if minor_version_right in ["5", "6"]:
# Replace with 0.5.3 # Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";" return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";"
raise FormatImpossible("Unknown version!")
def _patch(slither, result, in_file, pragma, modify_loc_start, modify_loc_end): # pylint: disable=too-many-arguments
def _patch(slither, result, in_file, pragma, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode("utf8") in_file_str = slither.source_code[in_file].encode("utf8")
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
create_patch( create_patch(
result, in_file, int(modify_loc_start), int(modify_loc_end), old_str_of_interest, pragma result,
in_file,
int(modify_loc_start),
int(modify_loc_end),
old_str_of_interest,
pragma,
) )

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save