Merge branch 'dev' into dev-flattening

pull/328/head
Josselin 5 years ago
commit 2658a6b58d
  1. 3
      setup.py
  2. 6
      slither/__main__.py
  3. 12
      slither/core/children/child_expression.py
  4. 18
      slither/core/slither_core.py
  5. 41
      slither/detectors/abstract_detector.py
  6. 6
      slither/detectors/attributes/const_functions.py
  7. 5
      slither/detectors/attributes/constant_pragma.py
  8. 7
      slither/detectors/attributes/incorrect_solc.py
  9. 8
      slither/detectors/functions/external_function.py
  10. 8
      slither/detectors/naming_convention/naming_convention.py
  11. 6
      slither/detectors/variables/possible_const_state_variables.py
  12. 5
      slither/detectors/variables/unused_state_variables.py
  13. 0
      slither/formatters/__init__.py
  14. 0
      slither/formatters/attributes/__init__.py
  15. 36
      slither/formatters/attributes/const_functions.py
  16. 69
      slither/formatters/attributes/constant_pragma.py
  17. 59
      slither/formatters/attributes/incorrect_solc.py
  18. 5
      slither/formatters/exceptions.py
  19. 0
      slither/formatters/functions/__init__.py
  20. 42
      slither/formatters/functions/external_function.py
  21. 0
      slither/formatters/naming_convention/__init__.py
  22. 609
      slither/formatters/naming_convention/naming_convention.py
  23. 0
      slither/formatters/utils/__init__.py
  24. 43
      slither/formatters/utils/patches.py
  25. 0
      slither/formatters/variables/__init__.py
  26. 38
      slither/formatters/variables/possible_const_state_variables.py
  27. 28
      slither/formatters/variables/unused_state_variables.py
  28. 4
      slither/slither.py
  29. 75
      slither/slithir/convert.py
  30. 3
      slither/slithir/operations/operation.py
  31. 6
      slither/slithir/operations/phi.py
  32. 2
      slither/slithir/utils/ssa.py
  33. 8
      slither/solc_parsing/expressions/expression_parsing.py
  34. 14
      slither/tools/slither_format/.gitignore
  35. 0
      slither/tools/slither_format/__init__.py
  36. 89
      slither/tools/slither_format/__main__.py
  37. 151
      slither/tools/slither_format/slither_format.py
  38. 3
      slither/utils/command_line.py
  39. 30
      slither/visitors/slithir/expression_to_slithir.py

@ -20,7 +20,8 @@ setup(
'slither-check-upgradeability = slither.tools.upgradeability.__main__:main', 'slither-check-upgradeability = slither.tools.upgradeability.__main__:main',
'slither-find-paths = slither.tools.possible_paths.__main__:main', 'slither-find-paths = slither.tools.possible_paths.__main__:main',
'slither-simil = slither.tools.similarity.__main__:main', 'slither-simil = slither.tools.similarity.__main__:main',
'slither-flat = slither.tools.flattening.__main__:main' 'slither-flat = slither.tools.flattening.__main__:main',
'slither-format = slither.tools.slither_format.__main__:main'
] ]
} }
) )

@ -257,6 +257,7 @@ def parse_filter_paths(args):
return args.filter_paths.split(',') return args.filter_paths.split(',')
return [] return []
def parse_args(detector_classes, printer_classes): def parse_args(detector_classes, printer_classes):
parser = argparse.ArgumentParser(description='Slither. For usage information, see https://github.com/crytic/slither/wiki/Usage', parser = argparse.ArgumentParser(description='Slither. For usage information, see https://github.com/crytic/slither/wiki/Usage',
usage="slither.py contract.sol [flag]") usage="slither.py contract.sol [flag]")
@ -379,6 +380,11 @@ def parse_args(detector_classes, printer_classes):
action='store_true', action='store_true',
default=False) default=False)
group_misc.add_argument('--generate-patches',
help='Generate patches (json output only)',
action='store_true',
default=False)
# debugger command # debugger command
parser.add_argument('--debug', parser.add_argument('--debug',
help=argparse.SUPPRESS, help=argparse.SUPPRESS,

@ -0,0 +1,12 @@
class ChildExpression:
def __init__(self):
super(ChildExpression, self).__init__()
self._expression = None
def set_expression(self, expression):
self._expression = expression
@property
def expression(self):
return self._expression

@ -35,6 +35,7 @@ class Slither(Context):
self._crytic_compile = None self._crytic_compile = None
self._generate_patches = False
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -44,7 +45,7 @@ class Slither(Context):
@property @property
def source_code(self): def source_code(self):
""" {filename: source_code}: source code """ """ {filename: source_code (str)}: source code """
return self._raw_source_code return self._raw_source_code
@property @property
@ -234,3 +235,18 @@ class Slither(Context):
def crytic_compile(self): def crytic_compile(self):
return self._crytic_compile return self._crytic_compile
# endregion # endregion
###################################################################################
###################################################################################
# region Format
###################################################################################
###################################################################################
@property
def generate_patches(self):
return self._generate_patches
@generate_patches.setter
def generate_patches(self, p):
self._generate_patches = p
# endregion

@ -1,10 +1,10 @@
import abc import abc
import re import re
from collections import OrderedDict, defaultdict
from slither.utils.colors import green, yellow, red from slither.utils.colors import green, yellow, red
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from collections import OrderedDict from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import apply_patch, create_diff
class IncorrectDetectorInitialization(Exception): class IncorrectDetectorInitialization(Exception):
pass pass
@ -94,7 +94,8 @@ class AbstractDetector(metaclass=abc.ABCMeta):
raise IncorrectDetectorInitialization('CONFIDENCE is not initialized {}'.format(self.__class__.__name__)) raise IncorrectDetectorInitialization('CONFIDENCE is not initialized {}'.format(self.__class__.__name__))
def _log(self, info): def _log(self, info):
self.logger.info(self.color(info)) if self.logger:
self.logger.info(self.color(info))
@abc.abstractmethod @abc.abstractmethod
def _detect(self): def _detect(self):
@ -115,6 +116,33 @@ class AbstractDetector(metaclass=abc.ABCMeta):
info += result['description'] info += result['description']
info += 'Reference: {}'.format(self.WIKI) info += 'Reference: {}'.format(self.WIKI)
self._log(info) self._log(info)
if self.slither.generate_patches:
for result in results:
try:
self._format(self.slither, result)
if not 'patches' in result:
continue
result['patches_diff'] = dict()
for file in result['patches']:
original_txt = self.slither.source_code[file].encode('utf8')
patched_txt = original_txt
offset = 0
patches = result['patches'][file]
patches.sort(key=lambda x: x['start'])
if not all(patches[i]['end'] <= patches[i + 1]['end'] for i in range(len(patches) - 1)):
self._log(f'Impossible to generate patch; patches collisions: {patches}')
continue
for patch in patches:
patched_txt, offset = apply_patch(patched_txt, patch, offset)
diff = create_diff(self.slither, original_txt, patched_txt, file)
if not diff:
self._log(f'Impossible to generate patch; empty {result}')
else:
result['patches_diff'][file] = diff
except FormatImpossible as e:
self._log(f'\nImpossible to patch:\n\t{result["description"]}\t{e}')
if results and self.slither.triage_mode: if results and self.slither.triage_mode:
while True: while True:
indexes = input('Results to hide during next runs: "0,1,..." or "All" (enter to not hide results): '.format(len(results))) indexes = input('Results to hide during next runs: "0,1,..." or "All" (enter to not hide results): '.format(len(results)))
@ -310,3 +338,8 @@ class AbstractDetector(metaclass=abc.ABCMeta):
{}, {},
additional_fields) additional_fields)
d['elements'].append(element) d['elements'].append(element)
@staticmethod
def _format(slither, result):
"""Implement format"""
return

