Improve code readability of major modules:

- Group methods together
 - Use region/endregion format
Remove unused import
pull/172/head
Josselin 6 years ago
parent d22e53fdfe
commit 5f0cae387f
  1. 290
      slither/__main__.py
  2. 87
      slither/analyses/data_dependency/data_dependency.py
  3. 249
      slither/core/cfg/node.py
  4. 405
      slither/core/declarations/contract.py
  5. 568
      slither/core/declarations/function.py
  6. 1
      slither/core/solidity_types/function_type.py
  7. 909
      slither/slithir/convert.py
  8. 257
      slither/slithir/utils/ssa.py
  9. 273
      slither/solc_parsing/declarations/contract.py
  10. 379
      slither/solc_parsing/declarations/function.py
  11. 215
      slither/solc_parsing/expressions/expression_parsing.py
  12. 31
      slither/solc_parsing/slitherSolc.py

@ -1,32 +1,38 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import inspect
import argparse import argparse
import glob import glob
import inspect
import json import json
import logging import logging
import os import os
import subprocess
import sys import sys
import traceback import traceback
import subprocess
from pkg_resources import iter_entry_points, require from pkg_resources import iter_entry_points, require
from slither.detectors import all_detectors from slither.detectors import all_detectors
from slither.printers import all_printers
from slither.detectors.abstract_detector import (AbstractDetector, from slither.detectors.abstract_detector import (AbstractDetector,
DetectorClassification) DetectorClassification)
from slither.printers import all_printers
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from slither.slither import Slither from slither.slither import Slither
from slither.utils.colors import red from slither.utils.colors import red, set_colorization_enabled
from slither.utils.command_line import output_to_markdown, output_detectors, output_printers, output_detectors_json, output_wiki from slither.utils.command_line import (output_detectors,
from slither.utils.colors import set_colorization_enabled output_detectors_json, output_printers,
output_to_markdown, output_wiki)
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
###################################################################################
###################################################################################
# region Process functions
###################################################################################
###################################################################################
def process(filename, args, detector_classes, printer_classes): def process(filename, args, detector_classes, printer_classes):
""" """
The core high-level code for running Slither static analysis. The core high-level code for running Slither static analysis.
@ -104,10 +110,23 @@ def process_files(filenames, args, detector_classes, printer_classes):
slither = Slither(all_contracts, args.solc, args.disable_solc_warnings, args.solc_args) slither = Slither(all_contracts, args.solc, args.disable_solc_warnings, args.solc_args)
return _process(slither, detector_classes, printer_classes) return _process(slither, detector_classes, printer_classes)
# endregion
###################################################################################
###################################################################################
# region Output
###################################################################################
###################################################################################
def output_json(results, filename): def output_json(results, filename):
with open(filename, 'w', encoding='utf8') as f: with open(filename, 'w', encoding='utf8') as f:
json.dump(results, f) json.dump(results, f)
# endregion
###################################################################################
###################################################################################
# region Exit
###################################################################################
###################################################################################
def exit(results): def exit(results):
if not results: if not results:
@ -115,6 +134,13 @@ def exit(results):
sys.exit(len(results)) sys.exit(len(results))
# endregion
###################################################################################
###################################################################################
# region Detectors and printers
###################################################################################
###################################################################################
def get_detectors_and_printers(): def get_detectors_and_printers():
""" """
NOTE: This contains just a few detectors and printers that we made public. NOTE: This contains just a few detectors and printers that we made public.
@ -144,87 +170,69 @@ def get_detectors_and_printers():
return detectors, printers return detectors, printers
def main(): def choose_detectors(args, all_detector_classes):
detectors, printers = get_detectors_and_printers() # If detectors are specified, run only these ones
main_impl(all_detector_classes=detectors, all_printer_classes=printers)
def main_impl(all_detector_classes, all_printer_classes):
"""
:param all_detector_classes: A list of all detectors that can be included/excluded.
:param all_printer_classes: A list of all printers that can be included.
"""
args = parse_args(all_detector_classes, all_printer_classes)
# Set colorization option
set_colorization_enabled(not args.disable_color)
printer_classes = choose_printers(args, all_printer_classes)
detector_classes = choose_detectors(args, all_detector_classes)
default_log = logging.INFO if not args.debug else logging.DEBUG
for (l_name, l_level) in [('Slither', default_log), detectors_to_run = []
('Contract', default_log), detectors = {d.ARGUMENT: d for d in all_detector_classes}
('Function', default_log),
('Node', default_log),
('Parsing', default_log),
('Detectors', default_log),
('FunctionSolc', default_log),
('ExpressionParsing', default_log),
('TypeParsing', default_log),
('SSA_Conversion', default_log),
('Printers', default_log)]:
l = logging.getLogger(l_name)
l.setLevel(l_level)
try: if args.detectors_to_run == 'all':
filename = args.filename detectors_to_run = all_detector_classes
detectors_excluded = args.detectors_to_exclude.split(',')
for d in detectors:
if d in detectors_excluded:
detectors_to_run.remove(detectors[d])
else:
for d in args.detectors_to_run.split(','):
if d in detectors:
detectors_to_run.append(detectors[d])
else:
raise Exception('Error: {} is not a detector'.format(d))
detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT)
return detectors_to_run
globbed_filenames = glob.glob(filename, recursive=True) if args.exclude_informational:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.INFORMATIONAL]
if args.exclude_low:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.LOW]
if args.exclude_medium:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.MEDIUM]
if args.exclude_high:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.HIGH]
if args.detectors_to_exclude:
detectors_to_run = [d for d in detectors_to_run if
d.ARGUMENT not in args.detectors_to_exclude]
if os.path.isfile(filename): detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT)
(results, number_contracts) = process(filename, args, detector_classes, printer_classes)
elif os.path.isfile(os.path.join(filename, 'truffle.js')) or os.path.isfile(os.path.join(filename, 'truffle-config.js')): return detectors_to_run
(results, number_contracts) = process_truffle(filename, args, detector_classes, printer_classes)
elif os.path.isdir(filename) or len(globbed_filenames) > 0:
extension = "*.sol" if not args.solc_ast else "*.json"
filenames = glob.glob(os.path.join(filename, extension))
if not filenames:
filenames = globbed_filenames
number_contracts = 0
results = []
if args.splitted and args.solc_ast:
(results, number_contracts) = process_files(filenames, args, detector_classes, printer_classes)
else:
for filename in filenames:
(results_tmp, number_contracts_tmp) = process(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results += results_tmp
def choose_printers(args, all_printer_classes):
printers_to_run = []
else: # disable default printer
raise Exception("Unrecognised file/dir path: '#{filename}'".format(filename=filename)) if args.printers_to_run == '':
return []
if args.json: printers = {p.ARGUMENT: p for p in all_printer_classes}
output_json(results, args.json) for p in args.printers_to_run.split(','):
# Dont print the number of result for printers if p in printers:
if number_contracts == 0: printers_to_run.append(printers[p])
logger.warn(red('No contract was analyzed'))
if printer_classes:
logger.info('%s analyzed (%d contracts)', filename, number_contracts)
else: else:
logger.info('%s analyzed (%d contracts), %d result(s) found', filename, number_contracts, len(results)) raise Exception('Error: {} is not a printer'.format(p))
exit(results) return printers_to_run
except Exception:
logging.error('Error in %s' % args.filename)
logging.error(traceback.format_exc())
sys.exit(-1)
# endregion
###################################################################################
###################################################################################
# region Command line parsing
###################################################################################
###################################################################################
def parse_args(detector_classes, printer_classes): def parse_args(detector_classes, printer_classes):
parser = argparse.ArgumentParser(description='Slither', parser = argparse.ArgumentParser(description='Slither',
@ -405,63 +413,99 @@ class OutputWiki(argparse.Action):
output_wiki(detectors, values) output_wiki(detectors, values)
parser.exit() parser.exit()
def choose_detectors(args, all_detector_classes):
# If detectors are specified, run only these ones
detectors_to_run = [] # endregion
detectors = {d.ARGUMENT: d for d in all_detector_classes} ###################################################################################
###################################################################################
# region Main
###################################################################################
###################################################################################
if args.detectors_to_run == 'all': def main():
detectors_to_run = all_detector_classes detectors, printers = get_detectors_and_printers()
detectors_excluded = args.detectors_to_exclude.split(',')
for d in detectors:
if d in detectors_excluded:
detectors_to_run.remove(detectors[d])
else:
for d in args.detectors_to_run.split(','):
if d in detectors:
detectors_to_run.append(detectors[d])
else:
raise Exception('Error: {} is not a detector'.format(d))
detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT)
return detectors_to_run
if args.exclude_informational: main_impl(all_detector_classes=detectors, all_printer_classes=printers)
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.INFORMATIONAL]
if args.exclude_low:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.LOW]
if args.exclude_medium:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.MEDIUM]
if args.exclude_high:
detectors_to_run = [d for d in detectors_to_run if
d.IMPACT != DetectorClassification.HIGH]
if args.detectors_to_exclude:
detectors_to_run = [d for d in detectors_to_run if
d.ARGUMENT not in args.detectors_to_exclude]
detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT)
return detectors_to_run def main_impl(all_detector_classes, all_printer_classes):
"""
:param all_detector_classes: A list of all detectors that can be included/excluded.
:param all_printer_classes: A list of all printers that can be included.
"""
args = parse_args(all_detector_classes, all_printer_classes)
# Set colorization option
set_colorization_enabled(not args.disable_color)
def choose_printers(args, all_printer_classes): printer_classes = choose_printers(args, all_printer_classes)
printers_to_run = [] detector_classes = choose_detectors(args, all_detector_classes)
default_log = logging.INFO if not args.debug else logging.DEBUG
for (l_name, l_level) in [('Slither', default_log),
('Contract', default_log),
('Function', default_log),
('Node', default_log),
('Parsing', default_log),
('Detectors', default_log),
('FunctionSolc', default_log),
('ExpressionParsing', default_log),
('TypeParsing', default_log),
('SSA_Conversion', default_log),
('Printers', default_log)]:
l = logging.getLogger(l_name)
l.setLevel(l_level)
try:
filename = args.filename
globbed_filenames = glob.glob(filename, recursive=True)
if os.path.isfile(filename):
(results, number_contracts) = process(filename, args, detector_classes, printer_classes)
elif os.path.isfile(os.path.join(filename, 'truffle.js')) or os.path.isfile(os.path.join(filename, 'truffle-config.js')):
(results, number_contracts) = process_truffle(filename, args, detector_classes, printer_classes)
elif os.path.isdir(filename) or len(globbed_filenames) > 0:
extension = "*.sol" if not args.solc_ast else "*.json"
filenames = glob.glob(os.path.join(filename, extension))
if not filenames:
filenames = globbed_filenames
number_contracts = 0
results = []
if args.splitted and args.solc_ast:
(results, number_contracts) = process_files(filenames, args, detector_classes, printer_classes)
else:
for filename in filenames:
(results_tmp, number_contracts_tmp) = process(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results += results_tmp
# disable default printer
if args.printers_to_run == '':
return []
printers = {p.ARGUMENT: p for p in all_printer_classes}
for p in args.printers_to_run.split(','):
if p in printers:
printers_to_run.append(printers[p])
else: else:
raise Exception('Error: {} is not a printer'.format(p)) raise Exception("Unrecognised file/dir path: '#{filename}'".format(filename=filename))
return printers_to_run
if args.json:
output_json(results, args.json)
# Dont print the number of result for printers
if number_contracts == 0:
logger.warn(red('No contract was analyzed'))
if printer_classes:
logger.info('%s analyzed (%d contracts)', filename, number_contracts)
else:
logger.info('%s analyzed (%d contracts), %d result(s) found', filename, number_contracts, len(results))
exit(results)
except Exception:
logging.error('Error in %s' % args.filename)
logging.error(traceback.format_exc())
sys.exit(-1)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
# endregion

@ -2,40 +2,18 @@
Compute the data depenency between all the SSA variables Compute the data depenency between all the SSA variables
""" """
from slither.core.declarations import Contract, Function from slither.core.declarations import Contract, Function
from slither.slithir.operations import Index, Member, OperationWithLValue
from slither.slithir.variables import ReferenceVariable, Constant
from slither.slithir.variables import (Constant, LocalIRVariable, StateIRVariable,
ReferenceVariable, TemporaryVariable,
TupleVariable)
from slither.core.declarations.solidity_variables import \ from slither.core.declarations.solidity_variables import \
SolidityVariableComposed SolidityVariableComposed
from slither.slithir.operations import Index, OperationWithLValue
from slither.slithir.variables import (Constant, LocalIRVariable,
ReferenceVariable, StateIRVariable,
TemporaryVariable)
KEY_SSA = "DATA_DEPENDENCY_SSA" ###################################################################################
KEY_NON_SSA = "DATA_DEPENDENCY" ###################################################################################
# region User APIs
# Only for unprotected functions ###################################################################################
KEY_SSA_UNPROTECTED = "DATA_DEPENDENCY_SSA_UNPROTECTED" ###################################################################################
KEY_NON_SSA_UNPROTECTED = "DATA_DEPENDENCY_UNPROTECTED"
KEY_INPUT = "DATA_DEPENDENCY_INPUT"
KEY_INPUT_SSA = "DATA_DEPENDENCY_INPUT_SSA"
def pprint_dependency(context):
print('#### SSA ####')
context = context.context
for k, values in context[KEY_SSA].items():
print('{} ({}):'.format(k, id(k)))
for v in values:
print('\t- {}'.format(v))
print('#### NON SSA ####')
for k, values in context[KEY_NON_SSA].items():
print('{} ({}):'.format(k, hex(id(k))))
for v in values:
print('\t- {} ({})'.format(v, hex(id(v))))
def is_dependent(variable, source, context, only_unprotected=False): def is_dependent(variable, source, context, only_unprotected=False):
''' '''
@ -119,6 +97,53 @@ def is_tainted_ssa(variable, context, only_unprotected=False):
taints |= GENERIC_TAINT taints |= GENERIC_TAINT
return variable in taints or any(is_dependent_ssa(variable, t, context, only_unprotected) for t in taints) return variable in taints or any(is_dependent_ssa(variable, t, context, only_unprotected) for t in taints)
# endregion
###################################################################################
###################################################################################
# region Module constants
###################################################################################
###################################################################################
KEY_SSA = "DATA_DEPENDENCY_SSA"
KEY_NON_SSA = "DATA_DEPENDENCY"
# Only for unprotected functions
KEY_SSA_UNPROTECTED = "DATA_DEPENDENCY_SSA_UNPROTECTED"
KEY_NON_SSA_UNPROTECTED = "DATA_DEPENDENCY_UNPROTECTED"
KEY_INPUT = "DATA_DEPENDENCY_INPUT"
KEY_INPUT_SSA = "DATA_DEPENDENCY_INPUT_SSA"
# endregion
###################################################################################
###################################################################################
# region Debug
###################################################################################
###################################################################################
def pprint_dependency(context):
print('#### SSA ####')
context = context.context
for k, values in context[KEY_SSA].items():
print('{} ({}):'.format(k, id(k)))
for v in values:
print('\t- {}'.format(v))
print('#### NON SSA ####')
for k, values in context[KEY_NON_SSA].items():
print('{} ({}):'.format(k, hex(id(k))))
for v in values:
print('\t- {} ({})'.format(v, hex(id(v))))
# endregion
###################################################################################
###################################################################################
# region Analyses
###################################################################################
###################################################################################
def compute_dependency(slither): def compute_dependency(slither):
slither.context[KEY_INPUT] = set() slither.context[KEY_INPUT] = set()

@ -5,25 +5,29 @@ import logging
from slither.core.children.child_function import ChildFunction from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.core.declarations.solidity_variables import (SolidityFunction, from slither.core.declarations.solidity_variables import SolidityVariable
SolidityVariable)
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.convert import convert_expression from slither.slithir.convert import convert_expression
from slither.slithir.operations import (Balance, HighLevelCall, Index, from slither.slithir.operations import (Balance, HighLevelCall, Index,
InternalCall, Length, LibraryCall, InternalCall, Length, LibraryCall,
LowLevelCall, Member, LowLevelCall, Member,
OperationWithLValue, SolidityCall, Phi, PhiCallback) OperationWithLValue, Phi, PhiCallback,
from slither.slithir.variables import (Constant, ReferenceVariable, SolidityCall)
TemporaryVariable, TupleVariable, StateIRVariable, LocalIRVariable) from slither.slithir.variables import (Constant, LocalIRVariable,
from slither.visitors.expression.expression_printer import ExpressionPrinter ReferenceVariable, StateIRVariable,
from slither.visitors.expression.read_var import ReadVar TemporaryVariable, TupleVariable)
from slither.visitors.expression.write_var import WriteVar
logger = logging.getLogger("Node") logger = logging.getLogger("Node")
###################################################################################
###################################################################################
# region NodeType
###################################################################################
###################################################################################
class NodeType: class NodeType:
ENTRYPOINT = 0x0 # no expression ENTRYPOINT = 0x0 # no expression
@ -89,10 +93,22 @@ class NodeType:
return 'END_LOOP' return 'END_LOOP'
return 'Unknown type {}'.format(hex(t)) return 'Unknown type {}'.format(hex(t))
# endregion
###################################################################################
###################################################################################
# region Utils
###################################################################################
###################################################################################
def link_nodes(n1, n2): def link_nodes(n1, n2):
n1.add_son(n2) n1.add_son(n2)
n2.add_father(n1) n2.add_father(n1)
# endregion
class Node(SourceMapping, ChildFunction): class Node(SourceMapping, ChildFunction):
""" """
Node class Node class
@ -159,68 +175,11 @@ class Node(SourceMapping, ChildFunction):
self._expression_vars_read = [] self._expression_vars_read = []
self._expression_calls = [] self._expression_calls = []
###################################################################################
@property ###################################################################################
def dominators(self): # region General's properties
''' ###################################################################################
Returns: ###################################################################################
set(Node)
'''
return self._dominators
@property
def immediate_dominator(self):
'''
Returns:
Node or None
'''
return self._immediate_dominator
@property
def dominance_frontier(self):
'''
Returns:
set(Node)
'''
return self._dominance_frontier
@property
def dominator_successors(self):
return self._dom_successors
@dominators.setter
def dominators(self, dom):
self._dominators = dom
@immediate_dominator.setter
def immediate_dominator(self, idom):
self._immediate_dominator = idom
@dominance_frontier.setter
def dominance_frontier(self, dom):
self._dominance_frontier = dom
@property
def phi_origins_local_variables(self):
return self._phi_origins_local_variables
@property
def phi_origins_state_variables(self):
return self._phi_origins_state_variables
def add_phi_origin_local_variable(self, variable, node):
if variable.name not in self._phi_origins_local_variables:
self._phi_origins_local_variables[variable.name] = (variable, set())
(v, nodes) = self._phi_origins_local_variables[variable.name]
assert v == variable
nodes.add(node)
def add_phi_origin_state_variable(self, variable, node):
if variable.canonical_name not in self._phi_origins_state_variables:
self._phi_origins_state_variables[variable.canonical_name] = (variable, set())
(v, nodes) = self._phi_origins_state_variables[variable.canonical_name]
assert v == variable
nodes.add(node)
@property @property
def slither(self): def slither(self):
@ -242,6 +201,13 @@ class Node(SourceMapping, ChildFunction):
def type(self, t): def type(self, t):
self._node_type = t self._node_type = t
# endregion
###################################################################################
###################################################################################
# region Variables
###################################################################################
###################################################################################
@property @property
def variables_read(self): def variables_read(self):
""" """
@ -291,8 +257,6 @@ class Node(SourceMapping, ChildFunction):
""" """
return list(self._ssa_local_vars_read) return list(self._ssa_local_vars_read)
@property @property
def variables_read_as_expression(self): def variables_read_as_expression(self):
return self._expression_vars_read return self._expression_vars_read
@ -347,6 +311,13 @@ class Node(SourceMapping, ChildFunction):
def variables_written_as_expression(self): def variables_written_as_expression(self):
return self._expression_vars_written return self._expression_vars_written
# endregion
###################################################################################
###################################################################################
# region Calls
###################################################################################
###################################################################################
@property @property
def internal_calls(self): def internal_calls(self):
""" """
@ -392,6 +363,13 @@ class Node(SourceMapping, ChildFunction):
def calls_as_expression(self): def calls_as_expression(self):
return list(self._expression_calls) return list(self._expression_calls)
# endregion
###################################################################################
###################################################################################
# region Expressions
###################################################################################
###################################################################################
@property @property
def expression(self): def expression(self):
""" """
@ -418,9 +396,12 @@ class Node(SourceMapping, ChildFunction):
""" """
return self._variable_declaration return self._variable_declaration
def __str__(self): # endregion
txt = NodeType.str(self._node_type) + ' '+ str(self.expression) ###################################################################################
return txt ###################################################################################
# region Summary information
###################################################################################
###################################################################################
def contains_require_or_assert(self): def contains_require_or_assert(self):
""" """
@ -449,6 +430,14 @@ class Node(SourceMapping, ChildFunction):
""" """
return self.contains_if(include_loop) or self.contains_require_or_assert() return self.contains_if(include_loop) or self.contains_require_or_assert()
# endregion
###################################################################################
###################################################################################
# region Graph
###################################################################################
###################################################################################
def add_father(self, father): def add_father(self, father):
""" Add a father node """ Add a father node
@ -474,7 +463,6 @@ class Node(SourceMapping, ChildFunction):
""" """
return list(self._fathers) return list(self._fathers)
def remove_father(self, father): def remove_father(self, father):
""" Remove the father node. Do nothing if the node is not a father """ Remove the father node. Do nothing if the node is not a father
@ -516,6 +504,13 @@ class Node(SourceMapping, ChildFunction):
""" """
return list(self._sons) return list(self._sons)
# endregion
###################################################################################
###################################################################################
# region SlithIR
###################################################################################
###################################################################################
@property @property
def irs(self): def irs(self):
""" Returns the slithIR representation """ Returns the slithIR representation
@ -559,6 +554,92 @@ class Node(SourceMapping, ChildFunction):
def _is_valid_slithir_var(var): def _is_valid_slithir_var(var):
return isinstance(var, (ReferenceVariable, TemporaryVariable, TupleVariable)) return isinstance(var, (ReferenceVariable, TemporaryVariable, TupleVariable))
# endregion
###################################################################################
###################################################################################
# region Dominators
###################################################################################
###################################################################################
@property
def dominators(self):
'''
Returns:
set(Node)
'''
return self._dominators
@property
def immediate_dominator(self):
'''
Returns:
Node or None
'''
return self._immediate_dominator
@property
def dominance_frontier(self):
'''
Returns:
set(Node)
'''
return self._dominance_frontier
@property
def dominator_successors(self):
return self._dom_successors
@dominators.setter
def dominators(self, dom):
self._dominators = dom
@immediate_dominator.setter
def immediate_dominator(self, idom):
self._immediate_dominator = idom
@dominance_frontier.setter
def dominance_frontier(self, dom):
self._dominance_frontier = dom
# endregion
###################################################################################
###################################################################################
# region Phi operation
###################################################################################
###################################################################################
@property
def phi_origins_local_variables(self):
return self._phi_origins_local_variables
@property
def phi_origins_state_variables(self):
return self._phi_origins_state_variables
def add_phi_origin_local_variable(self, variable, node):
if variable.name not in self._phi_origins_local_variables:
self._phi_origins_local_variables[variable.name] = (variable, set())
(v, nodes) = self._phi_origins_local_variables[variable.name]
assert v == variable
nodes.add(node)
def add_phi_origin_state_variable(self, variable, node):
if variable.canonical_name not in self._phi_origins_state_variables:
self._phi_origins_state_variables[variable.canonical_name] = (variable, set())
(v, nodes) = self._phi_origins_state_variables[variable.canonical_name]
assert v == variable
nodes.add(node)
# endregion
###################################################################################
###################################################################################
# region Analyses
###################################################################################
###################################################################################
def _find_read_write_call(self): def _find_read_write_call(self):
for ir in self.irs: for ir in self.irs:
@ -686,3 +767,17 @@ class Node(SourceMapping, ChildFunction):
self._vars_written += [v for v in vars_written if v not in self._vars_written] self._vars_written += [v for v in vars_written if v not in self._vars_written]
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
# endregion
###################################################################################
###################################################################################
# region Built in definitions
###################################################################################
###################################################################################
def __str__(self):
txt = NodeType.str(self._node_type) + ' '+ str(self.expression)
return txt
# endregion

@ -38,15 +38,11 @@ class Contract(ChildSlither, SourceMapping):
self._initial_state_variables = [] # ssa self._initial_state_variables = [] # ssa
def __eq__(self, other): ###################################################################################
if isinstance(other, str): ###################################################################################
return other == self.name # region General's properties
return NotImplemented ###################################################################################
###################################################################################
def __neq__(self, other):
if isinstance(other, str):
return other != self.name
return NotImplemented
@property @property
def name(self): def name(self):
@ -59,59 +55,114 @@ class Contract(ChildSlither, SourceMapping):
return self._id return self._id
@property @property
def inheritance(self): def contract_kind(self):
''' return self._kind
list(Contract): Inheritance list. Order: the first elem is the first father to be executed
''' # endregion
return list(self._inheritance) ###################################################################################
###################################################################################
# region Structures
###################################################################################
###################################################################################
@property @property
def immediate_inheritance(self): def structures(self):
''' '''
list(Contract): List of contracts immediately inherited from (fathers). Order: order of declaration. list(Structure): List of the structures
''' '''
return list(self._immediate_inheritance) return list(self._structures.values())
def structures_as_dict(self):
return self._structures
# endregion
###################################################################################
###################################################################################
# region Enums
###################################################################################
###################################################################################
@property @property
def inheritance_reverse(self): def enums(self):
return list(self._enums.values())
def enums_as_dict(self):
return self._enums
# endregion
###################################################################################
###################################################################################
# region Events
###################################################################################
###################################################################################
@property
def events(self):
''' '''
list(Contract): Inheritance list. Order: the last elem is the first father to be executed list(Event): List of the events
''' '''
return reversed(self._inheritance) return list(self._events.values())
def setInheritance(self, inheritance, immediate_inheritance, called_base_constructor_contracts): def events_as_dict(self):
self._inheritance = inheritance return self._events
self._immediate_inheritance = immediate_inheritance
self._explicit_base_constructor_calls = called_base_constructor_contracts # endregion
###################################################################################
###################################################################################
# region Using for
###################################################################################
###################################################################################
@property @property
def derived_contracts(self): def using_for(self):
return self._using_for
def reverse_using_for(self, name):
''' '''
list(Contract): Return the list of contracts derived from self Returns:
(list)
''' '''
candidates = self.slither.contracts return self._using_for[name]
return [c for c in candidates if self in c.inheritance]
# endregion
###################################################################################
###################################################################################
# region Variables
###################################################################################
###################################################################################
@property @property
def structures(self): def variables(self):
''' '''
list(Structure): List of the structures list(StateVariable): List of the state variables. Alias to self.state_variables
''' '''
return list(self._structures.values()) return list(self.state_variables)
def structures_as_dict(self): def variables_as_dict(self):
return self._structures return self._variables
@property @property
def enums(self): def state_variables(self):
return list(self._enums.values()) '''
list(StateVariable): List of the state variables.
def enums_as_dict(self): '''
return self._enums return list(self._variables.values())
@property
def slithir_variables(self):
'''
List all of the slithir variables (non SSA)
'''
slithir_variables = [f.slithir_variables for f in self.functions + self.modifiers]
slithir_variables = [item for sublist in slithir_variables for item in sublist]
return list(set(slithir_variables))
def modifiers_as_dict(self): # endregion
return self._modifiers ###################################################################################
###################################################################################
# region Constructors
###################################################################################
###################################################################################
@property @property
def constructor(self): def constructor(self):
@ -141,6 +192,25 @@ class Contract(ChildSlither, SourceMapping):
''' '''
return [func for func in self.functions if func.is_constructor] return [func for func in self.functions if func.is_constructor]
@property
def explicit_base_constructor_calls(self):
"""
list(Function): List of the base constructors called explicitly by this contract definition.
Base constructors called by any constructor definition will not be included.
Base constructors implicitly called by the contract definition (without
parenthesis) will not be included.
On "contract B is A(){..}" it returns the constructor of A
"""
return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor]
# endregion
###################################################################################
###################################################################################
# region Functions and Modifiers
###################################################################################
###################################################################################
@property @property
def functions(self): def functions(self):
@ -149,6 +219,9 @@ class Contract(ChildSlither, SourceMapping):
''' '''
return list(self._functions.values()) return list(self._functions.values())
def functions_as_dict(self):
return self._functions
@property @property
def functions_inherited(self): def functions_inherited(self):
''' '''
@ -170,19 +243,6 @@ class Contract(ChildSlither, SourceMapping):
''' '''
return [f for f in self.functions if f.visibility in ['public', 'external']] return [f for f in self.functions if f.visibility in ['public', 'external']]
@property
def explicit_base_constructor_calls(self):
"""
list(Function): List of the base constructors called explicitly by this contract definition.
Base constructors called by any constructor definition will not be included.
Base constructors implicitly called by the contract definition (without
parenthesis) will not be included.
On "contract B is A(){..}" it returns the constructor of A
"""
return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor]
@property @property
def modifiers(self): def modifiers(self):
''' '''
@ -190,6 +250,9 @@ class Contract(ChildSlither, SourceMapping):
''' '''
return list(self._modifiers.values()) return list(self._modifiers.values())
def modifiers_as_dict(self):
return self._modifiers
@property @property
def modifiers_inherited(self): def modifiers_inherited(self):
''' '''
@ -225,109 +288,53 @@ class Contract(ChildSlither, SourceMapping):
''' '''
return self.functions_not_inherited + self.modifiers_not_inherited return self.functions_not_inherited + self.modifiers_not_inherited
def get_functions_overridden_by(self, function): # endregion
''' ###################################################################################
Return the list of functions overriden by the function ###################################################################################
Args: # region Inheritance
(core.Function) ###################################################################################
Returns: ###################################################################################
list(core.Function)
'''
candidates = [c.functions_not_inherited for c in self.inheritance]
candidates = [candidate for sublist in candidates for candidate in sublist]
return [f for f in candidates if f.full_name == function.full_name]
@property
def all_functions_called(self):
'''
list(Function): List of functions reachable from the contract (include super)
'''
all_calls = [f.all_internal_calls() for f in self.functions + self.modifiers] + [self.functions + self.modifiers]
all_calls = [item for sublist in all_calls for item in sublist] + self.functions
all_calls = list(set(all_calls))
all_constructors = [c.constructor for c in self.inheritance]
all_constructors = list(set([c for c in all_constructors if c]))
all_calls = set(all_calls+all_constructors)
return [c for c in all_calls if isinstance(c, Function)]
@property
def all_state_variables_written(self):
'''
list(StateVariable): List all of the state variables written
'''
all_state_variables_written = [f.all_state_variables_written() for f in self.functions + self.modifiers]
all_state_variables_written = [item for sublist in all_state_variables_written for item in sublist]
return list(set(all_state_variables_written))
@property @property
def all_state_variables_read(self): def inheritance(self):
'''
list(StateVariable): List all of the state variables read
'''
all_state_variables_read = [f.all_state_variables_read() for f in self.functions + self.modifiers]
all_state_variables_read = [item for sublist in all_state_variables_read for item in sublist]
return list(set(all_state_variables_read))
@property
def slithir_variables(self):
'''
List all of the slithir variables (non SSA)
'''
slithir_variables = [f.slithir_variables for f in self.functions + self.modifiers]
slithir_variables = [item for sublist in slithir_variables for item in sublist]
return list(set(slithir_variables))
def functions_as_dict(self):
return self._functions
@property
def events(self):
''' '''
list(Event): List of the events list(Contract): Inheritance list. Order: the first elem is the first father to be executed
''' '''
return list(self._events.values()) return list(self._inheritance)
def events_as_dict(self):
return self._events
@property @property
def state_variables(self): def immediate_inheritance(self):
''' '''
list(StateVariable): List of the state variables. list(Contract): List of contracts immediately inherited from (fathers). Order: order of declaration.
''' '''
return list(self._variables.values()) return list(self._immediate_inheritance)
@property @property
def variables(self): def inheritance_reverse(self):
''' '''
list(StateVariable): List of the state variables. Alias to self.state_variables list(Contract): Inheritance list. Order: the last elem is the first father to be executed
''' '''
return list(self.state_variables) return reversed(self._inheritance)
def variables_as_dict(self): def setInheritance(self, inheritance, immediate_inheritance, called_base_constructor_contracts):
return self._variables self._inheritance = inheritance
self._immediate_inheritance = immediate_inheritance
self._explicit_base_constructor_calls = called_base_constructor_contracts
@property @property
def using_for(self): def derived_contracts(self):
return self._using_for
def reverse_using_for(self, name):
''' '''
Returns: list(Contract): Return the list of contracts derived from self
(list)
''' '''
return self._using_for[name] candidates = self.slither.contracts
return [c for c in candidates if self in c.inheritance]
@property
def contract_kind(self):
return self._kind
def __str__(self): # endregion
return self.name ###################################################################################
###################################################################################
# region Getters from/to object
###################################################################################
###################################################################################
def get_functions_reading_from_variable(self, variable): def get_functions_reading_from_variable(self, variable):
''' '''
@ -341,14 +348,6 @@ class Contract(ChildSlither, SourceMapping):
''' '''
return [f for f in self.functions if f.is_writing(variable)] return [f for f in self.functions if f.is_writing(variable)]
def is_signature_only(self):
""" Detect if the contract has only abstract functions
Returns:
bool: true if the function are abstract functions
"""
return all((not f.is_implemented) for f in self.functions)
def get_source_var_declaration(self, var): def get_source_var_declaration(self, var):
""" Return the source mapping where the variable is declared """ Return the source mapping where the variable is declared
@ -449,6 +448,85 @@ class Contract(ChildSlither, SourceMapping):
""" """
return next((e for e in self.enums if e.canonical_name == enum_name), None) return next((e for e in self.enums if e.canonical_name == enum_name), None)
def get_functions_overridden_by(self, function):
'''
Return the list of functions overriden by the function
Args:
(core.Function)
Returns:
list(core.Function)
'''
candidates = [c.functions_not_inherited for c in self.inheritance]
candidates = [candidate for sublist in candidates for candidate in sublist]
return [f for f in candidates if f.full_name == function.full_name]
# endregion
###################################################################################
###################################################################################
# region Recursive getters
###################################################################################
###################################################################################
@property
def all_functions_called(self):
'''
list(Function): List of functions reachable from the contract (include super)
'''
all_calls = [f.all_internal_calls() for f in self.functions + self.modifiers] + [self.functions + self.modifiers]
all_calls = [item for sublist in all_calls for item in sublist] + self.functions
all_calls = list(set(all_calls))
all_constructors = [c.constructor for c in self.inheritance]
all_constructors = list(set([c for c in all_constructors if c]))
all_calls = set(all_calls+all_constructors)
return [c for c in all_calls if isinstance(c, Function)]
@property
def all_state_variables_written(self):
'''
list(StateVariable): List all of the state variables written
'''
all_state_variables_written = [f.all_state_variables_written() for f in self.functions + self.modifiers]
all_state_variables_written = [item for sublist in all_state_variables_written for item in sublist]
return list(set(all_state_variables_written))
@property
def all_state_variables_read(self):
'''
list(StateVariable): List all of the state variables read
'''
all_state_variables_read = [f.all_state_variables_read() for f in self.functions + self.modifiers]
all_state_variables_read = [item for sublist in all_state_variables_read for item in sublist]
return list(set(all_state_variables_read))
# endregion
###################################################################################
###################################################################################
# region Summary information
###################################################################################
###################################################################################
def get_summary(self):
""" Return the function summary
Returns:
(str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries)
"""
func_summaries = [f.get_summary() for f in self.functions]
modif_summaries = [f.get_summary() for f in self.modifiers]
return (self.name, [str(x) for x in self.inheritance], [str(x) for x in self.variables], func_summaries, modif_summaries)
def is_signature_only(self):
""" Detect if the contract has only abstract functions
Returns:
bool: true if the function are abstract functions
"""
return all((not f.is_implemented) for f in self.functions)
def is_erc20(self): def is_erc20(self):
""" """
Check if the contract is an erc20 token Check if the contract is an erc20 token
@ -462,16 +540,35 @@ class Contract(ChildSlither, SourceMapping):
'transferFrom(address,address,uint256)' in full_names and\ 'transferFrom(address,address,uint256)' in full_names and\
'approve(address,uint256)' in full_names 'approve(address,uint256)' in full_names
# endregion
###################################################################################
###################################################################################
# region Function analyses
###################################################################################
###################################################################################
def update_read_write_using_ssa(self): def update_read_write_using_ssa(self):
for function in self.functions + self.modifiers: for function in self.functions + self.modifiers:
function.update_read_write_using_ssa() function.update_read_write_using_ssa()
def get_summary(self): # endregion
""" Return the function summary ###################################################################################
###################################################################################
# region Built in definitions
###################################################################################
###################################################################################
Returns: def __eq__(self, other):
(str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) if isinstance(other, str):
""" return other == self.name
func_summaries = [f.get_summary() for f in self.functions] return NotImplemented
modif_summaries = [f.get_summary() for f in self.modifiers]
return (self.name, [str(x) for x in self.inheritance], [str(x) for x in self.variables], func_summaries, modif_summaries) def __neq__(self, other):
if isinstance(other, str):
return other != self.name
return NotImplemented
def __str__(self):
return self.name
# endregion

