Merge branch 'dev' into dev-flattening

pull/328/head
Josselin 5 years ago
commit 2658a6b58d
  1. 3
      setup.py
  2. 6
      slither/__main__.py
  3. 12
      slither/core/children/child_expression.py
  4. 18
      slither/core/slither_core.py
  5. 41
      slither/detectors/abstract_detector.py
  6. 6
      slither/detectors/attributes/const_functions.py
  7. 5
      slither/detectors/attributes/constant_pragma.py
  8. 7
      slither/detectors/attributes/incorrect_solc.py
  9. 8
      slither/detectors/functions/external_function.py
  10. 8
      slither/detectors/naming_convention/naming_convention.py
  11. 6
      slither/detectors/variables/possible_const_state_variables.py
  12. 5
      slither/detectors/variables/unused_state_variables.py
  13. 0
      slither/formatters/__init__.py
  14. 0
      slither/formatters/attributes/__init__.py
  15. 36
      slither/formatters/attributes/const_functions.py
  16. 69
      slither/formatters/attributes/constant_pragma.py
  17. 59
      slither/formatters/attributes/incorrect_solc.py
  18. 5
      slither/formatters/exceptions.py
  19. 0
      slither/formatters/functions/__init__.py
  20. 42
      slither/formatters/functions/external_function.py
  21. 0
      slither/formatters/naming_convention/__init__.py
  22. 609
      slither/formatters/naming_convention/naming_convention.py
  23. 0
      slither/formatters/utils/__init__.py
  24. 43
      slither/formatters/utils/patches.py
  25. 0
      slither/formatters/variables/__init__.py
  26. 38
      slither/formatters/variables/possible_const_state_variables.py
  27. 28
      slither/formatters/variables/unused_state_variables.py
  28. 4
      slither/slither.py
  29. 75
      slither/slithir/convert.py
  30. 3
      slither/slithir/operations/operation.py
  31. 6
      slither/slithir/operations/phi.py
  32. 2
      slither/slithir/utils/ssa.py
  33. 8
      slither/solc_parsing/expressions/expression_parsing.py
  34. 14
      slither/tools/slither_format/.gitignore
  35. 0
      slither/tools/slither_format/__init__.py
  36. 89
      slither/tools/slither_format/__main__.py
  37. 151
      slither/tools/slither_format/slither_format.py
  38. 3
      slither/utils/command_line.py
  39. 30
      slither/visitors/slithir/expression_to_slithir.py

@ -20,7 +20,8 @@ setup(
'slither-check-upgradeability = slither.tools.upgradeability.__main__:main',
'slither-find-paths = slither.tools.possible_paths.__main__:main',
'slither-simil = slither.tools.similarity.__main__:main',
'slither-flat = slither.tools.flattening.__main__:main'
'slither-flat = slither.tools.flattening.__main__:main',
'slither-format = slither.tools.slither_format.__main__:main'
]
}
)

@ -257,6 +257,7 @@ def parse_filter_paths(args):
return args.filter_paths.split(',')
return []
def parse_args(detector_classes, printer_classes):
parser = argparse.ArgumentParser(description='Slither. For usage information, see https://github.com/crytic/slither/wiki/Usage',
usage="slither.py contract.sol [flag]")
@ -379,6 +380,11 @@ def parse_args(detector_classes, printer_classes):
action='store_true',
default=False)
group_misc.add_argument('--generate-patches',
help='Generate patches (json output only)',
action='store_true',
default=False)
# debugger command
parser.add_argument('--debug',
help=argparse.SUPPRESS,

@ -0,0 +1,12 @@
class ChildExpression:
def __init__(self):
super(ChildExpression, self).__init__()
self._expression = None
def set_expression(self, expression):
self._expression = expression
@property
def expression(self):
return self._expression

@ -35,6 +35,7 @@ class Slither(Context):
self._crytic_compile = None
self._generate_patches = False
###################################################################################
###################################################################################
@ -44,7 +45,7 @@ class Slither(Context):
@property
def source_code(self):
""" {filename: source_code}: source code """
""" {filename: source_code (str)}: source code """
return self._raw_source_code
@property
@ -233,4 +234,19 @@ class Slither(Context):
@property
def crytic_compile(self):
return self._crytic_compile
# endregion
###################################################################################
###################################################################################
# region Format
###################################################################################
###################################################################################
@property
def generate_patches(self):
return self._generate_patches
@generate_patches.setter
def generate_patches(self, p):
self._generate_patches = p
# endregion

@ -1,10 +1,10 @@
import abc
import re
from collections import OrderedDict, defaultdict
from slither.utils.colors import green, yellow, red
from slither.core.source_mapping.source_mapping import SourceMapping
from collections import OrderedDict
from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import apply_patch, create_diff
class IncorrectDetectorInitialization(Exception):
pass
@ -94,7 +94,8 @@ class AbstractDetector(metaclass=abc.ABCMeta):
raise IncorrectDetectorInitialization('CONFIDENCE is not initialized {}'.format(self.__class__.__name__))
def _log(self, info):
self.logger.info(self.color(info))
if self.logger:
self.logger.info(self.color(info))
@abc.abstractmethod
def _detect(self):
@ -115,6 +116,33 @@ class AbstractDetector(metaclass=abc.ABCMeta):
info += result['description']
info += 'Reference: {}'.format(self.WIKI)
self._log(info)
if self.slither.generate_patches:
for result in results:
try:
self._format(self.slither, result)
if not 'patches' in result:
continue
result['patches_diff'] = dict()
for file in result['patches']:
original_txt = self.slither.source_code[file].encode('utf8')
patched_txt = original_txt
offset = 0
patches = result['patches'][file]
patches.sort(key=lambda x: x['start'])
if not all(patches[i]['end'] <= patches[i + 1]['end'] for i in range(len(patches) - 1)):
self._log(f'Impossible to generate patch; patches collisions: {patches}')
continue
for patch in patches:
patched_txt, offset = apply_patch(patched_txt, patch, offset)
diff = create_diff(self.slither, original_txt, patched_txt, file)
if not diff:
self._log(f'Impossible to generate patch; empty {result}')
else:
result['patches_diff'][file] = diff
except FormatImpossible as e:
self._log(f'\nImpossible to patch:\n\t{result["description"]}\t{e}')
if results and self.slither.triage_mode:
while True:
indexes = input('Results to hide during next runs: "0,1,..." or "All" (enter to not hide results): '.format(len(results)))
@ -310,3 +338,8 @@ class AbstractDetector(metaclass=abc.ABCMeta):
{},
additional_fields)
d['elements'].append(element)
@staticmethod
def _format(slither, result):
"""Implement format"""
return