@ -3,7 +3,7 @@ Module detecting constant functions
Recursively check the called functions Recursively check the called functions
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.const_functions import format
class ConstantFunctions(AbstractDetector): class ConstantFunctions(AbstractDetector):
""" """
@ -76,3 +76,7 @@ All the calls to `get` revert, breaking Bob's smart contract execution.'''
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -3,6 +3,7 @@
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.constant_pragma import format
class ConstantPragma(AbstractDetector): class ConstantPragma(AbstractDetector):
@ -42,3 +43,7 @@ class ConstantPragma(AbstractDetector):
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -2,8 +2,9 @@
Check if an incorrect version of solc is used Check if an incorrect version of solc is used
""" """
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
import re import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.attributes.incorrect_solc import format
# group: # group:
# 0: ^ > >= < <= (optional) # 0: ^ > >= < <= (optional)
@ -105,3 +106,7 @@ Use Solidity 0.4.25 or 0.5.3. Consider using the latest version of Solidity for
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -1,7 +1,9 @@
from slither.detectors.abstract_detector import (AbstractDetector, from slither.detectors.abstract_detector import (AbstractDetector,
DetectorClassification) DetectorClassification)
from slither.slithir.operations import (HighLevelCall, SolidityCall ) from slither.slithir.operations import SolidityCall
from slither.slithir.operations import (InternalCall, InternalDynamicCall) from slither.slithir.operations import (InternalCall, InternalDynamicCall)
from slither.formatters.functions.external_function import format
class ExternalFunction(AbstractDetector): class ExternalFunction(AbstractDetector):
""" """
@ -193,3 +195,7 @@ class ExternalFunction(AbstractDetector):
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -1,5 +1,7 @@
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
import re import re
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.formatters.naming_convention.naming_convention import format
class NamingConvention(AbstractDetector): class NamingConvention(AbstractDetector):
@ -202,3 +204,7 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -7,6 +7,8 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.export_values import ExportValues
from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.formatters.variables.possible_const_state_variables import format
class ConstCandidateStateVars(AbstractDetector): class ConstCandidateStateVars(AbstractDetector):
""" """
@ -94,3 +96,7 @@ class ConstCandidateStateVars(AbstractDetector):
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.core.solidity_types import ArrayType from slither.core.solidity_types import ArrayType
from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.export_values import ExportValues
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.formatters.variables.unused_state_variables import format
class UnusedStateVars(AbstractDetector): class UnusedStateVars(AbstractDetector):
""" """
@ -70,3 +71,7 @@ class UnusedStateVars(AbstractDetector):
results.append(json) results.append(json)
return results return results
@staticmethod
def _format(slither, result):
format(slither, result)