@ -82,28 +82,11 @@ class Function(ChildContract, SourceMapping):
self._reachable_from_nodes = set() self._reachable_from_nodes = set()
self._reachable_from_functions = set() self._reachable_from_functions = set()
@property ###################################################################################
def contains_assembly(self): ###################################################################################
return self._contains_assembly # region General properties
###################################################################################
@property ###################################################################################
def return_type(self):
"""
Return the list of return type
If no return, return None
"""
returns = self.returns
if returns:
return [r.type for r in returns]
return None
@property
def type(self):
"""
Return the list of return type
If no return, return None
"""
return self.return_type
@property @property
def name(self): def name(self):
@ -118,25 +101,35 @@ class Function(ChildContract, SourceMapping):
return self._name return self._name
@property @property
def nodes(self): def full_name(self):
""" """
list(Node): List of the nodes str: func_name(type1,type2)
Return the function signature without the return values
""" """
return list(self._nodes) name, parameters, _ = self.signature
return name+'('+','.join(parameters)+')'
@property @property
def entry_point(self): def is_constructor(self):
""" """
Node: Entry point of the function bool: True if the function is the constructor
""" """
return self._entry_point return self._is_constructor or self._name == self.contract.name
@property @property
def visibility(self): def contains_assembly(self):
""" return self._contains_assembly
str: Function visibility
""" @property
return self._visibility def slither(self):
return self.contract.slither
# endregion
###################################################################################
###################################################################################
# region Payable
###################################################################################
###################################################################################
@property @property
def payable(self): def payable(self):
@ -145,12 +138,19 @@ class Function(ChildContract, SourceMapping):
""" """
return self._payable return self._payable
# endregion
###################################################################################
###################################################################################
# region Visibility
###################################################################################
###################################################################################
@property @property
def is_constructor(self): def visibility(self):
""" """
bool: True if the function is the constructor str: Function visibility
""" """
return self._is_constructor or self._name == self.contract.name return self._visibility
@property @property
def view(self): def view(self):
@ -166,6 +166,13 @@ class Function(ChildContract, SourceMapping):
""" """
return self._pure return self._pure
# endregion
###################################################################################
###################################################################################
# region Function's body
###################################################################################
###################################################################################
@property @property
def is_implemented(self): def is_implemented(self):
""" """
@ -180,6 +187,36 @@ class Function(ChildContract, SourceMapping):
""" """
return self._is_empty return self._is_empty
# endregion
###################################################################################
###################################################################################
# region Nodes
###################################################################################
###################################################################################
@property
def nodes(self):
"""
list(Node): List of the nodes
"""
return list(self._nodes)
@property
def entry_point(self):
"""
Node: Entry point of the function
"""
return self._entry_point
# endregion
###################################################################################
###################################################################################
# region Parameters
###################################################################################
###################################################################################
@property @property
def parameters(self): def parameters(self):
""" """
@ -197,6 +234,32 @@ class Function(ChildContract, SourceMapping):
def add_parameter_ssa(self, var): def add_parameter_ssa(self, var):
self._parameters_ssa.append(var) self._parameters_ssa.append(var)
# endregion
###################################################################################
###################################################################################
# region Return values
###################################################################################
###################################################################################
@property
def return_type(self):
"""
Return the list of return type
If no return, return None
"""
returns = self.returns
if returns:
return [r.type for r in returns]
return None
@property
def type(self):
"""
Return the list of return type
If no return, return None
"""
return self.return_type
@property @property
def returns(self): def returns(self):
""" """
@ -214,6 +277,13 @@ class Function(ChildContract, SourceMapping):
def add_return_ssa(self, var): def add_return_ssa(self, var):
self._returns_ssa.append(var) self._returns_ssa.append(var)
# endregion
###################################################################################
###################################################################################
# region Modifiers
###################################################################################
###################################################################################
@property @property
def modifiers(self): def modifiers(self):
""" """
@ -232,8 +302,13 @@ class Function(ChildContract, SourceMapping):
# This is a list of contracts internally, so we convert it to a list of constructor functions. # This is a list of contracts internally, so we convert it to a list of constructor functions.
return [c.constructor_not_inherited for c in self._explicit_base_constructor_calls if c.constructor_not_inherited] return [c.constructor_not_inherited for c in self._explicit_base_constructor_calls if c.constructor_not_inherited]
def __str__(self):
return self._name # endregion
###################################################################################
###################################################################################
# region Variables
###################################################################################
###################################################################################
@property @property
def variables(self): def variables(self):
@ -311,6 +386,13 @@ class Function(ChildContract, SourceMapping):
return list(self._slithir_variables) return list(self._slithir_variables)
# endregion
###################################################################################
###################################################################################
# region Calls
###################################################################################
###################################################################################
@property @property
def internal_calls(self): def internal_calls(self):
""" """
@ -353,6 +435,13 @@ class Function(ChildContract, SourceMapping):
""" """
return list(self._external_calls_as_expressions) return list(self._external_calls_as_expressions)
# endregion
###################################################################################
###################################################################################
# region Expressions
###################################################################################
###################################################################################
@property @property
def calls_as_expressions(self): def calls_as_expressions(self):
return self._expression_calls return self._expression_calls
@ -368,6 +457,13 @@ class Function(ChildContract, SourceMapping):
self._expressions = expressions self._expressions = expressions
return self._expressions return self._expressions
# endregion
###################################################################################
###################################################################################
# region SlithIR
###################################################################################
###################################################################################
@property @property
def slithir_operations(self): def slithir_operations(self):
""" """
@ -379,6 +475,13 @@ class Function(ChildContract, SourceMapping):
self._slithir_operations = operations self._slithir_operations = operations
return self._slithir_operations return self._slithir_operations
# endregion
###################################################################################
###################################################################################
# region Signature
###################################################################################
###################################################################################
@property @property
def signature(self): def signature(self):
""" """
@ -396,14 +499,12 @@ class Function(ChildContract, SourceMapping):
name, parameters, returnVars = self.signature name, parameters, returnVars = self.signature
return name+'('+','.join(parameters)+') returns('+','.join(returnVars)+')' return name+'('+','.join(parameters)+') returns('+','.join(returnVars)+')'
@property # endregion
def full_name(self): ###################################################################################
""" ###################################################################################
str: func_name(type1,type2) # region Functions
Return the function signature without the return values ###################################################################################
""" ###################################################################################
name, parameters, _ = self.signature
return name+'('+','.join(parameters)+')'
@property @property
def functions_shadowed(self): def functions_shadowed(self):
@ -418,9 +519,12 @@ class Function(ChildContract, SourceMapping):
return [f for f in candidates if f.full_name == self.full_name] return [f for f in candidates if f.full_name == self.full_name]
@property # endregion
def slither(self): ###################################################################################
return self.contract.slither ###################################################################################
# region Reachable
###################################################################################
###################################################################################
@property @property
def reachable_from_nodes(self): def reachable_from_nodes(self):
@ -438,111 +542,12 @@ class Function(ChildContract, SourceMapping):
self._reachable_from_nodes.add(ReacheableNode(n, ir)) self._reachable_from_nodes.add(ReacheableNode(n, ir))
self._reachable_from_functions.add(n.function) self._reachable_from_functions.add(n.function)
def _filter_state_variables_written(self, expressions): # endregion
ret = [] ###################################################################################
for expression in expressions: ###################################################################################
if isinstance(expression, Identifier): # region Recursive getters
ret.append(expression) ###################################################################################
if isinstance(expression, UnaryOperation): ###################################################################################
ret.append(expression.expression)
if isinstance(expression, MemberAccess):
ret.append(expression.expression)
if isinstance(expression, IndexAccess):
ret.append(expression.expression_left)
return ret
def _analyze_read_write(self):
""" Compute variables read/written/...
"""
write_var = [x.variables_written_as_expression for x in self.nodes]
write_var = [x for x in write_var if x]
write_var = [item for sublist in write_var for item in sublist]
write_var = list(set(write_var))
# Remove dupplicate if they share the same string representation
write_var = [next(obj) for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))]
self._expression_vars_written = write_var
write_var = [x.variables_written for x in self.nodes]
write_var = [x for x in write_var if x]
write_var = [item for sublist in write_var for item in sublist]
write_var = list(set(write_var))
# Remove dupplicate if they share the same string representation
write_var = [next(obj) for i, obj in\
groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))]
self._vars_written = write_var
read_var = [x.variables_read_as_expression for x in self.nodes]
read_var = [x for x in read_var if x]
read_var = [item for sublist in read_var for item in sublist]
# Remove dupplicate if they share the same string representation
read_var = [next(obj) for i, obj in\
groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))]
self._expression_vars_read = read_var
read_var = [x.variables_read for x in self.nodes]
read_var = [x for x in read_var if x]
read_var = [item for sublist in read_var for item in sublist]
# Remove dupplicate if they share the same string representation
read_var = [next(obj) for i, obj in\
groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))]
self._vars_read = read_var
self._state_vars_written = [x for x in self.variables_written if\
isinstance(x, StateVariable)]
self._state_vars_read = [x for x in self.variables_read if\
isinstance(x, (StateVariable))]
self._solidity_vars_read = [x for x in self.variables_read if\
isinstance(x, (SolidityVariable))]
self._vars_read_or_written = self._vars_written + self._vars_read
slithir_variables = [x.slithir_variables for x in self.nodes]
slithir_variables = [x for x in slithir_variables if x]
self._slithir_variables = [item for sublist in slithir_variables for item in sublist]
def _analyze_calls(self):
calls = [x.calls_as_expression for x in self.nodes]
calls = [x for x in calls if x]
calls = [item for sublist in calls for item in sublist]
# Remove dupplicate if they share the same string representation
# TODO: check if groupby is still necessary here
calls = [next(obj) for i, obj in\
groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))]
self._expression_calls = calls
internal_calls = [x.internal_calls for x in self.nodes]
internal_calls = [x for x in internal_calls if x]
internal_calls = [item for sublist in internal_calls for item in sublist]
internal_calls = [next(obj) for i, obj in
groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))]
self._internal_calls = internal_calls
self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)]
low_level_calls = [x.low_level_calls for x in self.nodes]
low_level_calls = [x for x in low_level_calls if x]
low_level_calls = [item for sublist in low_level_calls for item in sublist]
low_level_calls = [next(obj) for i, obj in
groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))]
self._low_level_calls = low_level_calls
high_level_calls = [x.high_level_calls for x in self.nodes]
high_level_calls = [x for x in high_level_calls if x]
high_level_calls = [item for sublist in high_level_calls for item in sublist]
high_level_calls = [next(obj) for i, obj in
groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))]
self._high_level_calls = high_level_calls
external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes]
external_calls_as_expressions = [x for x in external_calls_as_expressions if x]
external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist]
external_calls_as_expressions = [next(obj) for i, obj in
groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))]
self._external_calls_as_expressions = external_calls_as_expressions
def _explore_functions(self, f_new_values): def _explore_functions(self, f_new_values):
values = f_new_values(self) values = f_new_values(self)
@ -698,49 +703,12 @@ class Function(ChildContract, SourceMapping):
lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls)) lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls))
return self._all_solidity_variables_used_as_args return self._all_solidity_variables_used_as_args
def is_reading(self, variable): # endregion
""" ###################################################################################
Check if the function reads the variable ###################################################################################
Args: # region Visitor
variable (Variable): ###################################################################################
Returns: ###################################################################################
bool: True if the variable is read
"""
return variable in self.variables_read
def is_reading_in_conditional_node(self, variable):
"""
Check if the function reads the variable in a IF node
Args:
variable (Variable):
Returns:
bool: True if the variable is read
"""
variables_read = [n.variables_read for n in self.nodes if n.contains_if()]
variables_read = [item for sublist in variables_read for item in sublist]
return variable in variables_read
def is_reading_in_require_or_assert(self, variable):
"""
Check if the function reads the variable in an require or assert
Args:
variable (Variable):
Returns:
bool: True if the variable is read
"""
variables_read = [n.variables_read for n in self.nodes if n.contains_require_or_assert()]
variables_read = [item for sublist in variables_read for item in sublist]
return variable in variables_read
def is_writing(self, variable):
"""
Check if the function writes the variable
Args:
variable (Variable):
Returns:
bool: True if the variable is written
"""
return variable in self.variables_written
def apply_visitor(self, Visitor): def apply_visitor(self, Visitor):
""" """
@ -754,6 +722,29 @@ class Function(ChildContract, SourceMapping):
v = [Visitor(e).result() for e in expressions] v = [Visitor(e).result() for e in expressions]
return [item for sublist in v for item in sublist] return [item for sublist in v for item in sublist]
# endregion
###################################################################################
###################################################################################
# region Getters from/to object
###################################################################################
###################################################################################
def get_local_variable_from_name(self, variable_name):
"""
Return a local variable from a name
Args:
varible_name (str): name of the variable
Returns:
LocalVariable
"""
return next((v for v in self.variables if v.name == variable_name), None)
# endregion
###################################################################################
###################################################################################
# region Export
###################################################################################
###################################################################################
def cfg_to_dot(self, filename): def cfg_to_dot(self, filename):
""" """
@ -812,6 +803,57 @@ class Function(ChildContract, SourceMapping):
f.write("}\n") f.write("}\n")
# endregion
###################################################################################
###################################################################################
# region Summary information
###################################################################################
###################################################################################
def is_reading(self, variable):
"""
Check if the function reads the variable
Args:
variable (Variable):
Returns:
bool: True if the variable is read
"""
return variable in self.variables_read
def is_reading_in_conditional_node(self, variable):
"""
Check if the function reads the variable in a IF node
Args:
variable (Variable):
Returns:
bool: True if the variable is read
"""
variables_read = [n.variables_read for n in self.nodes if n.contains_if()]
variables_read = [item for sublist in variables_read for item in sublist]
return variable in variables_read
def is_reading_in_require_or_assert(self, variable):
"""
Check if the function reads the variable in an require or assert
Args:
variable (Variable):
Returns:
bool: True if the variable is read
"""
variables_read = [n.variables_read for n in self.nodes if n.contains_require_or_assert()]
variables_read = [item for sublist in variables_read for item in sublist]
return variable in variables_read
def is_writing(self, variable):
"""
Check if the function writes the variable
Args:
variable (Variable):
Returns:
bool: True if the variable is written
"""
return variable in self.variables_written
def get_summary(self): def get_summary(self):
""" """
Return the function summary Return the function summary
@ -844,12 +886,128 @@ class Function(ChildContract, SourceMapping):
args_vars = self.all_solidity_variables_used_as_args() args_vars = self.all_solidity_variables_used_as_args()
return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars
def get_local_variable_from_name(self, variable_name): # endregion
""" ###################################################################################
Return a local variable from a name ###################################################################################
Args: # region Analyses
varible_name (str): name of the variable ###################################################################################
Returns: ###################################################################################
LocalVariable
def _filter_state_variables_written(self, expressions):
ret = []
for expression in expressions:
if isinstance(expression, Identifier):
ret.append(expression)
if isinstance(expression, UnaryOperation):
ret.append(expression.expression)
if isinstance(expression, MemberAccess):
ret.append(expression.expression)
if isinstance(expression, IndexAccess):
ret.append(expression.expression_left)
return ret
def _analyze_read_write(self):
""" Compute variables read/written/...
""" """
return next((v for v in self.variables if v.name == variable_name), None) write_var = [x.variables_written_as_expression for x in self.nodes]
write_var = [x for x in write_var if x]
write_var = [item for sublist in write_var for item in sublist]
write_var = list(set(write_var))
# Remove dupplicate if they share the same string representation
write_var = [next(obj) for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))]
self._expression_vars_written = write_var
write_var = [x.variables_written for x in self.nodes]
write_var = [x for x in write_var if x]
write_var = [item for sublist in write_var for item in sublist]
write_var = list(set(write_var))
# Remove dupplicate if they share the same string representation
write_var = [next(obj) for i, obj in\
groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))]
self._vars_written = write_var
read_var = [x.variables_read_as_expression for x in self.nodes]
read_var = [x for x in read_var if x]
read_var = [item for sublist in read_var for item in sublist]
# Remove dupplicate if they share the same string representation
read_var = [next(obj) for i, obj in\
groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))]
self._expression_vars_read = read_var
read_var = [x.variables_read for x in self.nodes]
read_var = [x for x in read_var if x]
read_var = [item for sublist in read_var for item in sublist]
# Remove dupplicate if they share the same string representation
read_var = [next(obj) for i, obj in\
groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))]
self._vars_read = read_var
self._state_vars_written = [x for x in self.variables_written if\
isinstance(x, StateVariable)]
self._state_vars_read = [x for x in self.variables_read if\
isinstance(x, (StateVariable))]
self._solidity_vars_read = [x for x in self.variables_read if\
isinstance(x, (SolidityVariable))]
self._vars_read_or_written = self._vars_written + self._vars_read
slithir_variables = [x.slithir_variables for x in self.nodes]
slithir_variables = [x for x in slithir_variables if x]
self._slithir_variables = [item for sublist in slithir_variables for item in sublist]
def _analyze_calls(self):
calls = [x.calls_as_expression for x in self.nodes]
calls = [x for x in calls if x]
calls = [item for sublist in calls for item in sublist]
# Remove dupplicate if they share the same string representation
# TODO: check if groupby is still necessary here
calls = [next(obj) for i, obj in\
groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))]
self._expression_calls = calls
internal_calls = [x.internal_calls for x in self.nodes]
internal_calls = [x for x in internal_calls if x]
internal_calls = [item for sublist in internal_calls for item in sublist]
internal_calls = [next(obj) for i, obj in
groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))]
self._internal_calls = internal_calls
self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)]
low_level_calls = [x.low_level_calls for x in self.nodes]
low_level_calls = [x for x in low_level_calls if x]
low_level_calls = [item for sublist in low_level_calls for item in sublist]
low_level_calls = [next(obj) for i, obj in
groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))]
self._low_level_calls = low_level_calls
high_level_calls = [x.high_level_calls for x in self.nodes]
high_level_calls = [x for x in high_level_calls if x]
high_level_calls = [item for sublist in high_level_calls for item in sublist]
high_level_calls = [next(obj) for i, obj in
groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))]
self._high_level_calls = high_level_calls
external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes]
external_calls_as_expressions = [x for x in external_calls_as_expressions if x]
external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist]
external_calls_as_expressions = [next(obj) for i, obj in
groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))]
self._external_calls_as_expressions = external_calls_as_expressions
# endregion
###################################################################################
###################################################################################
# region Built in definitions
###################################################################################
###################################################################################
def __str__(self):
return self._name
# endregion