@ -3,7 +3,7 @@ Module detecting constant functions
Recursively check the called functions
"""
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.const_functions import format
class ConstantFunctions(AbstractDetector):
"""
@ -76,3 +76,7 @@ All the calls to `get` revert, breaking Bob's smart contract execution.'''
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -3,6 +3,7 @@
"""
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.constant_pragma import format
class ConstantPragma(AbstractDetector):
@ -42,3 +43,7 @@ class ConstantPragma(AbstractDetector):
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -2,8 +2,9 @@
Check if an incorrect version of solc is used
"""
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.incorrect_solc import format
# group:
# 0: ^ > >= < <= (optional)
@ -105,3 +106,7 @@ Use Solidity 0.4.25 or 0.5.3. Consider using the latest version of Solidity for
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -1,7 +1,9 @@
from slither.detectors.abstract_detector import (AbstractDetector,
DetectorClassification)
from slither.slithir.operations import (HighLevelCall, SolidityCall )
from slither.slithir.operations import SolidityCall
from slither.slithir.operations import (InternalCall, InternalDynamicCall)
from slither.formatters.functions.external_function import format
class ExternalFunction(AbstractDetector):
"""
@ -193,3 +195,7 @@ class ExternalFunction(AbstractDetector):
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -1,5 +1,7 @@
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.naming_convention.naming_convention import format
class NamingConvention(AbstractDetector):
@ -202,3 +204,7 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -7,6 +7,8 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.visitors.expression.export_values import ExportValues
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.variables.state_variable import StateVariable
from slither.formatters.variables.possible_const_state_variables import format
class ConstCandidateStateVars(AbstractDetector):
"""
@ -94,3 +96,7 @@ class ConstCandidateStateVars(AbstractDetector):
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.core.solidity_types import ArrayType
from slither.visitors.expression.export_values import ExportValues
from slither.core.variables.state_variable import StateVariable
from slither.formatters.variables.unused_state_variables import format
class UnusedStateVars(AbstractDetector):
"""
@ -70,3 +71,7 @@ class UnusedStateVars(AbstractDetector):
results.append(json)
return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -0,0 +1,36 @@
import re
from slither.formatters.exceptions import FormatError
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
if element['type'] != "function":
# Skip variable elements
continue
target_contract = slither.get_contract_from_name(element['type_specific_fields']['parent']['name'])
if target_contract:
function = target_contract.get_function_from_signature(element['type_specific_fields']['signature'])
if function:
_patch(slither,
result,
element['source_mapping']['filename_absolute'],
int(function.parameters_src.source_mapping['start'] +
function.parameters_src.source_mapping['length']),
int(function.returns_src.source_mapping['start']))
def _patch(slither, result, in_file, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Find the keywords view|pure|constant and remove them
m = re.search("(view|pure|constant)", old_str_of_interest.decode('utf-8'))
if m:
create_patch(result,
in_file,
modify_loc_start + m.span()[0],
modify_loc_start + m.span()[1],
m.groups(0)[0], # this is view|pure|constant
"")
else:
raise FormatError("No view/pure/constant specifier exists. Regex failed to remove specifier!")

@ -0,0 +1,69 @@
import re
from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import create_patch
# Indicates the recommended versions for replacement
REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"]
# group:
# 0: ^ > >= < <= (optional)
# 1: ' ' (optional)
# 2: version number
# 3: version number
# 4: version number
PATTERN = re.compile('(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)')
def format(slither, result):
elements = result['elements']
versions_used = []
for element in elements:
versions_used.append(''.join(element['type_specific_fields']['directive'][1:]))
solc_version_replace = _analyse_versions(versions_used)
for element in elements:
_patch(slither, result, element['source_mapping']['filename_absolute'], solc_version_replace,
element['source_mapping']['start'],
element['source_mapping']['start'] + element['source_mapping']['length'])
def _analyse_versions(used_solc_versions):
replace_solc_versions = list()
for version in used_solc_versions:
replace_solc_versions.append(_determine_solc_version_replacement(version))
if not all(version == replace_solc_versions[0] for version in replace_solc_versions):
raise FormatImpossible("Multiple incompatible versions!")
else:
return replace_solc_versions[0]
def _determine_solc_version_replacement(used_solc_version):
versions = PATTERN.findall(used_solc_version)
if len(versions) == 1:
version = versions[0]
minor_version = '.'.join(version[2:])[2]
if minor_version == '4':
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version == '5':
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
else:
raise FormatImpossible("Unknown version!")
elif len(versions) == 2:
version_right = versions[1]
minor_version_right = '.'.join(version_right[2:])[2]
if minor_version_right == '4':
# Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version_right in ['5', '6']:
# Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
def _patch(slither, result, in_file, pragma, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
create_patch(result,
in_file,
int(modify_loc_start),
int(modify_loc_end),
old_str_of_interest,
pragma)

@ -0,0 +1,59 @@
import re
from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import create_patch
# Indicates the recommended versions for replacement
REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"]
# group:
# 0: ^ > >= < <= (optional)
# 1: ' ' (optional)
# 2: version number
# 3: version number
# 4: version number
PATTERN = re.compile('(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)')
def format(slither, result):
elements = result['elements']
for element in elements:
solc_version_replace = _determine_solc_version_replacement(
''.join(element['type_specific_fields']['directive'][1:]))
_patch(slither, result, element['source_mapping']['filename_absolute'], solc_version_replace,
element['source_mapping']['start'], element['source_mapping']['start'] +
element['source_mapping']['length'])
def _determine_solc_version_replacement(used_solc_version):
versions = PATTERN.findall(used_solc_version)
if len(versions) == 1:
version = versions[0]
minor_version = '.'.join(version[2:])[2]
if minor_version == '4':
# Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version == '5':
# Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
else:
raise FormatImpossible(f"Unknown version {versions}")
elif len(versions) == 2:
version_right = versions[1]
minor_version_right = '.'.join(version_right[2:])[2]
if minor_version_right == '4':
# Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version_right in ['5','6']:
# Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
def _patch(slither, result, in_file, solc_version, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
create_patch(result,
in_file,
int(modify_loc_start),
int(modify_loc_end),
old_str_of_interest,
solc_version)

@ -0,0 +1,5 @@
from slither.exceptions import SlitherException
class FormatImpossible(SlitherException): pass
class FormatError(SlitherException): pass

@ -0,0 +1,42 @@
import re
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
target_contract = slither.get_contract_from_name(element['type_specific_fields']['parent']['name'])
if target_contract:
function = target_contract.get_function_from_signature(element['type_specific_fields']['signature'])
if function:
_patch(slither,
result,
element['source_mapping']['filename_absolute'],
int(function.parameters_src.source_mapping['start']),
int(function.returns_src.source_mapping['start']))
def _patch(slither, result, in_file, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Search for 'public' keyword which is in-between the function name and modifier name (if present)
# regex: 'public' could have spaces around or be at the end of the line
m = re.search(r'((\spublic)\s+)|(\spublic)$|(\)public)$', old_str_of_interest.decode('utf-8'))
if m is None:
# No visibility specifier exists; public by default.
create_patch(result,
in_file,
# start after the function definition's closing paranthesis
modify_loc_start + len(old_str_of_interest.decode('utf-8').split(')')[0]) + 1,
# end is same as start because we insert the keyword `external` at that location
modify_loc_start + len(old_str_of_interest.decode('utf-8').split(')')[0]) + 1,
"",
" external") # replace_text is `external`
else:
create_patch(result,
in_file,
# start at the keyword `public`
modify_loc_start + m.span()[0] + 1,
# end after the keyword `public` = start + len('public'')
modify_loc_start + m.span()[0] + 1 + len('public'),
"public",
"external")

@ -0,0 +1,609 @@
import re
import logging
from slither.slithir.operations import Send, Transfer, OperationWithLValue, HighLevelCall, LowLevelCall, \
InternalCall, InternalDynamicCall
from slither.core.declarations import Modifier
from slither.core.solidity_types import UserDefinedType, MappingType
from slither.core.declarations import Enum, Contract, Structure, Function
from slither.core.solidity_types.elementary_type import ElementaryTypeName
from slither.core.variables.local_variable import LocalVariable
from slither.formatters.exceptions import FormatError, FormatImpossible
from slither.formatters.utils.patches import create_patch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format')
def format(slither, result):
elements = result['elements']
for element in elements:
target = element['additional_fields']['target']
convention = element['additional_fields']['convention']
if convention == "l_O_I_should_not_be_used":
# l_O_I_should_not_be_used cannot be automatically patched
logger.info(f'The following naming convention cannot be patched: \n{result["description"]}')
continue
_patch(slither, result, element, target)
# endregion
###################################################################################
###################################################################################
# region Conventions
###################################################################################
###################################################################################
KEY = 'ALL_NAMES_USED'
# https://solidity.readthedocs.io/en/v0.5.11/miscellaneous.html#reserved-keywords
SOLIDITY_KEYWORDS = ['abstract', 'after', 'alias', 'apply', 'auto', 'case', 'catch', 'copyof', 'default', 'define',
'final', 'immutable', 'implements', 'in', 'inline', 'let', 'macro', 'match', 'mutable', 'null',
'of', 'override', 'partial', 'promise', 'reference', 'relocatable', 'sealed', 'sizeof', 'static',
'supports', 'switch', 'try', 'typedef', 'typeof', 'unchecked']
# https://solidity.readthedocs.io/en/v0.5.11/miscellaneous.html#language-grammar
SOLIDITY_KEYWORDS += ['pragma', 'import', 'contract', 'library', 'contract', 'function', 'using', 'struct', 'enum',
'public', 'private', 'internal', 'external', 'calldata', 'memory', 'modifier', 'view', 'pure',
'constant', 'storage', 'for', 'if', 'while', 'break', 'return', 'throw', 'else', 'type']
SOLIDITY_KEYWORDS += ElementaryTypeName
def _name_already_use(slither, name):
# Do not convert to a name used somewhere else
if not KEY in slither.context:
all_names = set()
for contract in slither.contracts_derived:
all_names = all_names.union(set([st.name for st in contract.structures]))
all_names = all_names.union(set([f.name for f in contract.functions_and_modifiers]))
all_names = all_names.union(set([e.name for e in contract.enums]))
all_names = all_names.union(set([s.name for s in contract.state_variables]))
for function in contract.functions:
all_names = all_names.union(set([v.name for v in function.variables]))
slither.context[KEY] = all_names
return name in slither.context[KEY]
def _convert_CapWords(original_name, slither):
name = original_name.capitalize()
while '_' in name:
offset = name.find('_')
if len(name) > offset:
name = name[0:offset] + name[offset+1].upper() + name[offset+1:]
if _name_already_use(slither, name):
raise FormatImpossible(f'{original_name} cannot be converted to {name} (already used)')
if name in SOLIDITY_KEYWORDS:
raise FormatImpossible(f'{original_name} cannot be converted to {name} (Solidity keyword)')
return name
def _convert_mixedCase(original_name, slither):
name = original_name
if isinstance(name, bytes):
name = name.decode('utf8')
while '_' in name:
offset = name.find('_')
if len(name) > offset:
name = name[0:offset] + name[offset + 1].upper() + name[offset + 2:]
name = name[0].lower() + name[1:]
if _name_already_use(slither, name):
raise FormatImpossible(f'{original_name} cannot be converted to {name} (already used)')
if name in SOLIDITY_KEYWORDS:
raise FormatImpossible(f'{original_name} cannot be converted to {name} (Solidity keyword)')
return name
def _convert_UPPER_CASE_WITH_UNDERSCORES(name, slither):
if _name_already_use(slither, name.upper()):
raise FormatImpossible(f'{name} cannot be converted to {name.upper()} (already used)')
if name.upper() in SOLIDITY_KEYWORDS:
raise FormatImpossible(f'{name} cannot be converted to {name.upper()} (Solidity keyword)')
return name.upper()
conventions ={
"CapWords":_convert_CapWords,
"mixedCase":_convert_mixedCase,
"UPPER_CASE_WITH_UNDERSCORES":_convert_UPPER_CASE_WITH_UNDERSCORES
}
# endregion
###################################################################################
###################################################################################
# region Helpers
###################################################################################
###################################################################################
def _get_from_contract(slither, element, name, getter):
contract_name = element['type_specific_fields']['parent']['name']
contract = slither.get_contract_from_name(contract_name)
return getattr(contract, getter)(name)
# endregion
###################################################################################
###################################################################################
# region Patch dispatcher
###################################################################################
###################################################################################
def _patch(slither, result, element, _target):
if _target == "contract":
target = slither.get_contract_from_name(element['name'])
elif _target == "structure":
target = _get_from_contract(slither, element, element['name'], 'get_structure_from_name')
elif _target == "event":
target = _get_from_contract(slither, element, element['name'], 'get_event_from_name')
elif _target == "function":
# Avoid constructor (FP?)
if element['name'] != element['type_specific_fields']['parent']['name']:
function_sig = element['type_specific_fields']['signature']
target = _get_from_contract(slither, element, function_sig, 'get_function_from_signature')
elif _target == "modifier":
modifier_sig = element['type_specific_fields']['signature']
target = _get_from_contract(slither, element, modifier_sig, 'get_modifier_from_signature')
elif _target == "parameter":
contract_name = element['type_specific_fields']['parent']['type_specific_fields']['parent']['name']
function_sig = element['type_specific_fields']['parent']['type_specific_fields']['signature']
param_name = element['name']
contract = slither.get_contract_from_name(contract_name)
function = contract.get_function_from_signature(function_sig)
target = function.get_local_variable_from_name(param_name)
elif _target in ["variable", "variable_constant"]:
# Local variable
if element['type_specific_fields']['parent'] == 'function':
contract_name = element['type_specific_fields']['parent']['type_specific_fields']['parent']['name']
function_sig = element['type_specific_fields']['parent']['type_specific_fields']['signature']
var_name = element['name']
contract = slither.get_contract_from_name(contract_name)
function = contract.get_function_from_signature(function_sig)
target = function.get_local_variable_from_name(var_name)
# State variable
else:
target = _get_from_contract(slither, element, element['name'], 'get_state_variable_from_name')
elif _target == "enum":
target = _get_from_contract(slither, element, element['name'], 'get_enum_from_canonical_name')
else:
raise FormatError("Unknown naming convention! " + _target)
_explore(slither,
result,
target,
conventions[element['additional_fields']['convention']])
# endregion
###################################################################################
###################################################################################
# region Explore functions
###################################################################################
###################################################################################
# group 1: beginning of the from type
# group 2: beginning of the to type
# nested mapping are within the group 1
#RE_MAPPING = '[ ]*mapping[ ]*\([ ]*([\=\>\(\) a-zA-Z0-9\._\[\]]*)[ ]*=>[ ]*([a-zA-Z0-9\._\[\]]*)\)'
RE_MAPPING_FROM = b'([a-zA-Z0-9\._\[\]]*)'
RE_MAPPING_TO = b'([\=\>\(\) a-zA-Z0-9\._\[\]\ ]*)'
RE_MAPPING = b'[ ]*mapping[ ]*\([ ]*' + RE_MAPPING_FROM + b'[ ]*' + b'=>' + b'[ ]*'+ RE_MAPPING_TO + b'\)'
def _is_var_declaration(slither, filename, start):
'''
Detect usage of 'var ' for Solidity < 0.5
:param slither:
:param filename:
:param start:
:return:
'''
v = 'var '
return slither.source_code[filename][start:start + len(v)] == v
def _explore_type(slither, result, target, convert, type, filename_source_code, start, end):
if isinstance(type, UserDefinedType):
# Patch type based on contract/enum
if isinstance(type.type, (Enum, Contract)):
if type.type == target:
old_str = type.type.name
new_str = convert(old_str, slither)
loc_start = start
if _is_var_declaration(slither, filename_source_code, start):
loc_end = loc_start + len('var')
else:
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
else:
# Patch type based on structure
assert isinstance(type.type, Structure)
if type.type == target:
old_str = type.type.name
new_str = convert(old_str, slither)
loc_start = start
if _is_var_declaration(slither, filename_source_code, start):
loc_end = loc_start + len('var')
else:
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
# Structure contain a list of elements, that might need patching
# .elems return a list of VariableStructure
_explore_variables_declaration(slither,
type.type.elems.values(),
result,
target,
convert)
if isinstance(type, MappingType):
# Mapping has three steps:
# Convert the "from" type
# Convert the "to" type
# Convert nested type in the "to"
# Ex: mapping (mapping (badName => uint) => uint)
# Do the comparison twice, so we can factor together the re matching
# mapping can only have elementary type in type_from
if isinstance(type.type_to, (UserDefinedType, MappingType)) or target in [type.type_from, type.type_to]:
full_txt_start = start
full_txt_end = end
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
re_match = re.match(RE_MAPPING, full_txt)
assert re_match
if type.type_from == target:
old_str = type.type_from.name
new_str = convert(old_str, slither)
loc_start = start + re_match.start(1)
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
if type.type_to == target:
old_str = type.type_to.name
new_str = convert(old_str, slither)
loc_start = start + re_match.start(2)
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
if isinstance(type.type_to, (UserDefinedType, MappingType)):
loc_start = start + re_match.start(2)
loc_end = start + re_match.end(2)
_explore_type(slither,
result,
target,
convert,
type.type_to,
filename_source_code,
loc_start,
loc_end)
def _explore_variables_declaration(slither, variables, result, target, convert, patch_comment=False):
for variable in variables:
# First explore the type of the variable
filename_source_code = variable.source_mapping['filename_absolute']
full_txt_start = variable.source_mapping['start']
full_txt_end = full_txt_start + variable.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
_explore_type(slither,
result,
target,
convert,
variable.type,
filename_source_code,
full_txt_start,
variable.source_mapping['start'] + variable.source_mapping['length'])
# If the variable is the target
if variable == target:
old_str = variable.name
new_str = convert(old_str, slither)
loc_start = full_txt_start + full_txt.find(old_str.encode('utf8'))
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
# Patch comment only makes sense for local variable declaration in the parameter list
if patch_comment and isinstance(variable, LocalVariable):
if 'lines' in variable.source_mapping and variable.source_mapping['lines']:
func = variable.function
end_line = func.source_mapping['lines'][0]
if variable in func.parameters:
idx = len(func.parameters) - func.parameters.index(variable) + 1
first_line = end_line - idx - 2
potential_comments = slither.source_code[filename_source_code].encode('utf8')
potential_comments = potential_comments.splitlines(keepends=True)[first_line:end_line-1]
idx_beginning = func.source_mapping['start']
idx_beginning += - func.source_mapping['starting_column'] + 1
idx_beginning += - sum([len(c) for c in potential_comments])
old_comment = f'@param {old_str}'.encode('utf8')
for line in potential_comments:
idx = line.find(old_comment)
if idx >=0:
loc_start = idx + idx_beginning
loc_end = loc_start + len(old_comment)
new_comment = f'@param {new_str}'.encode('utf8')
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_comment,
new_comment)
break
idx_beginning += len(line)
def _explore_modifiers_calls(slither, function, result, target, convert):
for modifier in function.modifiers_statements:
for node in modifier.nodes:
if node.irs:
_explore_irs(slither, node.irs, result, target, convert)
for modifier in function.explicit_base_constructor_calls_statements:
for node in modifier.nodes:
if node.irs:
_explore_irs(slither, node.irs, result, target, convert)
def _explore_structures_declaration(slither, structures, result, target, convert):
for st in structures:
# Explore the variable declared within the structure (VariableStructure)
_explore_variables_declaration(slither, st.elems.values(), result, target, convert)
# If the structure is the target
if st == target:
old_str = st.name
new_str = convert(old_str, slither)
filename_source_code = st.source_mapping['filename_absolute']
full_txt_start = st.source_mapping['start']
full_txt_end = full_txt_start + st.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
# The name is after the space
matches = re.finditer(b'struct[ ]*', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_events_declaration(slither, events, result, target, convert):
for event in events:
# Explore the parameters
_explore_variables_declaration(slither, event.elems, result, target, convert)
# If the event is the target
if event == target:
filename_source_code = event.source_mapping['filename_absolute']
old_str = event.name
new_str = convert(old_str, slither)
loc_start = event.source_mapping['start']
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def get_ir_variables(ir):
vars = ir.read
if isinstance(ir, (InternalCall, InternalDynamicCall, HighLevelCall)):
vars += [ir.function]
if isinstance(ir, (HighLevelCall, Send, LowLevelCall, Transfer)):
vars += [ir.call_value]
if isinstance(ir, (HighLevelCall, LowLevelCall)):
vars += [ir.call_gas]
if isinstance(ir, OperationWithLValue):
vars += [ir.lvalue]
return [v for v in vars if v]
def _explore_irs(slither, irs, result, target, convert):
if irs is None:
return
for ir in irs:
for v in get_ir_variables(ir):
if target == v or (
isinstance(target, Function) and isinstance(v, Function) and
v.canonical_name == target.canonical_name):
source_mapping = ir.expression.source_mapping
filename_source_code = source_mapping['filename_absolute']
full_txt_start = source_mapping['start']
full_txt_end = full_txt_start + source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
if not target.name.encode('utf8') in full_txt:
raise FormatError(f'{target} not found in {full_txt} ({source_mapping}')
old_str = target.name.encode('utf8')
new_str = convert(old_str, slither)
counter = 0
# Can be found multiple time on the same IR
# We patch one by one
while old_str in full_txt:
target_found_at = full_txt.find((old_str))
full_txt = full_txt[target_found_at+1:]
counter += target_found_at
loc_start = full_txt_start + counter
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_functions(slither, functions, result, target, convert):
for function in functions:
_explore_variables_declaration(slither, function.variables, result, target, convert, True)
_explore_modifiers_calls(slither, function, result, target, convert)
_explore_irs(slither, function.all_slithir_operations(), result, target, convert)
if isinstance(target, Function) and function.canonical_name == target.canonical_name:
old_str = function.name
new_str = convert(old_str, slither)
filename_source_code = function.source_mapping['filename_absolute']
full_txt_start = function.source_mapping['start']
full_txt_end = full_txt_start + function.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
# The name is after the space
if isinstance(target, Modifier):
matches = re.finditer(b'modifier([ ]*)', full_txt)
else:
matches = re.finditer(b'function([ ]*)', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_enums(slither, enums, result, target, convert):
for enum in enums:
if enum == target:
old_str = enum.name
new_str = convert(old_str, slither)
filename_source_code = enum.source_mapping['filename_absolute']
full_txt_start = enum.source_mapping['start']
full_txt_end = full_txt_start + enum.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
# The name is after the space
matches = re.finditer(b'enum([ ]*)', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_contract(slither, contract, result, target, convert):
_explore_variables_declaration(slither, contract.state_variables, result, target, convert)
_explore_structures_declaration(slither, contract.structures, result, target, convert)
_explore_functions(slither, contract.functions_and_modifiers, result, target, convert)
_explore_enums(slither, contract.enums, result, target, convert)
if contract == target:
filename_source_code = contract.source_mapping['filename_absolute']
full_txt_start = contract.source_mapping['start']
full_txt_end = full_txt_start + contract.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
old_str = contract.name
new_str = convert(old_str, slither)
# The name is after the space
matches = re.finditer(b'contract[ ]*', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore(slither, result, target, convert):
for contract in slither.contracts_derived:
_explore_contract(slither, contract, result, target, convert)
# endregion

@ -0,0 +1,43 @@
import os
import difflib
from collections import defaultdict
def create_patch(result, file, start, end, old_str, new_str):
if isinstance(old_str, bytes):
old_str = old_str.decode('utf8')
if isinstance(new_str, bytes):
new_str = new_str.decode('utf8')
p = {"start": start,
"end": end,
"old_string": old_str,
"new_string": new_str
}
if 'patches' not in result:
result['patches'] = defaultdict(list)
if p not in result['patches'][file]:
result['patches'][file].append(p)
def apply_patch(original_txt, patch, offset):
patched_txt = original_txt[:int(patch['start'] + offset)]
patched_txt += patch['new_string'].encode('utf8')
patched_txt += original_txt[int(patch['end'] + offset):]
# Keep the diff of text added or sub, in case of multiple patches
patch_length_diff = len(patch['new_string']) - (patch['end'] - patch['start'])
return patched_txt, patch_length_diff + offset
def create_diff(slither, original_txt, patched_txt, filename):
if slither.crytic_compile:
relative_path = slither.crytic_compile.filename_lookup(filename).relative
relative_path = os.path.join('.', relative_path)
else:
relative_path = filename
diff = difflib.unified_diff(original_txt.decode('utf8').splitlines(False),
patched_txt.decode('utf8').splitlines(False),
fromfile=relative_path,
tofile=relative_path,
lineterm='')
return '\n'.join(list(diff)) + '\n'

@ -0,0 +1,38 @@
import re
from slither.formatters.exceptions import FormatError, FormatImpossible
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
# TODO: decide if this should be changed in the constant detector
contract_name = element['type_specific_fields']['parent']['name']
contract = slither.get_contract_from_name(contract_name)
var = contract.get_state_variable_from_name(element['name'])
if not var.expression:
raise FormatImpossible(f'{var.name} is uninitialized and cannot become constant.')
_patch(slither, result, element['source_mapping']['filename_absolute'],
element['name'],
"constant " + element['name'],
element['source_mapping']['start'],
element['source_mapping']['start'] + element['source_mapping']['length'])
def _patch(slither, result, in_file, match_text, replace_text, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Add keyword `constant` before the variable name
(new_str_of_interest, num_repl) = re.subn(match_text, replace_text, old_str_of_interest.decode('utf-8'), 1)
if num_repl != 0:
create_patch(result,
in_file,
modify_loc_start,
modify_loc_end,
old_str_of_interest,
new_str_of_interest)
else:
raise FormatError("State variable not found?!")

@ -0,0 +1,28 @@
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
if element['type'] == "variable":
_patch(slither,
result,
element['source_mapping']['filename_absolute'],
element['source_mapping']['start'])
def _patch(slither, result, in_file, modify_loc_start):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:]
old_str = old_str_of_interest.decode('utf-8').partition(';')[0]\
+ old_str_of_interest.decode('utf-8').partition(';')[1]
create_patch(result,
in_file,
int(modify_loc_start),
# Remove the entire declaration until the semicolon
int(modify_loc_start + len(old_str_of_interest.decode('utf-8').partition(';')[0]) + 1),
old_str,
"")

@ -34,6 +34,7 @@ class Slither(SlitherSolc):
filter_paths (list(str)): list of path to filter (default [])
triage_mode (bool): if true, switch to triage mode (default false)
exclude_dependencies (bool): if true, exclude results that are only related to dependencies
generate_patches (bool): if true, patches are generated (json output only)
truffle_ignore (bool): ignore truffle.js presence (default false)
truffle_build_directory (str): build truffle directory (default 'build/contracts')
@ -64,6 +65,9 @@ class Slither(SlitherSolc):
self._parse_contracts_from_loaded_json(ast, path)
self._add_source_code(path)
if kwargs.get('generate_patches', False):
self.generate_patches = True
self._detectors = []
self._printers = []

@ -45,10 +45,14 @@ def convert_expression(expression, node):
if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]:
cst = Constant(expression.value, expression.type)
result = [Condition(cst)]
cond = Condition(cst)
cond.set_expression(expression)
result = [cond]
return result
if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]:
result = [Condition(expression.value)]
cond = Condition(expression.value)
cond.set_expression(expression)
result = [cond]
return result
@ -60,11 +64,15 @@ def convert_expression(expression, node):
if result:
if node.type in [NodeType.IF, NodeType.IFLOOP]:
assert isinstance(result[-1], (OperationWithLValue))
result.append(Condition(result[-1].lvalue))
cond = Condition(result[-1].lvalue)
cond.set_expression(expression)
result.append(cond)
elif node.type == NodeType.RETURN:
# May return None
if isinstance(result[-1], (OperationWithLValue)):
result.append(Return(result[-1].lvalue))
r = Return(result[-1].lvalue)
r.set_expression(expression)
result.append(r)
return result
@ -326,6 +334,7 @@ def _convert_type_contract(ir, slither):
assignment = Assignment(ir.lvalue,
Constant(str(bytecode)),
ElementaryType('bytes'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('bytes'))
return assignment
if ir.variable_right == 'runtimeCode':
@ -338,12 +347,14 @@ def _convert_type_contract(ir, slither):
assignment = Assignment(ir.lvalue,
Constant(str(bytecode)),
ElementaryType('bytes'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('bytes'))
return assignment
if ir.variable_right == 'name':
assignment = Assignment(ir.lvalue,
Constant(contract.name),
ElementaryType('string'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('string'))
return assignment
@ -446,14 +457,18 @@ def propagate_types(ir, node):
# TODO we should convert the reference to a temporary if the member is a length or a balance
if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)):
length = Length(ir.variable_left, ir.lvalue)
length.set_expression(ir.expression)
length.lvalue.points_to = ir.variable_left
return length
if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType):
return Balance(ir.variable_left, ir.lvalue)
b = Balance(ir.variable_left, ir.lvalue)
b.set_expression(ir.expression)
return b
if ir.variable_right == 'selector' and isinstance(ir.variable_left.type, Function):
assignment = Assignment(ir.lvalue,
Constant(str(get_function_id(ir.variable_left.type.full_name))),
ElementaryType('bytes4'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('bytes4'))
return assignment
if isinstance(ir.variable_left, TemporaryVariable) and isinstance(ir.variable_left.type, TypeInformation):
@ -534,6 +549,7 @@ def extract_tmp_call(ins, contract):
if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
call.set_expression(ins.expression)
call.call_id = ins.call_id
return call
if isinstance(ins.ori, Member):
@ -541,23 +557,28 @@ def extract_tmp_call(ins, contract):
if ins.ori.variable_left in contract.inheritance + [contract]:
if str(ins.ori.variable_right) in [f.name for f in contract.functions]:
internalcall = InternalCall((ins.ori.variable_right, ins.ori.variable_left.name), ins.nbr_arguments, ins.lvalue, ins.type_call)
internalcall.set_expression(ins.expression)
internalcall.call_id = ins.call_id
return internalcall
if str(ins.ori.variable_right) in [f.name for f in contract.events]:
eventcall = EventCall(ins.ori.variable_right)
eventcall.call_id = ins.call_id
return eventcall
eventcall = EventCall(ins.ori.variable_right)
eventcall.set_expression(ins.expression)
eventcall.call_id = ins.call_id
return eventcall
if isinstance(ins.ori.variable_left, Contract):
st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
if st:
op = NewStructure(st, ins.lvalue)
op.set_expression(ins.expression)
op.call_id = ins.call_id
return op
libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
libcall.set_expression(ins.expression)
libcall.call_id = ins.call_id
return libcall
msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
msgcall.call_id = ins.call_id
msgcall.set_expression(ins.expression)
return msgcall
if isinstance(ins.ori, TmpCall):
@ -567,29 +588,42 @@ def extract_tmp_call(ins, contract):
if str(ins.called) == 'block.blockhash':
ins.called = SolidityFunction('blockhash(uint256)')
elif str(ins.called) == 'this.balance':
return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)
s = SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)
s.set_expression(ins.expression)
return s
if isinstance(ins.called, SolidityFunction):
return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)
s = SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)
s.set_expression(ins.expression)
return s
if isinstance(ins.ori, TmpNewElementaryType):
return NewElementaryType(ins.ori.type, ins.lvalue)
n = NewElementaryType(ins.ori.type, ins.lvalue)
n.set_expression(ins.expression)
return n
if isinstance(ins.ori, TmpNewContract):
op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
op.set_expression(ins.expression)
op.call_id = ins.call_id
return op
if isinstance(ins.ori, TmpNewArray):
return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
n = NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
n.set_expression(ins.expression)
return n
if isinstance(ins.called, Structure):
op = NewStructure(ins.called, ins.lvalue)
op.set_expression(ins.expression)
op.call_id = ins.call_id
op.set_expression(ins.expression)
return op
if isinstance(ins.called, Event):
return EventCall(ins.called.name)
e = EventCall(ins.called.name)
e.set_expression(ins.expression)
return e
if isinstance(ins.called, Contract):
# Called a base constructor, where there is no constructor
@ -598,6 +632,7 @@ def extract_tmp_call(ins, contract):
internalcall = InternalCall(ins.called.constructor, ins.nbr_arguments, ins.lvalue,
ins.type_call)
internalcall.call_id = ins.call_id
internalcall.set_expression(ins.expression)
return internalcall
@ -628,11 +663,15 @@ def convert_to_low_level(ir):
"""
if ir.function_name == 'transfer':
assert len(ir.arguments) == 1
prev_ir = ir
ir = Transfer(ir.destination, ir.arguments[0])
ir.set_expression(prev_ir.expression)
return ir
elif ir.function_name == 'send':
assert len(ir.arguments) == 1
prev_ir = ir
ir = Send(ir.destination, ir.arguments[0], ir.lvalue)
ir.set_expression(prev_ir.expression)
ir.lvalue.set_type(ElementaryType('bool'))
return ir
elif ir.function_name in ['call',
@ -648,6 +687,7 @@ def convert_to_low_level(ir):
new_ir.call_value = ir.call_value
new_ir.arguments = ir.arguments
new_ir.lvalue.set_type(ElementaryType('bool'))
new_ir.set_expression(ir.expression)
return new_ir
raise SlithIRError('Incorrect conversion to low level {}'.format(ir))
@ -668,6 +708,7 @@ def convert_to_solidity_func(ir):
call = SolidityFunction('abi.{}()'.format(ir.function_name))
new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call)
new_ir.arguments = ir.arguments
new_ir.set_expression(ir.expression)
if isinstance(call.return_type, list) and len(call.return_type) == 1:
new_ir.lvalue.set_type(call.return_type[0])
else:
@ -693,9 +734,12 @@ def convert_to_push(ir, node):
val = TemporaryVariable(node)
operation = InitArray(ir.arguments[0], val)
operation.set_expression(ir.expression)
ret.append(operation)
prev_ir = ir
ir = Push(ir.destination, val)
ir.set_expression(prev_ir.expression)
length = Literal(len(operation.init_values), 'uint256')
t = operation.init_values[0].type
@ -705,18 +749,22 @@ def convert_to_push(ir, node):
if lvalue:
length = Length(ir.array, lvalue)
length.set_expression(ir.expression)
length.lvalue.points_to = ir.lvalue
ret.append(length)
return ret
prev_ir = ir
ir = Push(ir.destination, ir.arguments[0])
ir.set_expression(prev_ir.expression)
if lvalue:
ret = []
ret.append(ir)
length = Length(ir.array, lvalue)
length.set_expression(ir.expression)
length.lvalue.points_to = ir.lvalue
ret.append(length)
return ret
@ -732,6 +780,7 @@ def look_for_library(contract, ir, node, using_for, t):
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
lib_call.set_expression(ir.expression)
lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments
new_ir = convert_type_library_call(lib_call, lib_contract)