@ -0,0 +1,36 @@
import re
from slither.formatters.exceptions import FormatError
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
if element['type'] != "function":
# Skip variable elements
continue
target_contract = slither.get_contract_from_name(element['type_specific_fields']['parent']['name'])
if target_contract:
function = target_contract.get_function_from_signature(element['type_specific_fields']['signature'])
if function:
_patch(slither,
result,
element['source_mapping']['filename_absolute'],
int(function.parameters_src.source_mapping['start'] +
function.parameters_src.source_mapping['length']),
int(function.returns_src.source_mapping['start']))
def _patch(slither, result, in_file, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Find the keywords view|pure|constant and remove them
m = re.search("(view|pure|constant)", old_str_of_interest.decode('utf-8'))
if m:
create_patch(result,
in_file,
modify_loc_start + m.span()[0],
modify_loc_start + m.span()[1],
m.groups(0)[0], # this is view|pure|constant
"")
else:
raise FormatError("No view/pure/constant specifier exists. Regex failed to remove specifier!")

@ -0,0 +1,69 @@
import re
from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import create_patch
# Indicates the recommended versions for replacement
REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"]
# group:
# 0: ^ > >= < <= (optional)
# 1: ' ' (optional)
# 2: version number
# 3: version number
# 4: version number
PATTERN = re.compile('(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)')
def format(slither, result):
elements = result['elements']
versions_used = []
for element in elements:
versions_used.append(''.join(element['type_specific_fields']['directive'][1:]))
solc_version_replace = _analyse_versions(versions_used)
for element in elements:
_patch(slither, result, element['source_mapping']['filename_absolute'], solc_version_replace,
element['source_mapping']['start'],
element['source_mapping']['start'] + element['source_mapping']['length'])
def _analyse_versions(used_solc_versions):
replace_solc_versions = list()
for version in used_solc_versions:
replace_solc_versions.append(_determine_solc_version_replacement(version))
if not all(version == replace_solc_versions[0] for version in replace_solc_versions):
raise FormatImpossible("Multiple incompatible versions!")
else:
return replace_solc_versions[0]
def _determine_solc_version_replacement(used_solc_version):
versions = PATTERN.findall(used_solc_version)
if len(versions) == 1:
version = versions[0]
minor_version = '.'.join(version[2:])[2]
if minor_version == '4':
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version == '5':
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
else:
raise FormatImpossible("Unknown version!")
elif len(versions) == 2:
version_right = versions[1]
minor_version_right = '.'.join(version_right[2:])[2]
if minor_version_right == '4':
# Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version_right in ['5', '6']:
# Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
def _patch(slither, result, in_file, pragma, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
create_patch(result,
in_file,
int(modify_loc_start),
int(modify_loc_end),
old_str_of_interest,
pragma)

@ -0,0 +1,59 @@
import re
from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import create_patch
# Indicates the recommended versions for replacement
REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"]
# group:
# 0: ^ > >= < <= (optional)
# 1: ' ' (optional)
# 2: version number
# 3: version number
# 4: version number
PATTERN = re.compile('(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)')
def format(slither, result):
elements = result['elements']
for element in elements:
solc_version_replace = _determine_solc_version_replacement(
''.join(element['type_specific_fields']['directive'][1:]))
_patch(slither, result, element['source_mapping']['filename_absolute'], solc_version_replace,
element['source_mapping']['start'], element['source_mapping']['start'] +
element['source_mapping']['length'])
def _determine_solc_version_replacement(used_solc_version):
versions = PATTERN.findall(used_solc_version)
if len(versions) == 1:
version = versions[0]
minor_version = '.'.join(version[2:])[2]
if minor_version == '4':
# Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version == '5':
# Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
else:
raise FormatImpossible(f"Unknown version {versions}")
elif len(versions) == 2:
version_right = versions[1]
minor_version_right = '.'.join(version_right[2:])[2]
if minor_version_right == '4':
# Replace with 0.4.25
return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ';'
elif minor_version_right in ['5','6']:
# Replace with 0.5.3
return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ';'
def _patch(slither, result, in_file, solc_version, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
create_patch(result,
in_file,
int(modify_loc_start),
int(modify_loc_end),
old_str_of_interest,
solc_version)

@ -0,0 +1,5 @@
from slither.exceptions import SlitherException
class FormatImpossible(SlitherException): pass
class FormatError(SlitherException): pass

@ -0,0 +1,42 @@
import re
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
target_contract = slither.get_contract_from_name(element['type_specific_fields']['parent']['name'])
if target_contract:
function = target_contract.get_function_from_signature(element['type_specific_fields']['signature'])
if function:
_patch(slither,
result,
element['source_mapping']['filename_absolute'],
int(function.parameters_src.source_mapping['start']),
int(function.returns_src.source_mapping['start']))
def _patch(slither, result, in_file, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Search for 'public' keyword which is in-between the function name and modifier name (if present)
# regex: 'public' could have spaces around or be at the end of the line
m = re.search(r'((\spublic)\s+)|(\spublic)$|(\)public)$', old_str_of_interest.decode('utf-8'))
if m is None:
# No visibility specifier exists; public by default.
create_patch(result,
in_file,
# start after the function definition's closing paranthesis
modify_loc_start + len(old_str_of_interest.decode('utf-8').split(')')[0]) + 1,
# end is same as start because we insert the keyword `external` at that location
modify_loc_start + len(old_str_of_interest.decode('utf-8').split(')')[0]) + 1,
"",
" external") # replace_text is `external`
else:
create_patch(result,
in_file,
# start at the keyword `public`
modify_loc_start + m.span()[0] + 1,
# end after the keyword `public` = start + len('public'')
modify_loc_start + m.span()[0] + 1 + len('public'),
"public",
"external")

@ -0,0 +1,609 @@
import re
import logging
from slither.slithir.operations import Send, Transfer, OperationWithLValue, HighLevelCall, LowLevelCall, \
InternalCall, InternalDynamicCall
from slither.core.declarations import Modifier
from slither.core.solidity_types import UserDefinedType, MappingType
from slither.core.declarations import Enum, Contract, Structure, Function
from slither.core.solidity_types.elementary_type import ElementaryTypeName
from slither.core.variables.local_variable import LocalVariable
from slither.formatters.exceptions import FormatError, FormatImpossible
from slither.formatters.utils.patches import create_patch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format')
def format(slither, result):
elements = result['elements']
for element in elements:
target = element['additional_fields']['target']
convention = element['additional_fields']['convention']
if convention == "l_O_I_should_not_be_used":
# l_O_I_should_not_be_used cannot be automatically patched
logger.info(f'The following naming convention cannot be patched: \n{result["description"]}')
continue
_patch(slither, result, element, target)
# endregion
###################################################################################
###################################################################################
# region Conventions
###################################################################################
###################################################################################
KEY = 'ALL_NAMES_USED'
# https://solidity.readthedocs.io/en/v0.5.11/miscellaneous.html#reserved-keywords
SOLIDITY_KEYWORDS = ['abstract', 'after', 'alias', 'apply', 'auto', 'case', 'catch', 'copyof', 'default', 'define',
'final', 'immutable', 'implements', 'in', 'inline', 'let', 'macro', 'match', 'mutable', 'null',
'of', 'override', 'partial', 'promise', 'reference', 'relocatable', 'sealed', 'sizeof', 'static',
'supports', 'switch', 'try', 'typedef', 'typeof', 'unchecked']
# https://solidity.readthedocs.io/en/v0.5.11/miscellaneous.html#language-grammar
SOLIDITY_KEYWORDS += ['pragma', 'import', 'contract', 'library', 'contract', 'function', 'using', 'struct', 'enum',
'public', 'private', 'internal', 'external', 'calldata', 'memory', 'modifier', 'view', 'pure',
'constant', 'storage', 'for', 'if', 'while', 'break', 'return', 'throw', 'else', 'type']
SOLIDITY_KEYWORDS += ElementaryTypeName
def _name_already_use(slither, name):
# Do not convert to a name used somewhere else
if not KEY in slither.context:
all_names = set()
for contract in slither.contracts_derived:
all_names = all_names.union(set([st.name for st in contract.structures]))
all_names = all_names.union(set([f.name for f in contract.functions_and_modifiers]))
all_names = all_names.union(set([e.name for e in contract.enums]))
all_names = all_names.union(set([s.name for s in contract.state_variables]))
for function in contract.functions:
all_names = all_names.union(set([v.name for v in function.variables]))
slither.context[KEY] = all_names
return name in slither.context[KEY]
def _convert_CapWords(original_name, slither):
name = original_name.capitalize()
while '_' in name:
offset = name.find('_')
if len(name) > offset:
name = name[0:offset] + name[offset+1].upper() + name[offset+1:]
if _name_already_use(slither, name):
raise FormatImpossible(f'{original_name} cannot be converted to {name} (already used)')
if name in SOLIDITY_KEYWORDS:
raise FormatImpossible(f'{original_name} cannot be converted to {name} (Solidity keyword)')
return name
def _convert_mixedCase(original_name, slither):
name = original_name
if isinstance(name, bytes):
name = name.decode('utf8')
while '_' in name:
offset = name.find('_')
if len(name) > offset:
name = name[0:offset] + name[offset + 1].upper() + name[offset + 2:]
name = name[0].lower() + name[1:]
if _name_already_use(slither, name):
raise FormatImpossible(f'{original_name} cannot be converted to {name} (already used)')
if name in SOLIDITY_KEYWORDS:
raise FormatImpossible(f'{original_name} cannot be converted to {name} (Solidity keyword)')
return name
def _convert_UPPER_CASE_WITH_UNDERSCORES(name, slither):
if _name_already_use(slither, name.upper()):
raise FormatImpossible(f'{name} cannot be converted to {name.upper()} (already used)')
if name.upper() in SOLIDITY_KEYWORDS:
raise FormatImpossible(f'{name} cannot be converted to {name.upper()} (Solidity keyword)')
return name.upper()
conventions ={
"CapWords":_convert_CapWords,
"mixedCase":_convert_mixedCase,
"UPPER_CASE_WITH_UNDERSCORES":_convert_UPPER_CASE_WITH_UNDERSCORES
}
# endregion
###################################################################################
###################################################################################
# region Helpers
###################################################################################
###################################################################################
def _get_from_contract(slither, element, name, getter):
contract_name = element['type_specific_fields']['parent']['name']
contract = slither.get_contract_from_name(contract_name)
return getattr(contract, getter)(name)
# endregion
###################################################################################
###################################################################################
# region Patch dispatcher
###################################################################################
###################################################################################
def _patch(slither, result, element, _target):
if _target == "contract":
target = slither.get_contract_from_name(element['name'])
elif _target == "structure":
target = _get_from_contract(slither, element, element['name'], 'get_structure_from_name')
elif _target == "event":
target = _get_from_contract(slither, element, element['name'], 'get_event_from_name')
elif _target == "function":
# Avoid constructor (FP?)
if element['name'] != element['type_specific_fields']['parent']['name']:
function_sig = element['type_specific_fields']['signature']
target = _get_from_contract(slither, element, function_sig, 'get_function_from_signature')
elif _target == "modifier":
modifier_sig = element['type_specific_fields']['signature']
target = _get_from_contract(slither, element, modifier_sig, 'get_modifier_from_signature')
elif _target == "parameter":
contract_name = element['type_specific_fields']['parent']['type_specific_fields']['parent']['name']
function_sig = element['type_specific_fields']['parent']['type_specific_fields']['signature']
param_name = element['name']
contract = slither.get_contract_from_name(contract_name)
function = contract.get_function_from_signature(function_sig)
target = function.get_local_variable_from_name(param_name)
elif _target in ["variable", "variable_constant"]:
# Local variable
if element['type_specific_fields']['parent'] == 'function':
contract_name = element['type_specific_fields']['parent']['type_specific_fields']['parent']['name']
function_sig = element['type_specific_fields']['parent']['type_specific_fields']['signature']
var_name = element['name']
contract = slither.get_contract_from_name(contract_name)
function = contract.get_function_from_signature(function_sig)
target = function.get_local_variable_from_name(var_name)
# State variable
else:
target = _get_from_contract(slither, element, element['name'], 'get_state_variable_from_name')
elif _target == "enum":
target = _get_from_contract(slither, element, element['name'], 'get_enum_from_canonical_name')
else:
raise FormatError("Unknown naming convention! " + _target)
_explore(slither,
result,
target,
conventions[element['additional_fields']['convention']])
# endregion
###################################################################################
###################################################################################
# region Explore functions
###################################################################################
###################################################################################
# group 1: beginning of the from type
# group 2: beginning of the to type
# nested mapping are within the group 1
#RE_MAPPING = '[ ]*mapping[ ]*\([ ]*([\=\>\(\) a-zA-Z0-9\._\[\]]*)[ ]*=>[ ]*([a-zA-Z0-9\._\[\]]*)\)'
RE_MAPPING_FROM = b'([a-zA-Z0-9\._\[\]]*)'
RE_MAPPING_TO = b'([\=\>\(\) a-zA-Z0-9\._\[\]\ ]*)'
RE_MAPPING = b'[ ]*mapping[ ]*\([ ]*' + RE_MAPPING_FROM + b'[ ]*' + b'=>' + b'[ ]*'+ RE_MAPPING_TO + b'\)'
def _is_var_declaration(slither, filename, start):
'''
Detect usage of 'var ' for Solidity < 0.5
:param slither:
:param filename:
:param start:
:return:
'''
v = 'var '
return slither.source_code[filename][start:start + len(v)] == v
def _explore_type(slither, result, target, convert, type, filename_source_code, start, end):
if isinstance(type, UserDefinedType):
# Patch type based on contract/enum
if isinstance(type.type, (Enum, Contract)):
if type.type == target:
old_str = type.type.name
new_str = convert(old_str, slither)
loc_start = start
if _is_var_declaration(slither, filename_source_code, start):
loc_end = loc_start + len('var')
else:
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
else:
# Patch type based on structure
assert isinstance(type.type, Structure)
if type.type == target:
old_str = type.type.name
new_str = convert(old_str, slither)
loc_start = start
if _is_var_declaration(slither, filename_source_code, start):
loc_end = loc_start + len('var')
else:
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
# Structure contain a list of elements, that might need patching
# .elems return a list of VariableStructure
_explore_variables_declaration(slither,
type.type.elems.values(),
result,
target,
convert)
if isinstance(type, MappingType):
# Mapping has three steps:
# Convert the "from" type
# Convert the "to" type
# Convert nested type in the "to"
# Ex: mapping (mapping (badName => uint) => uint)
# Do the comparison twice, so we can factor together the re matching
# mapping can only have elementary type in type_from
if isinstance(type.type_to, (UserDefinedType, MappingType)) or target in [type.type_from, type.type_to]:
full_txt_start = start
full_txt_end = end
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
re_match = re.match(RE_MAPPING, full_txt)
assert re_match
if type.type_from == target:
old_str = type.type_from.name
new_str = convert(old_str, slither)
loc_start = start + re_match.start(1)
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
if type.type_to == target:
old_str = type.type_to.name
new_str = convert(old_str, slither)
loc_start = start + re_match.start(2)
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
if isinstance(type.type_to, (UserDefinedType, MappingType)):
loc_start = start + re_match.start(2)
loc_end = start + re_match.end(2)
_explore_type(slither,
result,
target,
convert,
type.type_to,
filename_source_code,
loc_start,
loc_end)
def _explore_variables_declaration(slither, variables, result, target, convert, patch_comment=False):
for variable in variables:
# First explore the type of the variable
filename_source_code = variable.source_mapping['filename_absolute']
full_txt_start = variable.source_mapping['start']
full_txt_end = full_txt_start + variable.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
_explore_type(slither,
result,
target,
convert,
variable.type,
filename_source_code,
full_txt_start,
variable.source_mapping['start'] + variable.source_mapping['length'])
# If the variable is the target
if variable == target:
old_str = variable.name
new_str = convert(old_str, slither)
loc_start = full_txt_start + full_txt.find(old_str.encode('utf8'))
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
# Patch comment only makes sense for local variable declaration in the parameter list
if patch_comment and isinstance(variable, LocalVariable):
if 'lines' in variable.source_mapping and variable.source_mapping['lines']:
func = variable.function
end_line = func.source_mapping['lines'][0]
if variable in func.parameters:
idx = len(func.parameters) - func.parameters.index(variable) + 1
first_line = end_line - idx - 2
potential_comments = slither.source_code[filename_source_code].encode('utf8')
potential_comments = potential_comments.splitlines(keepends=True)[first_line:end_line-1]
idx_beginning = func.source_mapping['start']
idx_beginning += - func.source_mapping['starting_column'] + 1
idx_beginning += - sum([len(c) for c in potential_comments])
old_comment = f'@param {old_str}'.encode('utf8')
for line in potential_comments:
idx = line.find(old_comment)
if idx >=0:
loc_start = idx + idx_beginning
loc_end = loc_start + len(old_comment)
new_comment = f'@param {new_str}'.encode('utf8')
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_comment,
new_comment)
break
idx_beginning += len(line)
def _explore_modifiers_calls(slither, function, result, target, convert):
for modifier in function.modifiers_statements:
for node in modifier.nodes:
if node.irs:
_explore_irs(slither, node.irs, result, target, convert)
for modifier in function.explicit_base_constructor_calls_statements:
for node in modifier.nodes:
if node.irs:
_explore_irs(slither, node.irs, result, target, convert)
def _explore_structures_declaration(slither, structures, result, target, convert):
for st in structures:
# Explore the variable declared within the structure (VariableStructure)
_explore_variables_declaration(slither, st.elems.values(), result, target, convert)
# If the structure is the target
if st == target:
old_str = st.name
new_str = convert(old_str, slither)
filename_source_code = st.source_mapping['filename_absolute']
full_txt_start = st.source_mapping['start']
full_txt_end = full_txt_start + st.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
# The name is after the space
matches = re.finditer(b'struct[ ]*', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_events_declaration(slither, events, result, target, convert):
for event in events:
# Explore the parameters
_explore_variables_declaration(slither, event.elems, result, target, convert)
# If the event is the target
if event == target:
filename_source_code = event.source_mapping['filename_absolute']
old_str = event.name
new_str = convert(old_str, slither)
loc_start = event.source_mapping['start']
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def get_ir_variables(ir):
vars = ir.read
if isinstance(ir, (InternalCall, InternalDynamicCall, HighLevelCall)):
vars += [ir.function]
if isinstance(ir, (HighLevelCall, Send, LowLevelCall, Transfer)):
vars += [ir.call_value]
if isinstance(ir, (HighLevelCall, LowLevelCall)):
vars += [ir.call_gas]
if isinstance(ir, OperationWithLValue):
vars += [ir.lvalue]
return [v for v in vars if v]
def _explore_irs(slither, irs, result, target, convert):
if irs is None:
return
for ir in irs:
for v in get_ir_variables(ir):
if target == v or (
isinstance(target, Function) and isinstance(v, Function) and
v.canonical_name == target.canonical_name):
source_mapping = ir.expression.source_mapping
filename_source_code = source_mapping['filename_absolute']
full_txt_start = source_mapping['start']
full_txt_end = full_txt_start + source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
if not target.name.encode('utf8') in full_txt:
raise FormatError(f'{target} not found in {full_txt} ({source_mapping}')
old_str = target.name.encode('utf8')
new_str = convert(old_str, slither)
counter = 0
# Can be found multiple time on the same IR
# We patch one by one
while old_str in full_txt:
target_found_at = full_txt.find((old_str))
full_txt = full_txt[target_found_at+1:]
counter += target_found_at
loc_start = full_txt_start + counter
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_functions(slither, functions, result, target, convert):
for function in functions:
_explore_variables_declaration(slither, function.variables, result, target, convert, True)
_explore_modifiers_calls(slither, function, result, target, convert)
_explore_irs(slither, function.all_slithir_operations(), result, target, convert)
if isinstance(target, Function) and function.canonical_name == target.canonical_name:
old_str = function.name
new_str = convert(old_str, slither)
filename_source_code = function.source_mapping['filename_absolute']
full_txt_start = function.source_mapping['start']
full_txt_end = full_txt_start + function.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
# The name is after the space
if isinstance(target, Modifier):
matches = re.finditer(b'modifier([ ]*)', full_txt)
else:
matches = re.finditer(b'function([ ]*)', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_enums(slither, enums, result, target, convert):
for enum in enums:
if enum == target:
old_str = enum.name
new_str = convert(old_str, slither)
filename_source_code = enum.source_mapping['filename_absolute']
full_txt_start = enum.source_mapping['start']
full_txt_end = full_txt_start + enum.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
# The name is after the space
matches = re.finditer(b'enum([ ]*)', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore_contract(slither, contract, result, target, convert):
_explore_variables_declaration(slither, contract.state_variables, result, target, convert)
_explore_structures_declaration(slither, contract.structures, result, target, convert)
_explore_functions(slither, contract.functions_and_modifiers, result, target, convert)
_explore_enums(slither, contract.enums, result, target, convert)
if contract == target:
filename_source_code = contract.source_mapping['filename_absolute']
full_txt_start = contract.source_mapping['start']
full_txt_end = full_txt_start + contract.source_mapping['length']
full_txt = slither.source_code[filename_source_code].encode('utf8')[full_txt_start:full_txt_end]
old_str = contract.name
new_str = convert(old_str, slither)
# The name is after the space
matches = re.finditer(b'contract[ ]*', full_txt)
# Look for the end offset of the largest list of ' '
loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end()
loc_end = loc_start + len(old_str)
create_patch(result,
filename_source_code,
loc_start,
loc_end,
old_str,
new_str)
def _explore(slither, result, target, convert):
for contract in slither.contracts_derived:
_explore_contract(slither, contract, result, target, convert)
# endregion

@ -0,0 +1,43 @@
import os
import difflib
from collections import defaultdict
def create_patch(result, file, start, end, old_str, new_str):
if isinstance(old_str, bytes):
old_str = old_str.decode('utf8')
if isinstance(new_str, bytes):
new_str = new_str.decode('utf8')
p = {"start": start,
"end": end,
"old_string": old_str,
"new_string": new_str
}
if 'patches' not in result:
result['patches'] = defaultdict(list)
if p not in result['patches'][file]:
result['patches'][file].append(p)
def apply_patch(original_txt, patch, offset):
patched_txt = original_txt[:int(patch['start'] + offset)]
patched_txt += patch['new_string'].encode('utf8')
patched_txt += original_txt[int(patch['end'] + offset):]
# Keep the diff of text added or sub, in case of multiple patches
patch_length_diff = len(patch['new_string']) - (patch['end'] - patch['start'])
return patched_txt, patch_length_diff + offset
def create_diff(slither, original_txt, patched_txt, filename):
if slither.crytic_compile:
relative_path = slither.crytic_compile.filename_lookup(filename).relative
relative_path = os.path.join('.', relative_path)
else:
relative_path = filename
diff = difflib.unified_diff(original_txt.decode('utf8').splitlines(False),
patched_txt.decode('utf8').splitlines(False),
fromfile=relative_path,
tofile=relative_path,
lineterm='')
return '\n'.join(list(diff)) + '\n'

@ -0,0 +1,38 @@
import re
from slither.formatters.exceptions import FormatError, FormatImpossible
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
# TODO: decide if this should be changed in the constant detector
contract_name = element['type_specific_fields']['parent']['name']
contract = slither.get_contract_from_name(contract_name)
var = contract.get_state_variable_from_name(element['name'])
if not var.expression:
raise FormatImpossible(f'{var.name} is uninitialized and cannot become constant.')
_patch(slither, result, element['source_mapping']['filename_absolute'],
element['name'],
"constant " + element['name'],
element['source_mapping']['start'],
element['source_mapping']['start'] + element['source_mapping']['length'])
def _patch(slither, result, in_file, match_text, replace_text, modify_loc_start, modify_loc_end):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Add keyword `constant` before the variable name
(new_str_of_interest, num_repl) = re.subn(match_text, replace_text, old_str_of_interest.decode('utf-8'), 1)
if num_repl != 0:
create_patch(result,
in_file,
modify_loc_start,
modify_loc_end,
old_str_of_interest,
new_str_of_interest)
else:
raise FormatError("State variable not found?!")

@ -0,0 +1,28 @@
from slither.formatters.utils.patches import create_patch
def format(slither, result):
elements = result['elements']
for element in elements:
if element['type'] == "variable":
_patch(slither,
result,
element['source_mapping']['filename_absolute'],
element['source_mapping']['start'])
def _patch(slither, result, in_file, modify_loc_start):
in_file_str = slither.source_code[in_file].encode('utf8')
old_str_of_interest = in_file_str[modify_loc_start:]
old_str = old_str_of_interest.decode('utf-8').partition(';')[0]\
+ old_str_of_interest.decode('utf-8').partition(';')[1]
create_patch(result,
in_file,
int(modify_loc_start),
# Remove the entire declaration until the semicolon
int(modify_loc_start + len(old_str_of_interest.decode('utf-8').partition(';')[0]) + 1),
old_str,
"")

@ -34,6 +34,7 @@ class Slither(SlitherSolc):
filter_paths (list(str)): list of path to filter (default []) filter_paths (list(str)): list of path to filter (default [])
triage_mode (bool): if true, switch to triage mode (default false) triage_mode (bool): if true, switch to triage mode (default false)
exclude_dependencies (bool): if true, exclude results that are only related to dependencies exclude_dependencies (bool): if true, exclude results that are only related to dependencies
generate_patches (bool): if true, patches are generated (json output only)
truffle_ignore (bool): ignore truffle.js presence (default false) truffle_ignore (bool): ignore truffle.js presence (default false)
truffle_build_directory (str): build truffle directory (default 'build/contracts') truffle_build_directory (str): build truffle directory (default 'build/contracts')
@ -64,6 +65,9 @@ class Slither(SlitherSolc):
self._parse_contracts_from_loaded_json(ast, path) self._parse_contracts_from_loaded_json(ast, path)
self._add_source_code(path) self._add_source_code(path)
if kwargs.get('generate_patches', False):
self.generate_patches = True
self._detectors = [] self._detectors = []
self._printers = [] self._printers = []

@ -45,10 +45,14 @@ def convert_expression(expression, node):
if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]: if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]:
cst = Constant(expression.value, expression.type) cst = Constant(expression.value, expression.type)
result = [Condition(cst)] cond = Condition(cst)
cond.set_expression(expression)
result = [cond]
return result return result
if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]: if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]:
result = [Condition(expression.value)] cond = Condition(expression.value)
cond.set_expression(expression)
result = [cond]
return result return result
@ -60,11 +64,15 @@ def convert_expression(expression, node):
if result: if result:
if node.type in [NodeType.IF, NodeType.IFLOOP]: if node.type in [NodeType.IF, NodeType.IFLOOP]:
assert isinstance(result[-1], (OperationWithLValue)) assert isinstance(result[-1], (OperationWithLValue))
result.append(Condition(result[-1].lvalue)) cond = Condition(result[-1].lvalue)
cond.set_expression(expression)
result.append(cond)
elif node.type == NodeType.RETURN: elif node.type == NodeType.RETURN:
# May return None # May return None
if isinstance(result[-1], (OperationWithLValue)): if isinstance(result[-1], (OperationWithLValue)):
result.append(Return(result[-1].lvalue)) r = Return(result[-1].lvalue)
r.set_expression(expression)
result.append(r)
return result return result
@ -326,6 +334,7 @@ def _convert_type_contract(ir, slither):
assignment = Assignment(ir.lvalue, assignment = Assignment(ir.lvalue,
Constant(str(bytecode)), Constant(str(bytecode)),
ElementaryType('bytes')) ElementaryType('bytes'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('bytes')) assignment.lvalue.set_type(ElementaryType('bytes'))
return assignment return assignment
if ir.variable_right == 'runtimeCode': if ir.variable_right == 'runtimeCode':
@ -338,12 +347,14 @@ def _convert_type_contract(ir, slither):
assignment = Assignment(ir.lvalue, assignment = Assignment(ir.lvalue,
Constant(str(bytecode)), Constant(str(bytecode)),
ElementaryType('bytes')) ElementaryType('bytes'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('bytes')) assignment.lvalue.set_type(ElementaryType('bytes'))
return assignment return assignment
if ir.variable_right == 'name': if ir.variable_right == 'name':
assignment = Assignment(ir.lvalue, assignment = Assignment(ir.lvalue,
Constant(contract.name), Constant(contract.name),
ElementaryType('string')) ElementaryType('string'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('string')) assignment.lvalue.set_type(ElementaryType('string'))
return assignment return assignment
@ -446,14 +457,18 @@ def propagate_types(ir, node):
# TODO we should convert the reference to a temporary if the member is a length or a balance # TODO we should convert the reference to a temporary if the member is a length or a balance
if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)): if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)):
length = Length(ir.variable_left, ir.lvalue) length = Length(ir.variable_left, ir.lvalue)
length.set_expression(ir.expression)
length.lvalue.points_to = ir.variable_left length.lvalue.points_to = ir.variable_left
return length return length
if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType): if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType):
return Balance(ir.variable_left, ir.lvalue) b = Balance(ir.variable_left, ir.lvalue)
b.set_expression(ir.expression)
return b
if ir.variable_right == 'selector' and isinstance(ir.variable_left.type, Function): if ir.variable_right == 'selector' and isinstance(ir.variable_left.type, Function):
assignment = Assignment(ir.lvalue, assignment = Assignment(ir.lvalue,
Constant(str(get_function_id(ir.variable_left.type.full_name))), Constant(str(get_function_id(ir.variable_left.type.full_name))),
ElementaryType('bytes4')) ElementaryType('bytes4'))
assignment.set_expression(ir.expression)
assignment.lvalue.set_type(ElementaryType('bytes4')) assignment.lvalue.set_type(ElementaryType('bytes4'))
return assignment return assignment
if isinstance(ir.variable_left, TemporaryVariable) and isinstance(ir.variable_left.type, TypeInformation): if isinstance(ir.variable_left, TemporaryVariable) and isinstance(ir.variable_left.type, TypeInformation):
@ -534,6 +549,7 @@ def extract_tmp_call(ins, contract):
if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type) call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
call.set_expression(ins.expression)
call.call_id = ins.call_id call.call_id = ins.call_id
return call return call
if isinstance(ins.ori, Member): if isinstance(ins.ori, Member):
@ -541,23 +557,28 @@ def extract_tmp_call(ins, contract):
if ins.ori.variable_left in contract.inheritance + [contract]: if ins.ori.variable_left in contract.inheritance + [contract]:
if str(ins.ori.variable_right) in [f.name for f in contract.functions]: if str(ins.ori.variable_right) in [f.name for f in contract.functions]:
internalcall = InternalCall((ins.ori.variable_right, ins.ori.variable_left.name), ins.nbr_arguments, ins.lvalue, ins.type_call) internalcall = InternalCall((ins.ori.variable_right, ins.ori.variable_left.name), ins.nbr_arguments, ins.lvalue, ins.type_call)
internalcall.set_expression(ins.expression)
internalcall.call_id = ins.call_id internalcall.call_id = ins.call_id
return internalcall return internalcall
if str(ins.ori.variable_right) in [f.name for f in contract.events]: if str(ins.ori.variable_right) in [f.name for f in contract.events]:
eventcall = EventCall(ins.ori.variable_right) eventcall = EventCall(ins.ori.variable_right)
eventcall.call_id = ins.call_id eventcall.set_expression(ins.expression)
return eventcall eventcall.call_id = ins.call_id
return eventcall
if isinstance(ins.ori.variable_left, Contract): if isinstance(ins.ori.variable_left, Contract):
st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right) st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
if st: if st:
op = NewStructure(st, ins.lvalue) op = NewStructure(st, ins.lvalue)
op.set_expression(ins.expression)
op.call_id = ins.call_id op.call_id = ins.call_id
return op return op
libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
libcall.set_expression(ins.expression)
libcall.call_id = ins.call_id libcall.call_id = ins.call_id
return libcall return libcall
msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
msgcall.call_id = ins.call_id msgcall.call_id = ins.call_id
msgcall.set_expression(ins.expression)
return msgcall return msgcall
if isinstance(ins.ori, TmpCall): if isinstance(ins.ori, TmpCall):
@ -567,29 +588,42 @@ def extract_tmp_call(ins, contract):
if str(ins.called) == 'block.blockhash': if str(ins.called) == 'block.blockhash':
ins.called = SolidityFunction('blockhash(uint256)') ins.called = SolidityFunction('blockhash(uint256)')
elif str(ins.called) == 'this.balance': elif str(ins.called) == 'this.balance':
return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call) s = SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)
s.set_expression(ins.expression)
return s
if isinstance(ins.called, SolidityFunction): if isinstance(ins.called, SolidityFunction):
return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call) s = SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)
s.set_expression(ins.expression)
return s
if isinstance(ins.ori, TmpNewElementaryType): if isinstance(ins.ori, TmpNewElementaryType):
return NewElementaryType(ins.ori.type, ins.lvalue) n = NewElementaryType(ins.ori.type, ins.lvalue)
n.set_expression(ins.expression)
return n
if isinstance(ins.ori, TmpNewContract): if isinstance(ins.ori, TmpNewContract):
op = NewContract(Constant(ins.ori.contract_name), ins.lvalue) op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
op.set_expression(ins.expression)
op.call_id = ins.call_id op.call_id = ins.call_id
return op return op
if isinstance(ins.ori, TmpNewArray): if isinstance(ins.ori, TmpNewArray):
return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue) n = NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
n.set_expression(ins.expression)
return n
if isinstance(ins.called, Structure): if isinstance(ins.called, Structure):
op = NewStructure(ins.called, ins.lvalue) op = NewStructure(ins.called, ins.lvalue)
op.set_expression(ins.expression)
op.call_id = ins.call_id op.call_id = ins.call_id
op.set_expression(ins.expression)
return op return op
if isinstance(ins.called, Event): if isinstance(ins.called, Event):
return EventCall(ins.called.name) e = EventCall(ins.called.name)
e.set_expression(ins.expression)
return e
if isinstance(ins.called, Contract): if isinstance(ins.called, Contract):
# Called a base constructor, where there is no constructor # Called a base constructor, where there is no constructor
@ -598,6 +632,7 @@ def extract_tmp_call(ins, contract):
internalcall = InternalCall(ins.called.constructor, ins.nbr_arguments, ins.lvalue, internalcall = InternalCall(ins.called.constructor, ins.nbr_arguments, ins.lvalue,
ins.type_call) ins.type_call)
internalcall.call_id = ins.call_id internalcall.call_id = ins.call_id
internalcall.set_expression(ins.expression)
return internalcall return internalcall
@ -628,11 +663,15 @@ def convert_to_low_level(ir):
""" """
if ir.function_name == 'transfer': if ir.function_name == 'transfer':
assert len(ir.arguments) == 1 assert len(ir.arguments) == 1
prev_ir = ir
ir = Transfer(ir.destination, ir.arguments[0]) ir = Transfer(ir.destination, ir.arguments[0])
ir.set_expression(prev_ir.expression)
return ir return ir
elif ir.function_name == 'send': elif ir.function_name == 'send':
assert len(ir.arguments) == 1 assert len(ir.arguments) == 1
prev_ir = ir
ir = Send(ir.destination, ir.arguments[0], ir.lvalue) ir = Send(ir.destination, ir.arguments[0], ir.lvalue)
ir.set_expression(prev_ir.expression)
ir.lvalue.set_type(ElementaryType('bool')) ir.lvalue.set_type(ElementaryType('bool'))
return ir return ir
elif ir.function_name in ['call', elif ir.function_name in ['call',
@ -648,6 +687,7 @@ def convert_to_low_level(ir):
new_ir.call_value = ir.call_value new_ir.call_value = ir.call_value
new_ir.arguments = ir.arguments new_ir.arguments = ir.arguments
new_ir.lvalue.set_type(ElementaryType('bool')) new_ir.lvalue.set_type(ElementaryType('bool'))
new_ir.set_expression(ir.expression)
return new_ir return new_ir
raise SlithIRError('Incorrect conversion to low level {}'.format(ir)) raise SlithIRError('Incorrect conversion to low level {}'.format(ir))
@ -668,6 +708,7 @@ def convert_to_solidity_func(ir):
call = SolidityFunction('abi.{}()'.format(ir.function_name)) call = SolidityFunction('abi.{}()'.format(ir.function_name))
new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call) new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call)
new_ir.arguments = ir.arguments new_ir.arguments = ir.arguments
new_ir.set_expression(ir.expression)
if isinstance(call.return_type, list) and len(call.return_type) == 1: if isinstance(call.return_type, list) and len(call.return_type) == 1:
new_ir.lvalue.set_type(call.return_type[0]) new_ir.lvalue.set_type(call.return_type[0])
else: else:
@ -693,9 +734,12 @@ def convert_to_push(ir, node):
val = TemporaryVariable(node) val = TemporaryVariable(node)
operation = InitArray(ir.arguments[0], val) operation = InitArray(ir.arguments[0], val)
operation.set_expression(ir.expression)
ret.append(operation) ret.append(operation)
prev_ir = ir
ir = Push(ir.destination, val) ir = Push(ir.destination, val)
ir.set_expression(prev_ir.expression)
length = Literal(len(operation.init_values), 'uint256') length = Literal(len(operation.init_values), 'uint256')
t = operation.init_values[0].type t = operation.init_values[0].type
@ -705,18 +749,22 @@ def convert_to_push(ir, node):
if lvalue: if lvalue:
length = Length(ir.array, lvalue) length = Length(ir.array, lvalue)
length.set_expression(ir.expression)
length.lvalue.points_to = ir.lvalue length.lvalue.points_to = ir.lvalue
ret.append(length) ret.append(length)
return ret return ret
prev_ir = ir
ir = Push(ir.destination, ir.arguments[0]) ir = Push(ir.destination, ir.arguments[0])
ir.set_expression(prev_ir.expression)
if lvalue: if lvalue:
ret = [] ret = []
ret.append(ir) ret.append(ir)
length = Length(ir.array, lvalue) length = Length(ir.array, lvalue)
length.set_expression(ir.expression)
length.lvalue.points_to = ir.lvalue length.lvalue.points_to = ir.lvalue
ret.append(length) ret.append(length)
return ret return ret
@ -732,6 +780,7 @@ def look_for_library(contract, ir, node, using_for, t):
ir.nbr_arguments, ir.nbr_arguments,
ir.lvalue, ir.lvalue,
ir.type_call) ir.type_call)
lib_call.set_expression(ir.expression)
lib_call.call_gas = ir.call_gas lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments lib_call.arguments = [ir.destination] + ir.arguments
new_ir = convert_type_library_call(lib_call, lib_contract) new_ir = convert_type_library_call(lib_call, lib_contract)