@ -1,6 +1,5 @@
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.function_type_variable import FunctionTypeVariable
from slither.core.expressions.expression import Expression
class FunctionType(Type): class FunctionType(Type):

@ -1,20 +1,24 @@
import logging import logging
from slither.core.declarations import (Contract, Enum, Event, SolidityFunction, from slither.core.declarations import (Contract, Enum, Event, Function,
Structure, SolidityVariableComposed, Function, SolidityVariable) SolidityFunction, SolidityVariable,
from slither.core.expressions import Identifier, Literal, TupleExpression SolidityVariableComposed, Structure)
from slither.core.solidity_types import ElementaryType, UserDefinedType, MappingType, ArrayType, FunctionType from slither.core.expressions import Identifier, Literal
from slither.core.solidity_types import (ArrayType, ElementaryType,
FunctionType, MappingType,
UserDefinedType)
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.operations import (Assignment, Binary, BinaryType, Call, from slither.slithir.operations import (Assignment, Balance, Binary,
Condition, Delete, EventCall, BinaryType, Call, Condition, Delete,
HighLevelCall, Index, InitArray, EventCall, HighLevelCall, Index,
InternalCall, InternalDynamicCall, LibraryCall, InitArray, InternalCall,
LowLevelCall, Member, NewArray, InternalDynamicCall, Length,
NewContract, NewElementaryType, LibraryCall, LowLevelCall, Member,
NewStructure, OperationWithLValue, NewArray, NewContract,
Push, Return, Send, SolidityCall, NewElementaryType, NewStructure,
Transfer, TypeConversion, Unary, OperationWithLValue, Push, Return,
Unpack, Length, Balance) Send, SolidityCall, Transfer,
TypeConversion, Unary, Unpack)
from slither.slithir.tmp_operations.argument import Argument, ArgumentType from slither.slithir.tmp_operations.argument import Argument, ArgumentType
from slither.slithir.tmp_operations.tmp_call import TmpCall from slither.slithir.tmp_operations.tmp_call import TmpCall
from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray
@ -23,11 +27,47 @@ from slither.slithir.tmp_operations.tmp_new_elementary_type import \
TmpNewElementaryType TmpNewElementaryType
from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure
from slither.slithir.variables import (Constant, ReferenceVariable, from slither.slithir.variables import (Constant, ReferenceVariable,
TemporaryVariable, TupleVariable) TemporaryVariable)
from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR
logger = logging.getLogger('ConvertToIR') logger = logging.getLogger('ConvertToIR')
def convert_expression(expression, node):
# handle standlone expression
# such as return true;
from slither.core.cfg.node import NodeType
if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]:
result = [Condition(Constant(expression.value))]
return result
if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]:
result = [Condition(expression.value)]
return result
visitor = ExpressionToSlithIR(expression, node)
result = visitor.result()
result = apply_ir_heuristics(result, node)
if result:
if node.type in [NodeType.IF, NodeType.IFLOOP]:
assert isinstance(result[-1], (OperationWithLValue))
result.append(Condition(result[-1].lvalue))
elif node.type == NodeType.RETURN:
# May return None
if isinstance(result[-1], (OperationWithLValue)):
result.append(Return(result[-1].lvalue))
return result
###################################################################################
###################################################################################
# region Helpers
###################################################################################
###################################################################################
def is_value(ins): def is_value(ins):
if isinstance(ins, TmpCall): if isinstance(ins, TmpCall):
if isinstance(ins.ori, Member): if isinstance(ins.ori, Member):
@ -42,7 +82,65 @@ def is_gas(ins):
return True return True
return False return False
def get_sig(ir):
'''
Return a list of potential signature
It is a list, as Constant variables can be converted to int256
Args:
ir (slithIR.operation)
Returns:
list(str)
'''
sig = '{}({})'
name = ir.function_name
# list of list of arguments
argss = [[]]
for arg in ir.arguments:
if isinstance(arg, (list,)):
type_arg = '{}[{}]'.format(get_type(arg[0].type), len(arg))
elif isinstance(arg, Function):
type_arg = arg.signature_str
else:
type_arg = get_type(arg.type)
if isinstance(arg, Constant) and arg.type == ElementaryType('uint256'):
# If it is a constant
# We dupplicate the existing list
# And we add uint256 and int256 cases
# There is no potential collision, as the compiler
# Prevent it with a
# "not unique after argument-dependent loopkup" issue
argss_new = [list(args) for args in argss]
for args in argss:
args.append(str(ElementaryType('uint256')))
for args in argss_new:
args.append(str(ElementaryType('int256')))
argss = argss + argss_new
else:
for args in argss:
args.append(type_arg)
return [sig.format(name, ','.join(args)) for args in argss]
def is_temporary(ins):
return isinstance(ins, (Argument,
TmpNewElementaryType,
TmpNewContract,
TmpNewArray,
TmpNewStructure))
# endregion
###################################################################################
###################################################################################
# region Calls modification
###################################################################################
###################################################################################
def integrate_value_gas(result): def integrate_value_gas(result):
'''
Integrate value and gas temporary arguments to call instruction
'''
was_changed = True was_changed = True
calls = [] calls = []
@ -110,7 +208,17 @@ def integrate_value_gas(result):
return result return result
def propage_type_and_convert_call(result, node): # endregion
###################################################################################
###################################################################################
# region Calls modification and Type propagation
###################################################################################
###################################################################################
def propagate_type_and_convert_call(result, node):
'''
Propagate the types variables and convert tmp call to real call operation
'''
calls_value = {} calls_value = {}
calls_gas = {} calls_gas = {}
@ -179,133 +287,343 @@ def propage_type_and_convert_call(result, node):
idx = idx +1 idx = idx +1
return result return result
def convert_to_low_level(ir): def propagate_types(ir, node):
""" # propagate the type
Convert to a transfer/send/or low level call using_for = node.function.contract.using_for
The funciton assume to receive a correct IR if isinstance(ir, OperationWithLValue):
The checks must be done by the caller # Force assignment in case of missing previous correct type
if not ir.lvalue.type:
Additionally convert abi... to solidityfunction if isinstance(ir, Assignment):
""" ir.lvalue.set_type(ir.rvalue.type)
if ir.function_name == 'transfer': elif isinstance(ir, Binary):
assert len(ir.arguments) == 1 if BinaryType.return_bool(ir.type):
ir = Transfer(ir.destination, ir.arguments[0]) ir.lvalue.set_type(ElementaryType('bool'))
return ir else:
elif ir.function_name == 'send': ir.lvalue.set_type(ir.variable_left.type)
assert len(ir.arguments) == 1 elif isinstance(ir, Delete):
ir = Send(ir.destination, ir.arguments[0], ir.lvalue) # nothing to propagate
ir.lvalue.set_type(ElementaryType('bool')) pass
return ir elif isinstance(ir, LibraryCall):
elif ir.destination.name == 'abi' and ir.function_name in ['encode', return convert_type_library_call(ir, ir.destination)
'encodePacked', elif isinstance(ir, HighLevelCall):
'encodeWithSelector', t = ir.destination.type
'encodeWithSignature',
'decode']:
call = SolidityFunction('abi.{}()'.format(ir.function_name))
new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call)
new_ir.arguments = ir.arguments
if isinstance(call.return_type, list) and len(call.return_type) == 1:
new_ir.lvalue.set_type(call.return_type[0])
else:
new_ir.lvalue.set_type(call.return_type)
return new_ir
elif ir.function_name in ['call',
'delegatecall',
'callcode',
'staticcall']:
new_ir = LowLevelCall(ir.destination,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
new_ir.call_gas = ir.call_gas
new_ir.call_value = ir.call_value
new_ir.arguments = ir.arguments
new_ir.lvalue.set_type(ElementaryType('bool'))
return new_ir
logger.error('Incorrect conversion to low level {}'.format(ir))
exit(-1)
def convert_to_push(ir, node):
"""
Convert a call to a PUSH operaiton
The funciton assume to receive a correct IR
The checks must be done by the caller
May necessitate to create an intermediate operation (InitArray)
Necessitate to return the lenght (see push documentation)
As a result, the function return may return a list
"""
lvalue = ir.lvalue
if isinstance(ir.arguments[0], list):
ret = []
val = TemporaryVariable(node)
operation = InitArray(ir.arguments[0], val)
ret.append(operation)
ir = Push(ir.destination, val)
length = Literal(len(operation.init_values))
t = operation.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
ret.append(ir)
if lvalue:
length = Length(ir.array, lvalue)
length.lvalue.points_to = ir.lvalue
ret.append(length)
return ret
ir = Push(ir.destination, ir.arguments[0])
if lvalue:
ret = []
ret.append(ir)
length = Length(ir.array, lvalue) # Temporary operation (they are removed later)
length.lvalue.points_to = ir.lvalue if t is None:
ret.append(length) return
return ret
return ir # convert library
if t in using_for or '*' in using_for:
new_ir = convert_to_library(ir, node, using_for)
if new_ir:
return new_ir
def look_for_library(contract, ir, node, using_for, t): if isinstance(t, UserDefinedType):
for destination in using_for[t]: # UserdefinedType
lib_contract = contract.slither.get_contract_from_name(str(destination)) t_type = t.type
if lib_contract: if isinstance(t_type, Contract):
lib_call = LibraryCall(lib_contract, contract = node.slither.get_contract_from_name(t_type.name)
ir.function_name, return convert_type_of_high_level_call(ir, contract)
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments
new_ir = convert_type_library_call(lib_call, lib_contract)
if new_ir:
new_ir.set_node(ir.node)
return new_ir
return None
def convert_to_library(ir, node, using_for): # Convert HighLevelCall to LowLevelCall
contract = node.function.contract if isinstance(t, ElementaryType) and t.name == 'address':
t = ir.destination.type if ir.destination.name == 'this':
return convert_type_of_high_level_call(ir, node.function.contract)
return convert_to_low_level(ir)
if t in using_for: # Convert push operations
new_ir = look_for_library(contract, ir, node, using_for, t) # May need to insert a new operation
if new_ir: # Which leads to return a list of operation
return new_ir if isinstance(t, ArrayType):
if ir.function_name == 'push' and len(ir.arguments) == 1:
return convert_to_push(ir, node)
if '*' in using_for: elif isinstance(ir, Index):
new_ir = look_for_library(contract, ir, node, using_for, '*') if isinstance(ir.variable_left.type, MappingType):
if new_ir: ir.lvalue.set_type(ir.variable_left.type.type_to)
return new_ir elif isinstance(ir.variable_left.type, ArrayType):
ir.lvalue.set_type(ir.variable_left.type.type)
elif isinstance(ir, InitArray):
length = len(ir.init_values)
t = ir.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
elif isinstance(ir, InternalCall):
# if its not a tuple, return a singleton
return_type = ir.function.return_type
if return_type:
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
elif len(return_type)>1:
ir.lvalue.set_type(return_type)
else:
ir.lvalue = None
elif isinstance(ir, InternalDynamicCall):
# if its not a tuple, return a singleton
return_type = ir.function_type.return_type
if return_type:
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
else:
ir.lvalue.set_type(return_type)
else:
ir.lvalue = None
elif isinstance(ir, LowLevelCall):
# Call are not yet converted
# This should not happen
assert False
elif isinstance(ir, Member):
# TODO we should convert the reference to a temporary if the member is a length or a balance
if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)):
length = Length(ir.variable_left, ir.lvalue)
length.lvalue.points_to = ir.variable_left
return length
if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType):
return Balance(ir.variable_left, ir.lvalue)
left = ir.variable_left
if isinstance(left, (Variable, SolidityVariable)):
t = ir.variable_left.type
elif isinstance(left, (Contract, Enum, Structure)):
t = UserDefinedType(left)
# can be None due to temporary operation
if t:
if isinstance(t, UserDefinedType):
# UserdefinedType
type_t = t.type
if isinstance(type_t, Enum):
ir.lvalue.set_type(t)
elif isinstance(type_t, Structure):
elems = type_t.elems
for elem in elems:
if elem == ir.variable_right:
ir.lvalue.set_type(elems[elem].type)
else:
assert isinstance(type_t, Contract)
elif isinstance(ir, NewArray):
ir.lvalue.set_type(ir.array_type)
elif isinstance(ir, NewContract):
contract = node.slither.get_contract_from_name(ir.contract_name)
ir.lvalue.set_type(UserDefinedType(contract))
elif isinstance(ir, NewElementaryType):
ir.lvalue.set_type(ir.type)
elif isinstance(ir, NewStructure):
ir.lvalue.set_type(UserDefinedType(ir.structure))
elif isinstance(ir, Push):
# No change required
pass
elif isinstance(ir, Send):
ir.lvalue.set_type(ElementaryType('bool'))
elif isinstance(ir, SolidityCall):
return_type = ir.function.return_type
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
elif len(return_type)>1:
ir.lvalue.set_type(return_type)
elif isinstance(ir, TypeConversion):
ir.lvalue.set_type(ir.type)
elif isinstance(ir, Unary):
ir.lvalue.set_type(ir.rvalue.type)
elif isinstance(ir, Unpack):
types = ir.tuple.type.type
idx = ir.index
t = types[idx]
ir.lvalue.set_type(t)
elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)):
# temporary operation; they will be removed
pass
else:
logger.error('Not handling {} during type propgation'.format(type(ir)))
exit(-1)
def extract_tmp_call(ins):
assert isinstance(ins, TmpCall)
if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
call.call_id = ins.call_id
return call
if isinstance(ins.ori, Member):
if isinstance(ins.ori.variable_left, Contract):
st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
if st:
op = NewStructure(st, ins.lvalue)
op.call_id = ins.call_id
return op
libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
libcall.call_id = ins.call_id
return libcall
msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
msgcall.call_id = ins.call_id
return msgcall
if isinstance(ins.ori, TmpCall):
r = extract_tmp_call(ins.ori)
return r
if isinstance(ins.called, SolidityVariableComposed):
if str(ins.called) == 'block.blockhash':
ins.called = SolidityFunction('blockhash(uint256)')
elif str(ins.called) == 'this.balance':
return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)
if isinstance(ins.called, SolidityFunction):
return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)
if isinstance(ins.ori, TmpNewElementaryType):
return NewElementaryType(ins.ori.type, ins.lvalue)
if isinstance(ins.ori, TmpNewContract):
op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
op.call_id = ins.call_id
return op
if isinstance(ins.ori, TmpNewArray):
return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
if isinstance(ins.called, Structure):
op = NewStructure(ins.called, ins.lvalue)
op.call_id = ins.call_id
return op
if isinstance(ins.called, Event):
return EventCall(ins.called.name)
raise Exception('Not extracted {} {}'.format(type(ins.called), ins))
# endregion
###################################################################################
###################################################################################
# region Conversion operations
###################################################################################
###################################################################################
def convert_to_low_level(ir):
"""
Convert to a transfer/send/or low level call
The funciton assume to receive a correct IR
The checks must be done by the caller
Additionally convert abi... to solidityfunction
"""
if ir.function_name == 'transfer':
assert len(ir.arguments) == 1
ir = Transfer(ir.destination, ir.arguments[0])
return ir
elif ir.function_name == 'send':
assert len(ir.arguments) == 1
ir = Send(ir.destination, ir.arguments[0], ir.lvalue)
ir.lvalue.set_type(ElementaryType('bool'))
return ir
elif ir.destination.name == 'abi' and ir.function_name in ['encode',
'encodePacked',
'encodeWithSelector',
'encodeWithSignature',
'decode']:
call = SolidityFunction('abi.{}()'.format(ir.function_name))
new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call)
new_ir.arguments = ir.arguments
if isinstance(call.return_type, list) and len(call.return_type) == 1:
new_ir.lvalue.set_type(call.return_type[0])
else:
new_ir.lvalue.set_type(call.return_type)
return new_ir
elif ir.function_name in ['call',
'delegatecall',
'callcode',
'staticcall']:
new_ir = LowLevelCall(ir.destination,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
new_ir.call_gas = ir.call_gas
new_ir.call_value = ir.call_value
new_ir.arguments = ir.arguments
new_ir.lvalue.set_type(ElementaryType('bool'))
return new_ir
logger.error('Incorrect conversion to low level {}'.format(ir))
exit(-1)
def convert_to_push(ir, node):
"""
Convert a call to a PUSH operaiton
The funciton assume to receive a correct IR
The checks must be done by the caller
May necessitate to create an intermediate operation (InitArray)
Necessitate to return the lenght (see push documentation)
As a result, the function return may return a list
"""
lvalue = ir.lvalue
if isinstance(ir.arguments[0], list):
ret = []
val = TemporaryVariable(node)
operation = InitArray(ir.arguments[0], val)
ret.append(operation)
ir = Push(ir.destination, val)
length = Literal(len(operation.init_values))
t = operation.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
ret.append(ir)
if lvalue:
length = Length(ir.array, lvalue)
length.lvalue.points_to = ir.lvalue
ret.append(length)
return ret
ir = Push(ir.destination, ir.arguments[0])
if lvalue:
ret = []
ret.append(ir)
length = Length(ir.array, lvalue)
length.lvalue.points_to = ir.lvalue
ret.append(length)
return ret
return ir
def look_for_library(contract, ir, node, using_for, t):
for destination in using_for[t]:
lib_contract = contract.slither.get_contract_from_name(str(destination))
if lib_contract:
lib_call = LibraryCall(lib_contract,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments
new_ir = convert_type_library_call(lib_call, lib_contract)
if new_ir:
new_ir.set_node(ir.node)
return new_ir
return None
def convert_to_library(ir, node, using_for):
contract = node.function.contract
t = ir.destination.type
if t in using_for:
new_ir = look_for_library(contract, ir, node, using_for, t)
if new_ir:
return new_ir
if '*' in using_for:
new_ir = look_for_library(contract, ir, node, using_for, '*')
if new_ir:
return new_ir
return None return None
@ -316,47 +634,8 @@ def get_type(t):
""" """
if isinstance(t, UserDefinedType): if isinstance(t, UserDefinedType):
if isinstance(t.type, Contract): if isinstance(t.type, Contract):
return 'address' return 'address'
return str(t) return str(t)
def get_sig(ir):
'''
Return a list of potential signature
It is a list, as Constant variables can be converted to int256
Args:
ir (slithIR.operation)
Returns:
list(str)
'''
sig = '{}({})'
name = ir.function_name
# list of list of arguments
argss = [[]]
for arg in ir.arguments:
if isinstance(arg, (list,)):
type_arg = '{}[{}]'.format(get_type(arg[0].type), len(arg))
elif isinstance(arg, Function):
type_arg = arg.signature_str
else:
type_arg = get_type(arg.type)
if isinstance(arg, Constant) and arg.type == ElementaryType('uint256'):
# If it is a constant
# We dupplicate the existing list
# And we add uint256 and int256 cases
# There is no potential collision, as the compiler
# Prevent it with a
# "not unique after argument-dependent loopkup" issue
argss_new = [list(args) for args in argss]
for args in argss:
args.append(str(ElementaryType('uint256')))
for args in argss_new:
args.append(str(ElementaryType('int256')))
argss = argss + argss_new
else:
for args in argss:
args.append(type_arg)
return [sig.format(name, ','.join(args)) for args in argss]
def convert_type_library_call(ir, lib_contract): def convert_type_library_call(ir, lib_contract):
sigs = get_sig(ir) sigs = get_sig(ir)
@ -448,167 +727,12 @@ def convert_type_of_high_level_call(ir, contract):
return None return None
def propagate_types(ir, node): # endregion
# propagate the type ###################################################################################
using_for = node.function.contract.using_for ###################################################################################
if isinstance(ir, OperationWithLValue): # region Points to operation
# Force assignment in case of missing previous correct type ###################################################################################
if not ir.lvalue.type: ###################################################################################
if isinstance(ir, Assignment):
ir.lvalue.set_type(ir.rvalue.type)
elif isinstance(ir, Binary):
if BinaryType.return_bool(ir.type):
ir.lvalue.set_type(ElementaryType('bool'))
else:
ir.lvalue.set_type(ir.variable_left.type)
elif isinstance(ir, Delete):
# nothing to propagate
pass
elif isinstance(ir, LibraryCall):
return convert_type_library_call(ir, ir.destination)
elif isinstance(ir, HighLevelCall):
t = ir.destination.type
# Temporary operation (they are removed later)
if t is None:
return
# convert library
if t in using_for or '*' in using_for:
new_ir = convert_to_library(ir, node, using_for)
if new_ir:
return new_ir
if isinstance(t, UserDefinedType):
# UserdefinedType
t_type = t.type
if isinstance(t_type, Contract):
contract = node.slither.get_contract_from_name(t_type.name)
return convert_type_of_high_level_call(ir, contract)
# Convert HighLevelCall to LowLevelCall
if isinstance(t, ElementaryType) and t.name == 'address':
if ir.destination.name == 'this':
return convert_type_of_high_level_call(ir, node.function.contract)
return convert_to_low_level(ir)
# Convert push operations
# May need to insert a new operation
# Which leads to return a list of operation
if isinstance(t, ArrayType):
if ir.function_name == 'push' and len(ir.arguments) == 1:
return convert_to_push(ir, node)
elif isinstance(ir, Index):
if isinstance(ir.variable_left.type, MappingType):
ir.lvalue.set_type(ir.variable_left.type.type_to)
elif isinstance(ir.variable_left.type, ArrayType):
ir.lvalue.set_type(ir.variable_left.type.type)
elif isinstance(ir, InitArray):
length = len(ir.init_values)
t = ir.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
elif isinstance(ir, InternalCall):
# if its not a tuple, return a singleton
return_type = ir.function.return_type
if return_type:
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
elif len(return_type)>1:
ir.lvalue.set_type(return_type)
else:
ir.lvalue = None
elif isinstance(ir, InternalDynamicCall):
# if its not a tuple, return a singleton
return_type = ir.function_type.return_type
if return_type:
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
else:
ir.lvalue.set_type(return_type)
else:
ir.lvalue = None
elif isinstance(ir, LowLevelCall):
# Call are not yet converted
# This should not happen
assert False
elif isinstance(ir, Member):
# TODO we should convert the reference to a temporary if the member is a length or a balance
if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)):
length = Length(ir.variable_left, ir.lvalue)
length.lvalue.points_to = ir.variable_left
return length
if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType):
return Balance(ir.variable_left, ir.lvalue)
left = ir.variable_left
if isinstance(left, (Variable, SolidityVariable)):
t = ir.variable_left.type
elif isinstance(left, (Contract, Enum, Structure)):
t = UserDefinedType(left)
# can be None due to temporary operation
if t:
if isinstance(t, UserDefinedType):
# UserdefinedType
type_t = t.type
if isinstance(type_t, Enum):
ir.lvalue.set_type(t)
elif isinstance(type_t, Structure):
elems = type_t.elems
for elem in elems:
if elem == ir.variable_right:
ir.lvalue.set_type(elems[elem].type)
else:
assert isinstance(type_t, Contract)
elif isinstance(ir, NewArray):
ir.lvalue.set_type(ir.array_type)
elif isinstance(ir, NewContract):
contract = node.slither.get_contract_from_name(ir.contract_name)
ir.lvalue.set_type(UserDefinedType(contract))
elif isinstance(ir, NewElementaryType):
ir.lvalue.set_type(ir.type)
elif isinstance(ir, NewStructure):
ir.lvalue.set_type(UserDefinedType(ir.structure))
elif isinstance(ir, Push):
# No change required
pass
elif isinstance(ir, Send):
ir.lvalue.set_type(ElementaryType('bool'))
elif isinstance(ir, SolidityCall):
return_type = ir.function.return_type
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
elif len(return_type)>1:
ir.lvalue.set_type(return_type)
elif isinstance(ir, TypeConversion):
ir.lvalue.set_type(ir.type)
elif isinstance(ir, Unary):
ir.lvalue.set_type(ir.rvalue.type)
elif isinstance(ir, Unpack):
types = ir.tuple.type.type
idx = ir.index
t = types[idx]
ir.lvalue.set_type(t)
elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)):
# temporary operation; they will be removed
pass
else:
logger.error('Not handling {} during type propgation'.format(type(ir)))
exit(-1)
def apply_ir_heuristics(irs, node):
"""
Apply a set of heuristic to improve slithIR
"""
irs = integrate_value_gas(irs)
irs = propage_type_and_convert_call(irs, node)
irs = remove_unused(irs)
find_references_origin(irs)
return irs
def find_references_origin(irs): def find_references_origin(irs):
""" """
@ -619,13 +743,12 @@ def find_references_origin(irs):
if isinstance(ir, (Index, Member)): if isinstance(ir, (Index, Member)):
ir.lvalue.points_to = ir.variable_left ir.lvalue.points_to = ir.variable_left
def is_temporary(ins): # endregion
return isinstance(ins, (Argument, ###################################################################################
TmpNewElementaryType, ###################################################################################
TmpNewContract, # region Operation filtering
TmpNewArray, ###################################################################################
TmpNewStructure)) ###################################################################################
def remove_temporary(result): def remove_temporary(result):
result = [ins for ins in result if not isinstance(ins, (Argument, result = [ins for ins in result if not isinstance(ins, (Argument,
@ -668,88 +791,24 @@ def remove_unused(result):
result = [i for i in result if not i in to_remove] result = [i for i in result if not i in to_remove]
return result return result
# endregion
###################################################################################
###################################################################################
# region Heuristics selection
###################################################################################
###################################################################################
def apply_ir_heuristics(irs, node):
"""
Apply a set of heuristic to improve slithIR
"""
def extract_tmp_call(ins): irs = integrate_value_gas(irs)
assert isinstance(ins, TmpCall)
if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
call.call_id = ins.call_id
return call
if isinstance(ins.ori, Member):
if isinstance(ins.ori.variable_left, Contract):
st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
if st:
op = NewStructure(st, ins.lvalue)
op.call_id = ins.call_id
return op
libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
libcall.call_id = ins.call_id
return libcall
msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call)
msgcall.call_id = ins.call_id
return msgcall
if isinstance(ins.ori, TmpCall):
r = extract_tmp_call(ins.ori)
return r
if isinstance(ins.called, SolidityVariableComposed):
if str(ins.called) == 'block.blockhash':
ins.called = SolidityFunction('blockhash(uint256)')
elif str(ins.called) == 'this.balance':
return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call)
if isinstance(ins.called, SolidityFunction):
return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call)
if isinstance(ins.ori, TmpNewElementaryType):
return NewElementaryType(ins.ori.type, ins.lvalue)
if isinstance(ins.ori, TmpNewContract):
op = NewContract(Constant(ins.ori.contract_name), ins.lvalue)
op.call_id = ins.call_id
return op
if isinstance(ins.ori, TmpNewArray):
return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
if isinstance(ins.called, Structure):
op = NewStructure(ins.called, ins.lvalue)
op.call_id = ins.call_id
return op
if isinstance(ins.called, Event):
return EventCall(ins.called.name)
raise Exception('Not extracted {} {}'.format(type(ins.called), ins))
def convert_expression(expression, node):
# handle standlone expression
# such as return true;
from slither.core.cfg.node import NodeType
if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]:
result = [Condition(Constant(expression.value))]
return result
if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]:
result = [Condition(expression.value)]
return result
visitor = ExpressionToSlithIR(expression, node) irs = propagate_type_and_convert_call(irs, node)
result = visitor.result() irs = remove_unused(irs)
find_references_origin(irs)
result = apply_ir_heuristics(result, node)
if result: return irs
if node.type in [NodeType.IF, NodeType.IFLOOP]:
assert isinstance(result[-1], (OperationWithLValue))
result.append(Condition(result[-1].lvalue))
elif node.type == NodeType.RETURN:
# May return None
if isinstance(result[-1], (OperationWithLValue)):
result.append(Return(result[-1].lvalue))
return result

@ -3,28 +3,32 @@ import logging
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import (Assignment, Balance, Binary, from slither.slithir.operations import (Assignment, Balance, Binary, Condition,
BinaryType, Condition, Delete, Delete, EventCall, HighLevelCall,
EventCall, HighLevelCall, Index, Index, InitArray, InternalCall,
InitArray, InternalCall,
InternalDynamicCall, Length, InternalDynamicCall, Length,
LibraryCall, LowLevelCall, Member, LibraryCall, LowLevelCall, Member,
NewArray, NewContract, NewArray, NewContract,
NewElementaryType, NewStructure, NewElementaryType, NewStructure,
OperationWithLValue, Phi, PhiCallback, Push, Return, OperationWithLValue, Phi, PhiCallback,
Send, SolidityCall, Transfer, Push, Return, Send, SolidityCall,
TypeConversion, Unary, Unpack) Transfer, TypeConversion, Unary,
from slither.slithir.variables import (Constant, LocalIRVariable, StateIRVariable, Unpack)
ReferenceVariable, TemporaryVariable, from slither.slithir.variables import (LocalIRVariable, ReferenceVariable,
StateIRVariable, TemporaryVariable,
TupleVariable) TupleVariable)
logger = logging.getLogger('SSA_Conversion') logger = logging.getLogger('SSA_Conversion')
###################################################################################
###################################################################################
# region SlihtIR variables to SSA
###################################################################################
###################################################################################
def transform_slithir_vars_to_ssa(function): def transform_slithir_vars_to_ssa(function):
""" """
Transform slithIR vars to SSA Transform slithIR vars to SSA (TemporaryVariable, ReferenceVariable, TupleVariable)
""" """
variables = [] variables = []
for node in function.nodes: for node in function.nodes:
@ -42,6 +46,12 @@ def transform_slithir_vars_to_ssa(function):
for idx in range(len(tuple_variables)): for idx in range(len(tuple_variables)):
tuple_variables[idx].index = idx tuple_variables[idx].index = idx
###################################################################################
###################################################################################
# region SSA conversion
###################################################################################
###################################################################################
def add_ssa_ir(function, all_state_variables_instances): def add_ssa_ir(function, all_state_variables_instances):
''' '''
Add SSA version of the IR Add SSA version of the IR
@ -134,98 +144,6 @@ def add_ssa_ir(function, all_state_variables_instances):
all_state_variables_instances, all_state_variables_instances,
init_local_variables_instances) init_local_variables_instances)
def last_name(n, var, init_vars):
candidates = []
# Todo optimize by creating a variables_ssa_written attribute
for ir_ssa in n.irs_ssa:
if isinstance(ir_ssa, OperationWithLValue):
lvalue = ir_ssa.lvalue
while isinstance(lvalue, ReferenceVariable):
lvalue = lvalue.points_to
if lvalue and lvalue.name == var.name:
candidates.append(lvalue)
if n.variable_declaration and n.variable_declaration.name == var.name:
candidates.append(LocalIRVariable(n.variable_declaration))
if n.type == NodeType.ENTRYPOINT:
if var.name in init_vars:
candidates.append(init_vars[var.name])
assert candidates
return max(candidates, key=lambda v: v.index)
def update_lvalue(new_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances):
if isinstance(new_ir, OperationWithLValue):
lvalue = new_ir.lvalue
update_through_ref = False
if isinstance(new_ir, (Assignment, Binary)):
if isinstance(lvalue, ReferenceVariable):
update_through_ref = True
while isinstance(lvalue, ReferenceVariable):
lvalue = lvalue.points_to
if isinstance(lvalue, (LocalIRVariable, StateIRVariable)):
if isinstance(lvalue, LocalIRVariable):
new_var = LocalIRVariable(lvalue)
new_var.index = all_local_variables_instances[lvalue.name].index + 1
all_local_variables_instances[lvalue.name] = new_var
local_variables_instances[lvalue.name] = new_var
else:
new_var = StateIRVariable(lvalue)
new_var.index = all_state_variables_instances[lvalue.canonical_name].index + 1
all_state_variables_instances[lvalue.canonical_name] = new_var
state_variables_instances[lvalue.canonical_name] = new_var
if update_through_ref:
phi_operation = Phi(new_var, {node})
phi_operation.rvalues = [lvalue]
node.add_ssa_ir(phi_operation)
if not isinstance(new_ir.lvalue, ReferenceVariable):
new_ir.lvalue = new_var
else:
to_update = new_ir.lvalue
while isinstance(to_update.points_to, ReferenceVariable):
to_update = to_update.points_to
to_update.points_to = new_var
def is_used_later(initial_node, variable):
# TODO: does not handle the case where its read and written in the declaration node
# It can be problematic if this happens in a loop/if structure
# Ex:
# for(;true;){
# if(true){
# uint a = a;
# }
# ..
to_explore = {initial_node}
explored = set()
while to_explore:
node = to_explore.pop()
explored.add(node)
if isinstance(variable, LocalVariable):
if any(v.name == variable.name for v in node.local_variables_read):
return True
if any(v.name == variable.name for v in node.local_variables_written):
return False
if isinstance(variable, StateVariable):
if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_read):
return True
if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_written):
return False
for son in node.sons:
if not son in explored:
to_explore.add(son)
return False
def initiate_all_local_variables_instances(nodes, local_variables_instances, all_local_variables_instances):
for node in nodes:
if node.variable_declaration:
new_var = LocalIRVariable(node.variable_declaration)
if new_var.name in all_local_variables_instances:
new_var.index = all_local_variables_instances[new_var.name].index + 1
local_variables_instances[node.variable_declaration.name] = new_var
all_local_variables_instances[node.variable_declaration.name] = new_var
def generate_ssa_irs(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances, visited): def generate_ssa_irs(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances, visited):
if node in visited: if node in visited:
@ -308,6 +226,129 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
init_local_variables_instances, init_local_variables_instances,
visited) visited)
# endregion
###################################################################################
###################################################################################
# region Helpers
###################################################################################
###################################################################################
def last_name(n, var, init_vars):
candidates = []
# Todo optimize by creating a variables_ssa_written attribute
for ir_ssa in n.irs_ssa:
if isinstance(ir_ssa, OperationWithLValue):
lvalue = ir_ssa.lvalue
while isinstance(lvalue, ReferenceVariable):
lvalue = lvalue.points_to
if lvalue and lvalue.name == var.name:
candidates.append(lvalue)
if n.variable_declaration and n.variable_declaration.name == var.name:
candidates.append(LocalIRVariable(n.variable_declaration))
if n.type == NodeType.ENTRYPOINT:
if var.name in init_vars:
candidates.append(init_vars[var.name])
assert candidates
return max(candidates, key=lambda v: v.index)
def is_used_later(initial_node, variable):
# TODO: does not handle the case where its read and written in the declaration node
# It can be problematic if this happens in a loop/if structure
# Ex:
# for(;true;){
# if(true){
# uint a = a;
# }
# ..
to_explore = {initial_node}
explored = set()
while to_explore:
node = to_explore.pop()
explored.add(node)
if isinstance(variable, LocalVariable):
if any(v.name == variable.name for v in node.local_variables_read):
return True
if any(v.name == variable.name for v in node.local_variables_written):
return False
if isinstance(variable, StateVariable):
if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_read):
return True
if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_written):
return False
for son in node.sons:
if not son in explored:
to_explore.add(son)
return False
# endregion
###################################################################################
###################################################################################
# region Update operation
###################################################################################
###################################################################################
def update_lvalue(new_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances):
if isinstance(new_ir, OperationWithLValue):
lvalue = new_ir.lvalue
update_through_ref = False
if isinstance(new_ir, (Assignment, Binary)):
if isinstance(lvalue, ReferenceVariable):
update_through_ref = True
while isinstance(lvalue, ReferenceVariable):
lvalue = lvalue.points_to
if isinstance(lvalue, (LocalIRVariable, StateIRVariable)):
if isinstance(lvalue, LocalIRVariable):
new_var = LocalIRVariable(lvalue)
new_var.index = all_local_variables_instances[lvalue.name].index + 1
all_local_variables_instances[lvalue.name] = new_var
local_variables_instances[lvalue.name] = new_var
else:
new_var = StateIRVariable(lvalue)
new_var.index = all_state_variables_instances[lvalue.canonical_name].index + 1
all_state_variables_instances[lvalue.canonical_name] = new_var
state_variables_instances[lvalue.canonical_name] = new_var
if update_through_ref:
phi_operation = Phi(new_var, {node})
phi_operation.rvalues = [lvalue]
node.add_ssa_ir(phi_operation)
if not isinstance(new_ir.lvalue, ReferenceVariable):
new_ir.lvalue = new_var
else:
to_update = new_ir.lvalue
while isinstance(to_update.points_to, ReferenceVariable):
to_update = to_update.points_to
to_update.points_to = new_var
# endregion
###################################################################################
###################################################################################
# region Initialization
###################################################################################
###################################################################################
def initiate_all_local_variables_instances(nodes, local_variables_instances, all_local_variables_instances):
for node in nodes:
if node.variable_declaration:
new_var = LocalIRVariable(node.variable_declaration)
if new_var.name in all_local_variables_instances:
new_var.index = all_local_variables_instances[new_var.name].index + 1
local_variables_instances[node.variable_declaration.name] = new_var
all_local_variables_instances[node.variable_declaration.name] = new_var
# endregion
###################################################################################
###################################################################################
# region Phi Operations
###################################################################################
###################################################################################
def fix_phi_rvalues_and_storage_ref(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances): def fix_phi_rvalues_and_storage_ref(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances):
for ir in node.irs_ssa: for ir in node.irs_ssa:
if isinstance(ir, (Phi)) and not ir.rvalues: if isinstance(ir, (Phi)) and not ir.rvalues:
@ -367,6 +408,14 @@ def add_phi_origins(node, local_variables_definition, state_variables_definition
for succ in node.dominator_successors: for succ in node.dominator_successors:
add_phi_origins(succ, local_variables_definition, state_variables_definition) add_phi_origins(succ, local_variables_definition, state_variables_definition)
# endregion
###################################################################################
###################################################################################
# region IR copy
###################################################################################
###################################################################################
def copy_ir(ir, local_variables_instances, state_variables_instances, temporary_variables_instances, reference_variables_instances, all_local_variables_instances): def copy_ir(ir, local_variables_instances, state_variables_instances, temporary_variables_instances, reference_variables_instances, all_local_variables_instances):
''' '''
Args: Args:
@ -591,4 +640,4 @@ def copy_ir(ir, local_variables_instances, state_variables_instances, temporary_
logger.error('Impossible ir copy on {} ({})'.format(ir, type(ir))) logger.error('Impossible ir copy on {} ({})'.format(ir, type(ir)))
exit(-1) exit(-1)
# endregion

@ -2,16 +2,13 @@ import logging
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
from slither.slithir.variables import StateIRVariable
from slither.solc_parsing.declarations.structure import StructureSolc
from slither.solc_parsing.declarations.event import EventSolc from slither.solc_parsing.declarations.event import EventSolc
from slither.solc_parsing.declarations.modifier import ModifierSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.modifier import ModifierSolc
from slither.solc_parsing.variables.state_variable import StateVariableSolc from slither.solc_parsing.declarations.structure import StructureSolc
from slither.solc_parsing.solidity_types.type_parsing import parse_type from slither.solc_parsing.solidity_types.type_parsing import parse_type
from slither.solc_parsing.variables.state_variable import StateVariableSolc
from slither.slithir.variables import StateIRVariable
logger = logging.getLogger("ContractSolcParsing") logger = logging.getLogger("ContractSolcParsing")
@ -54,10 +51,25 @@ class ContractSolc04(Contract):
self._parse_contract_items() self._parse_contract_items()
###################################################################################
###################################################################################
# region General Properties
###################################################################################
###################################################################################
@property @property
def is_analyzed(self): def is_analyzed(self):
return self._is_analyzed return self._is_analyzed
def set_is_analyzed(self, is_analyzed):
self._is_analyzed = is_analyzed
###################################################################################
###################################################################################
# region AST
###################################################################################
###################################################################################
def get_key(self): def get_key(self):
return self.slither.get_key() return self.slither.get_key()
@ -74,8 +86,12 @@ class ContractSolc04(Contract):
def is_compact_ast(self): def is_compact_ast(self):
return self.slither.is_compact_ast return self.slither.is_compact_ast
def set_is_analyzed(self, is_analyzed): # endregion
self._is_analyzed = is_analyzed ###################################################################################
###################################################################################
# region SlithIR
###################################################################################
###################################################################################
def _parse_contract_info(self): def _parse_contract_info(self):
if self.is_compact_ast: if self.is_compact_ast:
@ -174,70 +190,6 @@ class ContractSolc04(Contract):
exit(-1) exit(-1)
return return
def analyze_using_for(self):
for father in self.inheritance:
self._using_for.update(father.using_for)
if self.is_compact_ast:
for using_for in self._usingForNotParsed:
lib_name = parse_type(using_for['libraryName'], self)
if 'typeName' in using_for and using_for['typeName']:
type_name = parse_type(using_for['typeName'], self)
else:
type_name = '*'
if not type_name in self._using_for:
self.using_for[type_name] = []
self._using_for[type_name].append(lib_name)
else:
for using_for in self._usingForNotParsed:
children = using_for[self.get_children()]
assert children and len(children) <= 2
if len(children) == 2:
new = parse_type(children[0], self)
old = parse_type(children[1], self)
else:
new = parse_type(children[0], self)
old = '*'
if not old in self._using_for:
self.using_for[old] = []
self._using_for[old].append(new)
self._usingForNotParsed = []
def analyze_enums(self):
for father in self.inheritance:
self._enums.update(father.enums_as_dict())
for enum in self._enumsNotParsed:
# for enum, we can parse and analyze it
# at the same time
self._analyze_enum(enum)
self._enumsNotParsed = None
def _analyze_enum(self, enum):
# Enum can be parsed in one pass
if self.is_compact_ast:
name = enum['name']
canonicalName = enum['canonicalName']
else:
name = enum['attributes'][self.get_key()]
if 'canonicalName' in enum['attributes']:
canonicalName = enum['attributes']['canonicalName']
else:
canonicalName = self.name + '.' + name
values = []
for child in enum[self.get_children('members')]:
assert child[self.get_key()] == 'EnumValue'
if self.is_compact_ast:
values.append(child['name'])
else:
values.append(child['attributes'][self.get_key()])
new_enum = Enum(name, canonicalName, values)
new_enum.set_contract(self)
new_enum.set_offset(enum['src'], self.slither)
self._enums[canonicalName] = new_enum
def _parse_struct(self, struct): def _parse_struct(self, struct):
if self.is_compact_ast: if self.is_compact_ast:
name = struct['name'] name = struct['name']
@ -259,9 +211,6 @@ class ContractSolc04(Contract):
st.set_offset(struct['src'], self.slither) st.set_offset(struct['src'], self.slither)
self._structures[name] = st self._structures[name] = st
def _analyze_struct(self, struct):
struct.analyze()
def parse_structs(self): def parse_structs(self):
for father in self.inheritance_reverse: for father in self.inheritance_reverse:
self._structures.update(father.structures_as_dict()) self._structures.update(father.structures_as_dict())
@ -270,24 +219,6 @@ class ContractSolc04(Contract):
self._parse_struct(struct) self._parse_struct(struct)
self._structuresNotParsed = None self._structuresNotParsed = None
def analyze_structs(self):
for struct in self.structures:
self._analyze_struct(struct)
def analyze_events(self):
for father in self.inheritance_reverse:
self._events.update(father.events_as_dict())
for event_to_parse in self._eventsNotParsed:
event = EventSolc(event_to_parse, self)
event.analyze(self)
event.set_contract(self)
event.set_offset(event_to_parse['src'], self.slither)
self._events[event.full_name] = event
self._eventsNotParsed = None
def parse_state_variables(self): def parse_state_variables(self):
for father in self.inheritance_reverse: for father in self.inheritance_reverse:
self._variables.update(father.variables_as_dict()) self._variables.update(father.variables_as_dict())
@ -299,22 +230,6 @@ class ContractSolc04(Contract):
self._variables[var.name] = var self._variables[var.name] = var
def analyze_constant_state_variables(self):
from slither.solc_parsing.expressions.expression_parsing import VariableNotFound
for var in self.variables:
if var.is_constant:
# cant parse constant expression based on function calls
try:
var.analyze(self)
except VariableNotFound:
pass
return
def analyze_state_variables(self):
for var in self.variables:
var.analyze(self)
return
def _parse_modifier(self, modifier): def _parse_modifier(self, modifier):
modif = ModifierSolc(modifier, self) modif = ModifierSolc(modifier, self)
@ -347,6 +262,24 @@ class ContractSolc04(Contract):
return return
# endregion
###################################################################################
###################################################################################
# region Analyze
###################################################################################
###################################################################################
def analyze_content_modifiers(self):
for modifier in self.modifiers:
modifier.analyze_content()
return
def analyze_content_functions(self):
for function in self.functions:
function.analyze_content()
return
def analyze_params_modifiers(self): def analyze_params_modifiers(self):
for father in self.inheritance_reverse: for father in self.inheritance_reverse:
self._modifiers.update(father.modifiers_as_dict()) self._modifiers.update(father.modifiers_as_dict())
@ -390,17 +323,114 @@ class ContractSolc04(Contract):
self._functions_no_params = [] self._functions_no_params = []
return return
def analyze_content_modifiers(self): def analyze_constant_state_variables(self):
for modifier in self.modifiers: from slither.solc_parsing.expressions.expression_parsing import VariableNotFound
modifier.analyze_content() for var in self.variables:
if var.is_constant:
# cant parse constant expression based on function calls
try:
var.analyze(self)
except VariableNotFound:
pass
return return
def analyze_content_functions(self): def analyze_state_variables(self):
for function in self.functions: for var in self.variables:
function.analyze_content() var.analyze(self)
return return
def analyze_using_for(self):
for father in self.inheritance:
self._using_for.update(father.using_for)
if self.is_compact_ast:
for using_for in self._usingForNotParsed:
lib_name = parse_type(using_for['libraryName'], self)
if 'typeName' in using_for and using_for['typeName']:
type_name = parse_type(using_for['typeName'], self)
else:
type_name = '*'
if not type_name in self._using_for:
self.using_for[type_name] = []
self._using_for[type_name].append(lib_name)
else:
for using_for in self._usingForNotParsed:
children = using_for[self.get_children()]
assert children and len(children) <= 2
if len(children) == 2:
new = parse_type(children[0], self)
old = parse_type(children[1], self)
else:
new = parse_type(children[0], self)
old = '*'
if not old in self._using_for:
self.using_for[old] = []
self._using_for[old].append(new)
self._usingForNotParsed = []
def analyze_enums(self):
for father in self.inheritance:
self._enums.update(father.enums_as_dict())
for enum in self._enumsNotParsed:
# for enum, we can parse and analyze it
# at the same time
self._analyze_enum(enum)
self._enumsNotParsed = None
def _analyze_enum(self, enum):
# Enum can be parsed in one pass
if self.is_compact_ast:
name = enum['name']
canonicalName = enum['canonicalName']
else:
name = enum['attributes'][self.get_key()]
if 'canonicalName' in enum['attributes']:
canonicalName = enum['attributes']['canonicalName']
else:
canonicalName = self.name + '.' + name
values = []
for child in enum[self.get_children('members')]:
assert child[self.get_key()] == 'EnumValue'
if self.is_compact_ast:
values.append(child['name'])
else:
values.append(child['attributes'][self.get_key()])
new_enum = Enum(name, canonicalName, values)
new_enum.set_contract(self)
new_enum.set_offset(enum['src'], self.slither)
self._enums[canonicalName] = new_enum
def _analyze_struct(self, struct):
struct.analyze()
def analyze_structs(self):
for struct in self.structures:
self._analyze_struct(struct)
def analyze_events(self):
for father in self.inheritance_reverse:
self._events.update(father.events_as_dict())
for event_to_parse in self._eventsNotParsed:
event = EventSolc(event_to_parse, self)
event.analyze(self)
event.set_contract(self)
event.set_offset(event_to_parse['src'], self.slither)
self._events[event.full_name] = event
self._eventsNotParsed = None
# endregion
###################################################################################
###################################################################################
# region SlithIR
###################################################################################
###################################################################################
def convert_expression_to_slithir(self): def convert_expression_to_slithir(self):
for func in self.functions + self.modifiers: for func in self.functions + self.modifiers:
@ -442,5 +472,14 @@ class ContractSolc04(Contract):
func.fix_phi(last_state_variables_instances, initial_state_variables_instances) func.fix_phi(last_state_variables_instances, initial_state_variables_instances)
# endregion
###################################################################################
###################################################################################
# region Built in definitions
###################################################################################
###################################################################################
def __hash__(self): def __hash__(self):
return self._id return self._id
# endregion

@ -1,20 +1,19 @@
""" """
Event module
""" """
import logging import logging
from slither.core.cfg.node import NodeType, link_nodes from slither.core.cfg.node import NodeType, link_nodes
from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function from slither.core.declarations.function import Function
from slither.core.dominators.utils import (compute_dominance_frontier, from slither.core.dominators.utils import (compute_dominance_frontier,
compute_dominators) compute_dominators)
from slither.core.expressions import AssignmentOperation from slither.core.expressions import AssignmentOperation
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import (Assignment, HighLevelCall, from slither.slithir.operations import (InternalCall, OperationWithLValue, Phi,
InternalCall, InternalDynamicCall, PhiCallback)
LowLevelCall, OperationWithLValue, Phi,
PhiCallback, LibraryCall)
from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa
from slither.slithir.variables import LocalIRVariable, ReferenceVariable from slither.slithir.variables import (Constant, ReferenceVariable,
StateIRVariable)
from slither.solc_parsing.cfg.node import NodeSolc from slither.solc_parsing.cfg.node import NodeSolc
from slither.solc_parsing.expressions.expression_parsing import \ from slither.solc_parsing.expressions.expression_parsing import \
parse_expression parse_expression
@ -24,18 +23,14 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import \
from slither.solc_parsing.variables.variable_declaration import \ from slither.solc_parsing.variables.variable_declaration import \
MultipleVariablesDeclaration MultipleVariablesDeclaration
from slither.utils.expression_manipulations import SplitTernaryExpression from slither.utils.expression_manipulations import SplitTernaryExpression
from slither.utils.utils import unroll
from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.export_values import ExportValues
from slither.visitors.expression.has_conditional import HasConditional from slither.visitors.expression.has_conditional import HasConditional
from slither.core.declarations.contract import Contract
from slither.slithir.variables import StateIRVariable, LocalIRVariable, Constant
from slither.utils.utils import unroll
logger = logging.getLogger("FunctionSolc") logger = logging.getLogger("FunctionSolc")
class FunctionSolc(Function): class FunctionSolc(Function):
""" """
Event class
""" """
# elems = [(type, name)] # elems = [(type, name)]
@ -60,6 +55,12 @@ class FunctionSolc(Function):
# which is only possible with solc > 0.5 # which is only possible with solc > 0.5
self._variables_renamed = {} self._variables_renamed = {}
###################################################################################
###################################################################################
# region AST format
###################################################################################
###################################################################################
def get_key(self): def get_key(self):
return self.slither.get_key() return self.slither.get_key()
@ -72,6 +73,13 @@ class FunctionSolc(Function):
def is_compact_ast(self): def is_compact_ast(self):
return self.slither.is_compact_ast return self.slither.is_compact_ast
# endregion
###################################################################################
###################################################################################
# region Variables
###################################################################################
###################################################################################
@property @property
def variables_renamed(self): def variables_renamed(self):
return self._variables_renamed return self._variables_renamed
@ -89,6 +97,12 @@ class FunctionSolc(Function):
self._variables_renamed[local_var.reference_id] = local_var self._variables_renamed[local_var.reference_id] = local_var
self._variables[local_var.name] = local_var self._variables[local_var.name] = local_var
# endregion
###################################################################################
###################################################################################
# region Analyses
###################################################################################
###################################################################################
def _analyze_attributes(self): def _analyze_attributes(self):
if self.is_compact_ast: if self.is_compact_ast:
@ -133,6 +147,77 @@ class FunctionSolc(Function):
if 'payable' in attributes: if 'payable' in attributes:
self._payable = attributes['payable'] self._payable = attributes['payable']
def analyze_params(self):
# Can be re-analyzed due to inheritance
if self._params_was_analyzed:
return
self._params_was_analyzed = True
self._analyze_attributes()
if self.is_compact_ast:
params = self._functionNotParsed['parameters']
returns = self._functionNotParsed['returnParameters']
else:
children = self._functionNotParsed[self.get_children('children')]
params = children[0]
returns = children[1]
if params:
self._parse_params(params)
if returns:
self._parse_returns(returns)
def analyze_content(self):
if self._content_was_analyzed:
return
self._content_was_analyzed = True
if self.is_compact_ast:
body = self._functionNotParsed['body']
if body and body[self.get_key()] == 'Block':
self._is_implemented = True
self._parse_cfg(body)
for modifier in self._functionNotParsed['modifiers']:
self._parse_modifier(modifier)
else:
children = self._functionNotParsed[self.get_children('children')]
self._is_implemented = False
for child in children[2:]:
if child[self.get_key()] == 'Block':
self._is_implemented = True
self._parse_cfg(child)
# Parse modifier after parsing all the block
# In the case a local variable is used in the modifier
for child in children[2:]:
if child[self.get_key()] == 'ModifierInvocation':
self._parse_modifier(child)
for local_vars in self.variables:
local_vars.analyze(self)
for node in self.nodes:
node.analyze_expressions(self)
self._filter_ternary()
self._remove_alone_endif()
# endregion
###################################################################################
###################################################################################
# region Nodes
###################################################################################
###################################################################################
def _new_node(self, node_type, src): def _new_node(self, node_type, src):
node = NodeSolc(node_type, self._counter_nodes) node = NodeSolc(node_type, self._counter_nodes)
node.set_offset(src, self.slither) node.set_offset(src, self.slither)
@ -141,6 +226,13 @@ class FunctionSolc(Function):
self._nodes.append(node) self._nodes.append(node)
return node return node
# endregion
###################################################################################
###################################################################################
# region Parsing function
###################################################################################
###################################################################################
def _parse_if(self, ifStatement, node): def _parse_if(self, ifStatement, node):
# IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )?
falseStatement = None falseStatement = None
@ -662,6 +754,13 @@ class FunctionSolc(Function):
self._remove_incorrect_edges() self._remove_incorrect_edges()
self._remove_alone_endif() self._remove_alone_endif()
# endregion
###################################################################################
###################################################################################
# region Loops
###################################################################################
###################################################################################
def _find_end_loop(self, node, visited, counter): def _find_end_loop(self, node, visited, counter):
# counter allows to explore nested loop # counter allows to explore nested loop
if node in visited: if node in visited:
@ -723,48 +822,6 @@ class FunctionSolc(Function):
node.set_sons([start_node]) node.set_sons([start_node])
start_node.add_father(node) start_node.add_father(node)
def _remove_incorrect_edges(self):
for node in self._nodes:
if node.type in [NodeType.RETURN, NodeType.THROW]:
for son in node.sons:
son.remove_father(node)
node.set_sons([])
if node.type in [NodeType.BREAK]:
self._fix_break_node(node)
if node.type in [NodeType.CONTINUE]:
self._fix_continue_node(node)
def _remove_alone_endif(self):
"""
Can occur on:
if(..){
return
}
else{
return
}
Iterate until a fix point to remove the ENDIF node
creates on the following pattern
if(){
return
}
else if(){
return
}
"""
prev_nodes = []
while set(prev_nodes) != set(self.nodes):
prev_nodes = self.nodes
to_remove = []
for node in self.nodes:
if node.type == NodeType.ENDIF and not node.fathers:
for son in node.sons:
son.remove_father(node)
node.set_sons([])
to_remove.append(node)
self._nodes = [n for n in self.nodes if not n in to_remove]
#
def _parse_params(self, params): def _parse_params(self, params):
assert params[self.get_key()] == 'ParameterList' assert params[self.get_key()] == 'ParameterList'
@ -824,67 +881,61 @@ class FunctionSolc(Function):
elif isinstance(m, Contract): elif isinstance(m, Contract):
self._explicit_base_constructor_calls.append(m) self._explicit_base_constructor_calls.append(m)
# endregion
###################################################################################
###################################################################################
# region Edges
###################################################################################
###################################################################################
def analyze_params(self): def _remove_incorrect_edges(self):
# Can be re-analyzed due to inheritance for node in self._nodes:
if self._params_was_analyzed: if node.type in [NodeType.RETURN, NodeType.THROW]:
return for son in node.sons:
son.remove_father(node)
self._params_was_analyzed = True node.set_sons([])
if node.type in [NodeType.BREAK]:
self._analyze_attributes() self._fix_break_node(node)
if node.type in [NodeType.CONTINUE]:
if self.is_compact_ast: self._fix_continue_node(node)
params = self._functionNotParsed['parameters']
returns = self._functionNotParsed['returnParameters']
else:
children = self._functionNotParsed[self.get_children('children')]
params = children[0]
returns = children[1]
if params:
self._parse_params(params)
if returns:
self._parse_returns(returns)
def analyze_content(self):
if self._content_was_analyzed:
return
self._content_was_analyzed = True
if self.is_compact_ast:
body = self._functionNotParsed['body']
if body and body[self.get_key()] == 'Block':
self._is_implemented = True
self._parse_cfg(body)
for modifier in self._functionNotParsed['modifiers']:
self._parse_modifier(modifier)
else:
children = self._functionNotParsed[self.get_children('children')]
self._is_implemented = False
for child in children[2:]:
if child[self.get_key()] == 'Block':
self._is_implemented = True
self._parse_cfg(child)
# Parse modifier after parsing all the block
# In the case a local variable is used in the modifier
for child in children[2:]:
if child[self.get_key()] == 'ModifierInvocation':
self._parse_modifier(child)
for local_vars in self.variables: def _remove_alone_endif(self):
local_vars.analyze(self) """
Can occur on:
if(..){
return
}
else{
return
}
for node in self.nodes: Iterate until a fix point to remove the ENDIF node
node.analyze_expressions(self) creates on the following pattern
if(){
return
}
else if(){
return
}
"""
prev_nodes = []
while set(prev_nodes) != set(self.nodes):
prev_nodes = self.nodes
to_remove = []
for node in self.nodes:
if node.type == NodeType.ENDIF and not node.fathers:
for son in node.sons:
son.remove_father(node)
node.set_sons([])
to_remove.append(node)
self._nodes = [n for n in self.nodes if not n in to_remove]
self._filter_ternary() # endregion
self._remove_alone_endif() ###################################################################################
###################################################################################
# region Ternary
###################################################################################
###################################################################################
def _filter_ternary(self): def _filter_ternary(self):
ternary_found = True ternary_found = True
@ -902,6 +953,64 @@ class FunctionSolc(Function):
ternary_found = True ternary_found = True
break break
def split_ternary_node(self, node, condition, true_expr, false_expr):
condition_node = self._new_node(NodeType.IF, node.source_mapping)
condition_node.add_expression(condition)
condition_node.analyze_expressions(self)
if node.type == NodeType.VARIABLE:
condition_node.add_variable_declaration(node.variable_declaration)
true_node = self._new_node(NodeType.EXPRESSION, node.source_mapping)
if node.type == NodeType.VARIABLE:
assert isinstance(true_expr, AssignmentOperation)
#true_expr = true_expr.expression_right
elif node.type == NodeType.RETURN:
true_node.type = NodeType.RETURN
true_node.add_expression(true_expr)
true_node.analyze_expressions(self)
false_node = self._new_node(NodeType.EXPRESSION, node.source_mapping)
if node.type == NodeType.VARIABLE:
assert isinstance(false_expr, AssignmentOperation)
elif node.type == NodeType.RETURN:
false_node.type = NodeType.RETURN
#false_expr = false_expr.expression_right
false_node.add_expression(false_expr)
false_node.analyze_expressions(self)
endif_node = self._new_node(NodeType.ENDIF, node.source_mapping)
for father in node.fathers:
father.remove_son(node)
father.add_son(condition_node)
condition_node.add_father(father)
for son in node.sons:
son.remove_father(node)
son.add_father(endif_node)
endif_node.add_son(son)
link_nodes(condition_node, true_node)
link_nodes(condition_node, false_node)
if not true_node.type in [NodeType.THROW, NodeType.RETURN]:
link_nodes(true_node, endif_node)
if not false_node.type in [NodeType.THROW, NodeType.RETURN]:
link_nodes(false_node, endif_node)
self._nodes = [n for n in self._nodes if n.node_id != node.node_id]
# endregion
###################################################################################
###################################################################################
# region SlithIr and SSA
###################################################################################
###################################################################################
def get_last_ssa_state_variables_instances(self): def get_last_ssa_state_variables_instances(self):
if not self.is_implemented: if not self.is_implemented:
return dict() return dict()
@ -1006,53 +1115,3 @@ class FunctionSolc(Function):
node.update_read_write_using_ssa() node.update_read_write_using_ssa()
self._analyze_read_write() self._analyze_read_write()
def split_ternary_node(self, node, condition, true_expr, false_expr):
condition_node = self._new_node(NodeType.IF, node.source_mapping)
condition_node.add_expression(condition)
condition_node.analyze_expressions(self)
if node.type == NodeType.VARIABLE:
condition_node.add_variable_declaration(node.variable_declaration)
true_node = self._new_node(NodeType.EXPRESSION, node.source_mapping)
if node.type == NodeType.VARIABLE:
assert isinstance(true_expr, AssignmentOperation)
#true_expr = true_expr.expression_right
elif node.type == NodeType.RETURN:
true_node.type = NodeType.RETURN
true_node.add_expression(true_expr)
true_node.analyze_expressions(self)
false_node = self._new_node(NodeType.EXPRESSION, node.source_mapping)
if node.type == NodeType.VARIABLE:
assert isinstance(false_expr, AssignmentOperation)
elif node.type == NodeType.RETURN:
false_node.type = NodeType.RETURN
#false_expr = false_expr.expression_right
false_node.add_expression(false_expr)
false_node.analyze_expressions(self)
endif_node = self._new_node(NodeType.ENDIF, node.source_mapping)
for father in node.fathers:
father.remove_son(node)
father.add_son(condition_node)
condition_node.add_father(father)
for son in node.sons:
son.remove_father(node)
son.add_father(endif_node)
endif_node.add_son(son)
link_nodes(condition_node, true_node)
link_nodes(condition_node, false_node)
if not true_node.type in [NodeType.THROW, NodeType.RETURN]:
link_nodes(true_node, endif_node)
if not false_node.type in [NodeType.THROW, NodeType.RETURN]:
link_nodes(false_node, endif_node)
self._nodes = [n for n in self._nodes if n.node_id != node.node_id]

@ -1,38 +1,59 @@
import logging import logging
import re import re
from slither.core.expressions.unary_operation import UnaryOperation, UnaryOperationType
from slither.core.expressions.binary_operation import BinaryOperation, BinaryOperationType from slither.core.declarations.contract import Contract
from slither.core.expressions.literal import Literal from slither.core.declarations.function import Function
from slither.core.declarations.solidity_variables import (SOLIDITY_FUNCTIONS,
SOLIDITY_VARIABLES,
SOLIDITY_VARIABLES_COMPOSED,
SolidityFunction,
SolidityVariable,
SolidityVariableComposed)
from slither.core.expressions.assignment_operation import (AssignmentOperation,
AssignmentOperationType)
from slither.core.expressions.binary_operation import (BinaryOperation,
BinaryOperationType)
from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.conditional_expression import \
ConditionalExpression
from slither.core.expressions.elementary_type_name_expression import \
ElementaryTypeNameExpression
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
from slither.core.expressions.super_identifier import SuperIdentifier
from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.index_access import IndexAccess
from slither.core.expressions.literal import Literal
from slither.core.expressions.member_access import MemberAccess from slither.core.expressions.member_access import MemberAccess
from slither.core.expressions.tuple_expression import TupleExpression
from slither.core.expressions.conditional_expression import ConditionalExpression
from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType
from slither.core.expressions.type_conversion import TypeConversion
from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.super_call_expression import SuperCallExpression
from slither.core.expressions.new_array import NewArray from slither.core.expressions.new_array import NewArray
from slither.core.expressions.new_contract import NewContract from slither.core.expressions.new_contract import NewContract
from slither.core.expressions.new_elementary_type import NewElementaryType from slither.core.expressions.new_elementary_type import NewElementaryType
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.super_call_expression import SuperCallExpression
from slither.core.expressions.super_identifier import SuperIdentifier
from slither.solc_parsing.solidity_types.type_parsing import parse_type, UnknownType from slither.core.expressions.tuple_expression import TupleExpression
from slither.core.expressions.type_conversion import TypeConversion
from slither.core.declarations.contract import Contract from slither.core.expressions.unary_operation import (UnaryOperation,
from slither.core.declarations.function import Function UnaryOperationType)
from slither.core.solidity_types import (ArrayType, ElementaryType,
from slither.core.declarations.solidity_variables import SOLIDITY_VARIABLES, SOLIDITY_FUNCTIONS, SOLIDITY_VARIABLES_COMPOSED FunctionType, MappingType)
from slither.core.declarations.solidity_variables import SolidityVariable, SolidityFunction, SolidityVariableComposed, solidity_function_signature from slither.solc_parsing.solidity_types.type_parsing import (UnknownType,
parse_type)
from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, FunctionType logger = logging.getLogger("ExpressionParsing")
logger = logging.getLogger("ExpressionParsing") ###################################################################################
###################################################################################
# region Exception
###################################################################################
###################################################################################
class VariableNotFound(Exception): pass class VariableNotFound(Exception): pass
# endregion
###################################################################################
###################################################################################
# region Helpers
###################################################################################
###################################################################################
def get_pointer_name(variable): def get_pointer_name(variable):
curr_type = variable.type curr_type = variable.type
while(isinstance(curr_type, (ArrayType, MappingType))): while(isinstance(curr_type, (ArrayType, MappingType))):
@ -135,6 +156,92 @@ def find_variable(var_name, caller_context, referenced_declaration=None):
raise VariableNotFound('Variable not found: {}'.format(var_name)) raise VariableNotFound('Variable not found: {}'.format(var_name))
# endregion
###################################################################################
###################################################################################
# region Filtering
###################################################################################
###################################################################################
def filter_name(value):
value = value.replace(' memory', '')
value = value.replace(' storage', '')
value = value.replace(' external', '')
value = value.replace(' internal', '')
value = value.replace('struct ', '')
value = value.replace('contract ', '')
value = value.replace('enum ', '')
value = value.replace(' ref', '')
value = value.replace(' pointer', '')
value = value.replace(' pure', '')
value = value.replace(' view', '')
value = value.replace(' constant', '')
value = value.replace(' payable', '')
value = value.replace('function (', 'function(')
value = value.replace('returns (', 'returns(')
# remove the text remaining after functio(...)
# which should only be ..returns(...)
# nested parenthesis so we use a system of counter on parenthesis
idx = value.find('(')
if idx:
counter = 1
max_idx = len(value)
while counter:
assert idx < max_idx
idx = idx +1
if value[idx] == '(':
counter += 1
elif value[idx] == ')':
counter -= 1
value = value[:idx+1]
return value
# endregion
###################################################################################
###################################################################################
# region Conversion
###################################################################################
###################################################################################
def convert_subdenomination(value, sub):
if sub is None:
return value
# to allow 0.1 ether conversion
if value[0:2] == "0x":
value = float(int(value, 16))
else:
value = float(value)
if sub == 'wei':
return int(value)
if sub == 'szabo':
return int(value * int(1e12))
if sub == 'finney':
return int(value * int(1e15))
if sub == 'ether':
return int(value * int(1e18))
if sub == 'seconds':
return int(value)
if sub == 'minutes':
return int(value * 60)
if sub == 'hours':
return int(value * 60 * 60)
if sub == 'days':
return int(value * 60 * 60 * 24)
if sub == 'weeks':
return int(value * 60 * 60 * 24 * 7)
if sub == 'years':
return int(value * 60 * 60 * 24 * 7 * 365)
logger.error('Subdemoniation not found {}'.format(sub))
return int(value)
# endregion
###################################################################################
###################################################################################
# region Parsing
###################################################################################
###################################################################################
def parse_call(expression, caller_context): def parse_call(expression, caller_context):
@ -208,72 +315,6 @@ def parse_super_name(expression, is_compact_ast):
return base_name+arguments return base_name+arguments
def filter_name(value):
value = value.replace(' memory', '')
value = value.replace(' storage', '')
value = value.replace(' external', '')
value = value.replace(' internal', '')
value = value.replace('struct ', '')
value = value.replace('contract ', '')
value = value.replace('enum ', '')
value = value.replace(' ref', '')
value = value.replace(' pointer', '')
value = value.replace(' pure', '')
value = value.replace(' view', '')
value = value.replace(' constant', '')
value = value.replace(' payable', '')
value = value.replace('function (', 'function(')
value = value.replace('returns (', 'returns(')
# remove the text remaining after functio(...)
# which should only be ..returns(...)
# nested parenthesis so we use a system of counter on parenthesis
idx = value.find('(')
if idx:
counter = 1
max_idx = len(value)
while counter:
assert idx < max_idx
idx = idx +1
if value[idx] == '(':
counter += 1
elif value[idx] == ')':
counter -= 1
value = value[:idx+1]
return value
def convert_subdenomination(value, sub):
if sub is None:
return value
# to allow 0.1 ether conversion
if value[0:2] == "0x":
value = float(int(value, 16))
else:
value = float(value)
if sub == 'wei':
return int(value)
if sub == 'szabo':
return int(value * int(1e12))
if sub == 'finney':
return int(value * int(1e15))
if sub == 'ether':
return int(value * int(1e18))
if sub == 'seconds':
return int(value)
if sub == 'minutes':
return int(value * 60)
if sub == 'hours':
return int(value * 60 * 60)
if sub == 'days':
return int(value * 60 * 60 * 24)
if sub == 'weeks':
return int(value * 60 * 60 * 24 * 7)
if sub == 'years':
return int(value * 60 * 60 * 24 * 7 * 365)
logger.error('Subdemoniation not found {}'.format(sub))
return int(value)
def parse_expression(expression, caller_context): def parse_expression(expression, caller_context):
""" """

@ -26,6 +26,13 @@ class SlitherSolc(Slither):
self._is_compact_ast = False self._is_compact_ast = False
###################################################################################
###################################################################################
# region AST
###################################################################################
###################################################################################
def get_key(self): def get_key(self):
if self._is_compact_ast: if self._is_compact_ast:
return 'nodeType' return 'nodeType'
@ -40,6 +47,13 @@ class SlitherSolc(Slither):
def is_compact_ast(self): def is_compact_ast(self):
return self._is_compact_ast return self._is_compact_ast
# endregion
###################################################################################
###################################################################################
# region Parsing
###################################################################################
###################################################################################
def _parse_contracts_from_json(self, json_data): def _parse_contracts_from_json(self, json_data):
try: try:
data_loaded = json.loads(json_data) data_loaded = json.loads(json_data)
@ -148,6 +162,16 @@ class SlitherSolc(Slither):
source_code = f.read() source_code = f.read()
self.source_code[name] = source_code self.source_code[name] = source_code
# endregion
###################################################################################
###################################################################################
# region Analyze
###################################################################################
###################################################################################
@property
def analyzed(self):
return self._analyzed
def _analyze_contracts(self): def _analyze_contracts(self):
if not self._contractsNotParsed: if not self._contractsNotParsed:
@ -234,11 +258,6 @@ class SlitherSolc(Slither):
compute_dependency(self) compute_dependency(self)
# TODO refactor the following functions, and use a lambda function
@property
def analyzed(self):
return self._analyzed
def _analyze_all_enums(self, contracts_to_be_analyzed): def _analyze_all_enums(self, contracts_to_be_analyzed):
while contracts_to_be_analyzed: while contracts_to_be_analyzed:
@ -362,4 +381,4 @@ class SlitherSolc(Slither):
contract.fix_phi() contract.fix_phi()
contract.update_read_write_using_ssa() contract.update_read_write_using_ssa()
# endregion

Loading…
Cancel
Save