@ -1,5 +1,6 @@
import abc
from slither.core.context.context import Context
from slither.core.children.child_expression import ChildExpression
from slither.core.children.child_node import ChildNode
from slither.utils.utils import unroll
@ -21,7 +22,7 @@ class AbstractOperation(abc.ABC):
"""
pass
class Operation(Context, ChildNode, AbstractOperation):
class Operation(Context, ChildExpression, ChildNode, AbstractOperation):
@property
def used(self):

@ -1,10 +1,6 @@
import logging
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable
from slither.slithir.variables import TupleVariable
from slither.core.declarations.function import Function
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue
from slither.slithir.utils.utils import is_valid_lvalue
class Phi(OperationWithLValue):

@ -181,6 +181,8 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
tuple_variables_instances,
all_local_variables_instances)
new_ir.set_expression(ir.expression)
update_lvalue(new_ir,
node,
local_variables_instances,

@ -275,7 +275,9 @@ def parse_call(expression, caller_context):
arguments = [parse_expression(a, caller_context) for a in children[1::]]
if isinstance(called, SuperCallExpression):
return SuperCallExpression(called, arguments, type_return)
sp = SuperCallExpression(called, arguments, type_return)
sp.set_offset(expression['src'], caller_context.slither)
return sp
call_expression = CallExpression(called, arguments, type_return)
call_expression.set_offset(src, caller_context.slither)
return call_expression
@ -386,7 +388,7 @@ def parse_expression(expression, caller_context):
return binary_op
elif name == 'FunctionCall':
return parse_call(expression, caller_context)
return parse_call(expression, caller_context)
elif name == 'TupleExpression':
"""
@ -401,7 +403,7 @@ def parse_expression(expression, caller_context):
Note: this is only possible with Solidity >= 0.4.12
"""
if is_compact_ast:
expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']]
expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']]
else:
if 'children' not in expression :
attributes = expression['attributes']

@ -0,0 +1,14 @@
# .format files are the output files produced by slither-format
# .patch files are the output files produced by slither-format
*.format
*.patch
# Temporary files (Emacs backup files ending in tilde and others)
*~
*.err
*.out

@ -0,0 +1,89 @@
import sys
import argparse
from slither import Slither
from slither.utils.command_line import read_config_file
import logging
from .slither_format import slither_format
from crytic_compile import cryticparser
logging.basicConfig()
logger = logging.getLogger("Slither").setLevel(logging.INFO)
# Slither detectors for which slither-format currently works
available_detectors = ["unused-state",
"solc-version",
"pragma",
"naming-convention",
"external-function",
"constable-states",
"constant-function"]
detectors_to_run = []
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='slither_format',
usage='slither_format filename')
parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False)
parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False)
parser.add_argument('--version',
help='displays the current version',
version='0.1.0',
action='version')
parser.add_argument('--config-file',
help='Provide a config file (default: slither.config.json)',
action='store',
dest='config_file',
default='slither.config.json')
group_detector = parser.add_argument_group('Detectors')
group_detector.add_argument('--detect',
help='Comma-separated list of detectors, defaults to all, '
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_run',
default='all')
group_detector.add_argument('--exclude',
help='Comma-separated list of detectors to exclude,'
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_exclude',
default='all')
cryticparser.init(parser)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
return parser.parse_args()
def main():
# ------------------------------
# Usage: python3 -m slither_format filename
# Example: python3 -m slither_format contract.sol
# ------------------------------
# Parse all arguments
args = parse_args()
read_config_file(args)
# Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args))
# Format the input files based on slither analysis
slither_format(slither, **vars(args))
if __name__ == '__main__':
main()

@ -0,0 +1,151 @@
import logging
from pathlib import Path
from slither.detectors.variables.unused_state_variables import UnusedStateVars
from slither.detectors.attributes.incorrect_solc import IncorrectSolc
from slither.detectors.attributes.constant_pragma import ConstantPragma
from slither.detectors.naming_convention.naming_convention import NamingConvention
from slither.detectors.functions.external_function import ExternalFunction
from slither.detectors.variables.possible_const_state_variables import ConstCandidateStateVars
from slither.detectors.attributes.const_functions import ConstantFunctions
from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format')
all_detectors = {
'unused-state': UnusedStateVars,
'solc-version': IncorrectSolc,
'pragma': ConstantPragma,
'naming-convention': NamingConvention,
'external-function': ExternalFunction,
'constable-states' : ConstCandidateStateVars,
'constant-function': ConstantFunctions
}
def slither_format(slither, **kwargs):
''''
Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all
'''
detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'),
kwargs.get('detectors_to_exclude', ''))
for detector in detectors_to_run:
slither.register_detector(detector)
slither.generate_patches = True
detector_results = slither.run_detectors()
detector_results = [x for x in detector_results if x] # remove empty results
detector_results = [item for sublist in detector_results for item in sublist] # flatten
export = Path('crytic-export', 'patches')
export.mkdir(parents=True, exist_ok=True)
counter_result = 0
logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.'))
for result in detector_results:
if not 'patches' in result:
continue
one_line_description = result["description"].split("\n")[0]
export_result = Path(export, f'{counter_result}')
export_result.mkdir(parents=True, exist_ok=True)
counter_result += 1
counter = 0
logger.info(f'Issue: {one_line_description}')
logger.info(f'Generated: ({export_result})')
for file, diff, in result['patches_diff'].items():
filename = f'fix_{counter}.patch'
path = Path(export_result, filename)
logger.info(f'\t- {filename}')
with open(path, 'w') as f:
f.write(diff)
counter += 1
# endregion
###################################################################################
###################################################################################
# region Detectors
###################################################################################
###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude):
# If detectors are specified, run only these ones
cls_detectors_to_run = []
exclude = detectors_to_exclude.split(',')
if detectors_to_run == 'all':
for d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
exclude = detectors_to_exclude.split(',')
for d in detectors_to_run.split(','):
if d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
raise Exception('Error: {} is not a detector'.format(d))
return cls_detectors_to_run
# endregion
###################################################################################
###################################################################################
# region Debug functions
###################################################################################
###################################################################################
def print_patches(number_of_slither_results, patches):
logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0
for file in patches:
number_of_patches += len(patches[file])
logger.info("Number of patches: " + str(number_of_patches))
for file in patches:
logger.info("Patch file: " + file)
for patch in patches[file]:
logger.info("Detector: " + patch['detector'])
logger.info("Old string: " + patch['old_string'].replace("\n",""))
logger.info("New string: " + patch['new_string'].replace("\n",""))
logger.info("Location start: " + str(patch['start']))
logger.info("Location end: " + str(patch['end']))
def print_patches_json(number_of_slither_results, patches):
print('{',end='')
print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",')
print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',')
print("\"Patchlets\":" + '[')
for index, file in enumerate(patches):
if index > 0:
print(',')
print('{',end='')
print("\"Patch file\":" + '"' + file + '",')
print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',')
print("\"Patches\":" + '[')
for index, patch in enumerate(patches[file]):
if index > 0:
print(',')
print('{',end='')
print("\"Detector\":" + '"' + patch['detector'] + '",')
print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",')
print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",')
print("\"Location start\":" + '"' + str(patch['start']) + '",')
print("\"Location end\":" + '"' + str(patch['end']) + '"')
if 'overlaps' in patch:
print("\"Overlaps\":" + "Yes")
print('}',end='')
print(']',end='')
print('}',end='')
print(']',end='')
print('}')

@ -1,6 +1,8 @@
import os
import logging
import json
import os
import logging
from collections import defaultdict
from prettytable import PrettyTable
from crytic_compile.cryticparser.defaults import defaults_flag_in_config as defaults_flag_in_config_crytic_compile
@ -29,6 +31,7 @@ defaults_flag_in_config = {
'json-types': ','.join(DEFAULT_JSON_OUTPUT_TYPES),
'disable_color': False,
'filter_paths': None,
'generate_patches': False,
# debug command
'legacy_ast': False,
'ignore_return_value': False,

@ -67,7 +67,9 @@ class ExpressionToSlithIR(ExpressionVisitor):
self._result = []
self._visit_expression(self.expression)
if node.type == NodeType.RETURN:
self._result.append(Return(get(self.expression)))
r = Return(get(self.expression))
r.set_expression(expression)
self._result.append(r)
for ir in self._result:
ir.set_node(node)
@ -83,6 +85,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
for idx in range(len(left)):
if not left[idx] is None:
operation = convert_assignment(left[idx], right[idx], expression.type, expression.expression_return_type)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, None)
else:
@ -90,6 +93,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
for idx in range(len(left)):
if not left[idx] is None:
operation = Unpack(left[idx], right, idx)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, None)
else:
@ -97,10 +101,12 @@ class ExpressionToSlithIR(ExpressionVisitor):
# uint8[2] var = [1,2];
if isinstance(right, list):
operation = InitArray(right, left)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, left)
else:
operation = convert_assignment(left, right, expression.type, expression.expression_return_type)
operation.set_expression(expression)
self._result.append(operation)
# Return left to handle
# a = b = 1;
@ -112,6 +118,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
val = TemporaryVariable(self._node)
operation = Binary(val, left, right, expression.type)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)
@ -120,6 +127,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
args = [get(a) for a in expression.arguments if a]
for arg in args:
arg_ = Argument(arg)
arg_.set_expression(expression)
self._result.append(arg_)
if isinstance(called, Function):
# internal call
@ -130,11 +138,10 @@ class ExpressionToSlithIR(ExpressionVisitor):
else:
val = TemporaryVariable(self._node)
internal_call = InternalCall(called, len(args), val, expression.type_call)
internal_call.set_expression(expression)
self._result.append(internal_call)
set_val(expression, val)
else:
val = TemporaryVariable(self._node)
# If tuple
if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()':
val = TupleVariable(self._node)
@ -142,6 +149,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
val = TemporaryVariable(self._node)
message_call = TmpCall(called, len(args), val, expression.type_call)
message_call.set_expression(expression)
self._result.append(message_call)
set_val(expression, val)
@ -165,8 +173,10 @@ class ExpressionToSlithIR(ExpressionVisitor):
init_array_right = left
left = init_array_val
operation = InitArray(init_array_right, init_array_val)
operation.set_expression(expression)
self._result.append(operation)
operation = Index(val, left, right, expression.type)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)
@ -178,18 +188,21 @@ class ExpressionToSlithIR(ExpressionVisitor):
expr = get(expression.expression)
val = ReferenceVariable(self._node)
member = Member(expr, Constant(expression.member_name), val)
member.set_expression(expression)
self._result.append(member)
set_val(expression, val)
def _post_new_array(self, expression):
val = TemporaryVariable(self._node)
operation = TmpNewArray(expression.depth, expression.array_type, val)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)
def _post_new_contract(self, expression):
val = TemporaryVariable(self._node)
operation = TmpNewContract(expression.contract_name, val)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)
@ -197,6 +210,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
# TODO unclear if this is ever used?
val = TemporaryVariable(self._node)
operation = TmpNewElementaryType(expression.type, val)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)
@ -212,6 +226,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
expr = get(expression.expression)
val = TemporaryVariable(self._node)
operation = TypeConversion(val, expr, expression.type)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)
@ -220,32 +235,40 @@ class ExpressionToSlithIR(ExpressionVisitor):
if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]:
lvalue = TemporaryVariable(self._node)
operation = Unary(lvalue, value, expression.type)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, lvalue)
elif expression.type in [UnaryOperationType.DELETE]:
operation = Delete(value, value)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, value)
elif expression.type in [UnaryOperationType.PLUSPLUS_PRE]:
operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, value)
elif expression.type in [UnaryOperationType.MINUSMINUS_PRE]:
operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, value)
elif expression.type in [UnaryOperationType.PLUSPLUS_POST]:
lvalue = TemporaryVariable(self._node)
operation = Assignment(lvalue, value, value.type)
operation.set_expression(expression)
self._result.append(operation)
operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, lvalue)
elif expression.type in [UnaryOperationType.MINUSMINUS_POST]:
lvalue = TemporaryVariable(self._node)
operation = Assignment(lvalue, value, value.type)
operation.set_expression(expression)
self._result.append(operation)
operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, lvalue)
elif expression.type in [UnaryOperationType.PLUS_PRE]:
@ -253,6 +276,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
elif expression.type in [UnaryOperationType.MINUS_PRE]:
lvalue = TemporaryVariable(self._node)
operation = Binary(lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, lvalue)
else:

Loading…
Cancel
Save