@ -1,5 +1,6 @@
import abc import abc
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.children.child_expression import ChildExpression
from slither.core.children.child_node import ChildNode from slither.core.children.child_node import ChildNode
from slither.utils.utils import unroll from slither.utils.utils import unroll
@ -21,7 +22,7 @@ class AbstractOperation(abc.ABC):
""" """
pass pass
class Operation(Context, ChildNode, AbstractOperation): class Operation(Context, ChildExpression, ChildNode, AbstractOperation):
@property @property
def used(self): def used(self):

@ -1,10 +1,6 @@
import logging
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable from slither.slithir.utils.utils import is_valid_lvalue
from slither.slithir.variables import TupleVariable
from slither.core.declarations.function import Function
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue
class Phi(OperationWithLValue): class Phi(OperationWithLValue):

@ -181,6 +181,8 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
tuple_variables_instances, tuple_variables_instances,
all_local_variables_instances) all_local_variables_instances)
new_ir.set_expression(ir.expression)
update_lvalue(new_ir, update_lvalue(new_ir,
node, node,
local_variables_instances, local_variables_instances,

@ -275,7 +275,9 @@ def parse_call(expression, caller_context):
arguments = [parse_expression(a, caller_context) for a in children[1::]] arguments = [parse_expression(a, caller_context) for a in children[1::]]
if isinstance(called, SuperCallExpression): if isinstance(called, SuperCallExpression):
return SuperCallExpression(called, arguments, type_return) sp = SuperCallExpression(called, arguments, type_return)
sp.set_offset(expression['src'], caller_context.slither)
return sp
call_expression = CallExpression(called, arguments, type_return) call_expression = CallExpression(called, arguments, type_return)
call_expression.set_offset(src, caller_context.slither) call_expression.set_offset(src, caller_context.slither)
return call_expression return call_expression
@ -386,7 +388,7 @@ def parse_expression(expression, caller_context):
return binary_op return binary_op
elif name == 'FunctionCall': elif name == 'FunctionCall':
return parse_call(expression, caller_context) return parse_call(expression, caller_context)
elif name == 'TupleExpression': elif name == 'TupleExpression':
""" """
@ -401,7 +403,7 @@ def parse_expression(expression, caller_context):
Note: this is only possible with Solidity >= 0.4.12 Note: this is only possible with Solidity >= 0.4.12
""" """
if is_compact_ast: if is_compact_ast:
expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']] expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']]
else: else:
if 'children' not in expression : if 'children' not in expression :
attributes = expression['attributes'] attributes = expression['attributes']

@ -0,0 +1,14 @@
# .format files are the output files produced by slither-format
# .patch files are the output files produced by slither-format
*.format
*.patch
# Temporary files (Emacs backup files ending in tilde and others)
*~
*.err
*.out

@ -0,0 +1,89 @@
import sys
import argparse
from slither import Slither
from slither.utils.command_line import read_config_file
import logging
from .slither_format import slither_format
from crytic_compile import cryticparser
logging.basicConfig()
logger = logging.getLogger("Slither").setLevel(logging.INFO)
# Slither detectors for which slither-format currently works
available_detectors = ["unused-state",
"solc-version",
"pragma",
"naming-convention",
"external-function",
"constable-states",
"constant-function"]
detectors_to_run = []
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='slither_format',
usage='slither_format filename')
parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.')
parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False)
parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False)
parser.add_argument('--version',
help='displays the current version',
version='0.1.0',
action='version')
parser.add_argument('--config-file',
help='Provide a config file (default: slither.config.json)',
action='store',
dest='config_file',
default='slither.config.json')
group_detector = parser.add_argument_group('Detectors')
group_detector.add_argument('--detect',
help='Comma-separated list of detectors, defaults to all, '
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_run',
default='all')
group_detector.add_argument('--exclude',
help='Comma-separated list of detectors to exclude,'
'available detectors: {}'.format(
', '.join(d for d in available_detectors)),
action='store',
dest='detectors_to_exclude',
default='all')
cryticparser.init(parser)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
return parser.parse_args()
def main():
# ------------------------------
# Usage: python3 -m slither_format filename
# Example: python3 -m slither_format contract.sol
# ------------------------------
# Parse all arguments
args = parse_args()
read_config_file(args)
# Perform slither analysis on the given filename
slither = Slither(args.filename, **vars(args))
# Format the input files based on slither analysis
slither_format(slither, **vars(args))
if __name__ == '__main__':
main()

@ -0,0 +1,151 @@
import logging
from pathlib import Path
from slither.detectors.variables.unused_state_variables import UnusedStateVars
from slither.detectors.attributes.incorrect_solc import IncorrectSolc
from slither.detectors.attributes.constant_pragma import ConstantPragma
from slither.detectors.naming_convention.naming_convention import NamingConvention
from slither.detectors.functions.external_function import ExternalFunction
from slither.detectors.variables.possible_const_state_variables import ConstCandidateStateVars
from slither.detectors.attributes.const_functions import ConstantFunctions
from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Slither.Format')
all_detectors = {
'unused-state': UnusedStateVars,
'solc-version': IncorrectSolc,
'pragma': ConstantPragma,
'naming-convention': NamingConvention,
'external-function': ExternalFunction,
'constable-states' : ConstCandidateStateVars,
'constant-function': ConstantFunctions
}
def slither_format(slither, **kwargs):
''''
Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all
'''
detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'),
kwargs.get('detectors_to_exclude', ''))
for detector in detectors_to_run:
slither.register_detector(detector)
slither.generate_patches = True
detector_results = slither.run_detectors()
detector_results = [x for x in detector_results if x] # remove empty results
detector_results = [item for sublist in detector_results for item in sublist] # flatten
export = Path('crytic-export', 'patches')
export.mkdir(parents=True, exist_ok=True)
counter_result = 0
logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.'))
for result in detector_results:
if not 'patches' in result:
continue
one_line_description = result["description"].split("\n")[0]
export_result = Path(export, f'{counter_result}')
export_result.mkdir(parents=True, exist_ok=True)
counter_result += 1
counter = 0
logger.info(f'Issue: {one_line_description}')
logger.info(f'Generated: ({export_result})')
for file, diff, in result['patches_diff'].items():
filename = f'fix_{counter}.patch'
path = Path(export_result, filename)
logger.info(f'\t- {filename}')
with open(path, 'w') as f:
f.write(diff)
counter += 1
# endregion
###################################################################################
###################################################################################
# region Detectors
###################################################################################
###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude):
# If detectors are specified, run only these ones
cls_detectors_to_run = []
exclude = detectors_to_exclude.split(',')
if detectors_to_run == 'all':
for d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
exclude = detectors_to_exclude.split(',')
for d in detectors_to_run.split(','):
if d in all_detectors:
if d in exclude:
continue
cls_detectors_to_run.append(all_detectors[d])
else:
raise Exception('Error: {} is not a detector'.format(d))
return cls_detectors_to_run
# endregion
###################################################################################
###################################################################################
# region Debug functions
###################################################################################
###################################################################################
def print_patches(number_of_slither_results, patches):
logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0
for file in patches:
number_of_patches += len(patches[file])
logger.info("Number of patches: " + str(number_of_patches))
for file in patches:
logger.info("Patch file: " + file)
for patch in patches[file]:
logger.info("Detector: " + patch['detector'])
logger.info("Old string: " + patch['old_string'].replace("\n",""))
logger.info("New string: " + patch['new_string'].replace("\n",""))
logger.info("Location start: " + str(patch['start']))
logger.info("Location end: " + str(patch['end']))
def print_patches_json(number_of_slither_results, patches):
print('{',end='')
print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",')
print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',')
print("\"Patchlets\":" + '[')
for index, file in enumerate(patches):
if index > 0:
print(',')
print('{',end='')
print("\"Patch file\":" + '"' + file + '",')
print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',')
print("\"Patches\":" + '[')
for index, patch in enumerate(patches[file]):
if index > 0:
print(',')
print('{',end='')
print("\"Detector\":" + '"' + patch['detector'] + '",')
print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",')
print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",')
print("\"Location start\":" + '"' + str(patch['start']) + '",')
print("\"Location end\":" + '"' + str(patch['end']) + '"')
if 'overlaps' in patch:
print("\"Overlaps\":" + "Yes")
print('}',end='')
print(']',end='')
print('}',end='')
print(']',end='')
print('}')

@ -1,6 +1,8 @@
import os import os
import logging import logging
import json import json
import os
import logging
from collections import defaultdict from collections import defaultdict
from prettytable import PrettyTable from prettytable import PrettyTable
from crytic_compile.cryticparser.defaults import defaults_flag_in_config as defaults_flag_in_config_crytic_compile from crytic_compile.cryticparser.defaults import defaults_flag_in_config as defaults_flag_in_config_crytic_compile
@ -29,6 +31,7 @@ defaults_flag_in_config = {
'json-types': ','.join(DEFAULT_JSON_OUTPUT_TYPES), 'json-types': ','.join(DEFAULT_JSON_OUTPUT_TYPES),
'disable_color': False, 'disable_color': False,
'filter_paths': None, 'filter_paths': None,
'generate_patches': False,
# debug command # debug command
'legacy_ast': False, 'legacy_ast': False,
'ignore_return_value': False, 'ignore_return_value': False,

@ -67,7 +67,9 @@ class ExpressionToSlithIR(ExpressionVisitor):
self._result = [] self._result = []
self._visit_expression(self.expression) self._visit_expression(self.expression)
if node.type == NodeType.RETURN: if node.type == NodeType.RETURN:
self._result.append(Return(get(self.expression))) r = Return(get(self.expression))
r.set_expression(expression)
self._result.append(r)
for ir in self._result: for ir in self._result:
ir.set_node(node) ir.set_node(node)
@ -83,6 +85,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
for idx in range(len(left)): for idx in range(len(left)):
if not left[idx] is None: if not left[idx] is None:
operation = convert_assignment(left[idx], right[idx], expression.type, expression.expression_return_type) operation = convert_assignment(left[idx], right[idx], expression.type, expression.expression_return_type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, None) set_val(expression, None)
else: else:
@ -90,6 +93,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
for idx in range(len(left)): for idx in range(len(left)):
if not left[idx] is None: if not left[idx] is None:
operation = Unpack(left[idx], right, idx) operation = Unpack(left[idx], right, idx)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, None) set_val(expression, None)
else: else:
@ -97,10 +101,12 @@ class ExpressionToSlithIR(ExpressionVisitor):
# uint8[2] var = [1,2]; # uint8[2] var = [1,2];
if isinstance(right, list): if isinstance(right, list):
operation = InitArray(right, left) operation = InitArray(right, left)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, left) set_val(expression, left)
else: else:
operation = convert_assignment(left, right, expression.type, expression.expression_return_type) operation = convert_assignment(left, right, expression.type, expression.expression_return_type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
# Return left to handle # Return left to handle
# a = b = 1; # a = b = 1;
@ -112,6 +118,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
operation = Binary(val, left, right, expression.type) operation = Binary(val, left, right, expression.type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, val) set_val(expression, val)
@ -120,6 +127,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
args = [get(a) for a in expression.arguments if a] args = [get(a) for a in expression.arguments if a]
for arg in args: for arg in args:
arg_ = Argument(arg) arg_ = Argument(arg)
arg_.set_expression(expression)
self._result.append(arg_) self._result.append(arg_)
if isinstance(called, Function): if isinstance(called, Function):
# internal call # internal call
@ -130,11 +138,10 @@ class ExpressionToSlithIR(ExpressionVisitor):
else: else:
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
internal_call = InternalCall(called, len(args), val, expression.type_call) internal_call = InternalCall(called, len(args), val, expression.type_call)
internal_call.set_expression(expression)
self._result.append(internal_call) self._result.append(internal_call)
set_val(expression, val) set_val(expression, val)
else: else:
val = TemporaryVariable(self._node)
# If tuple # If tuple
if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()':
val = TupleVariable(self._node) val = TupleVariable(self._node)
@ -142,6 +149,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
message_call = TmpCall(called, len(args), val, expression.type_call) message_call = TmpCall(called, len(args), val, expression.type_call)
message_call.set_expression(expression)
self._result.append(message_call) self._result.append(message_call)
set_val(expression, val) set_val(expression, val)
@ -165,8 +173,10 @@ class ExpressionToSlithIR(ExpressionVisitor):
init_array_right = left init_array_right = left
left = init_array_val left = init_array_val
operation = InitArray(init_array_right, init_array_val) operation = InitArray(init_array_right, init_array_val)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
operation = Index(val, left, right, expression.type) operation = Index(val, left, right, expression.type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, val) set_val(expression, val)
@ -178,18 +188,21 @@ class ExpressionToSlithIR(ExpressionVisitor):
expr = get(expression.expression) expr = get(expression.expression)
val = ReferenceVariable(self._node) val = ReferenceVariable(self._node)
member = Member(expr, Constant(expression.member_name), val) member = Member(expr, Constant(expression.member_name), val)
member.set_expression(expression)
self._result.append(member) self._result.append(member)
set_val(expression, val) set_val(expression, val)
def _post_new_array(self, expression): def _post_new_array(self, expression):
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
operation = TmpNewArray(expression.depth, expression.array_type, val) operation = TmpNewArray(expression.depth, expression.array_type, val)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, val) set_val(expression, val)
def _post_new_contract(self, expression): def _post_new_contract(self, expression):
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
operation = TmpNewContract(expression.contract_name, val) operation = TmpNewContract(expression.contract_name, val)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, val) set_val(expression, val)
@ -197,6 +210,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
# TODO unclear if this is ever used? # TODO unclear if this is ever used?
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
operation = TmpNewElementaryType(expression.type, val) operation = TmpNewElementaryType(expression.type, val)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, val) set_val(expression, val)
@ -212,6 +226,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
expr = get(expression.expression) expr = get(expression.expression)
val = TemporaryVariable(self._node) val = TemporaryVariable(self._node)
operation = TypeConversion(val, expr, expression.type) operation = TypeConversion(val, expr, expression.type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, val) set_val(expression, val)
@ -220,32 +235,40 @@ class ExpressionToSlithIR(ExpressionVisitor):
if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]:
lvalue = TemporaryVariable(self._node) lvalue = TemporaryVariable(self._node)
operation = Unary(lvalue, value, expression.type) operation = Unary(lvalue, value, expression.type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, lvalue) set_val(expression, lvalue)
elif expression.type in [UnaryOperationType.DELETE]: elif expression.type in [UnaryOperationType.DELETE]:
operation = Delete(value, value) operation = Delete(value, value)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, value) set_val(expression, value)
elif expression.type in [UnaryOperationType.PLUSPLUS_PRE]: elif expression.type in [UnaryOperationType.PLUSPLUS_PRE]:
operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION) operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, value) set_val(expression, value)
elif expression.type in [UnaryOperationType.MINUSMINUS_PRE]: elif expression.type in [UnaryOperationType.MINUSMINUS_PRE]:
operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION) operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, value) set_val(expression, value)
elif expression.type in [UnaryOperationType.PLUSPLUS_POST]: elif expression.type in [UnaryOperationType.PLUSPLUS_POST]:
lvalue = TemporaryVariable(self._node) lvalue = TemporaryVariable(self._node)
operation = Assignment(lvalue, value, value.type) operation = Assignment(lvalue, value, value.type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION) operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, lvalue) set_val(expression, lvalue)
elif expression.type in [UnaryOperationType.MINUSMINUS_POST]: elif expression.type in [UnaryOperationType.MINUSMINUS_POST]:
lvalue = TemporaryVariable(self._node) lvalue = TemporaryVariable(self._node)
operation = Assignment(lvalue, value, value.type) operation = Assignment(lvalue, value, value.type)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION) operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, lvalue) set_val(expression, lvalue)
elif expression.type in [UnaryOperationType.PLUS_PRE]: elif expression.type in [UnaryOperationType.PLUS_PRE]:
@ -253,6 +276,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
elif expression.type in [UnaryOperationType.MINUS_PRE]: elif expression.type in [UnaryOperationType.MINUS_PRE]:
lvalue = TemporaryVariable(self._node) lvalue = TemporaryVariable(self._node)
operation = Binary(lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION) operation = Binary(lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION)
operation.set_expression(expression)
self._result.append(operation) self._result.append(operation)
set_val(expression, lvalue) set_val(expression, lvalue)
else: else:

Loading…
Cancel
Save