@ -14,7 +14,12 @@ from slither.core.declarations.solidity_variables import (
SolidityVariable ,
SolidityVariable ,
SolidityVariableComposed ,
SolidityVariableComposed ,
)
)
from slither . core . expressions import Identifier , IndexAccess , MemberAccess , UnaryOperation
from slither . core . expressions import (
Identifier ,
IndexAccess ,
MemberAccess ,
UnaryOperation ,
)
from slither . core . solidity_types import UserDefinedType
from slither . core . solidity_types import UserDefinedType
from slither . core . solidity_types . type import Type
from slither . core . solidity_types . type import Type
from slither . core . source_mapping . source_mapping import SourceMapping
from slither . core . source_mapping . source_mapping import SourceMapping
@ -23,6 +28,8 @@ from slither.core.variables.local_variable import LocalVariable
from slither . core . variables . state_variable import StateVariable
from slither . core . variables . state_variable import StateVariable
from slither . utils . utils import unroll
from slither . utils . utils import unroll
# pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines
if TYPE_CHECKING :
if TYPE_CHECKING :
from slither . utils . type_helpers import (
from slither . utils . type_helpers import (
InternalCallType ,
InternalCallType ,
@ -46,7 +53,10 @@ ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
class ModifierStatements :
class ModifierStatements :
def __init__ (
def __init__ (
self , modifier : Union [ " Contract " , " Function " ] , entry_point : " Node " , nodes : List [ " Node " ]
self ,
modifier : Union [ " Contract " , " Function " ] ,
entry_point : " Node " ,
nodes : List [ " Node " ] ,
) :
) :
self . _modifier = modifier
self . _modifier = modifier
self . _entry_point = entry_point
self . _entry_point = entry_point
@ -79,10 +89,26 @@ class FunctionType(Enum):
FALLBACK = 2
FALLBACK = 2
RECEIVE = 3
RECEIVE = 3
CONSTRUCTOR_VARIABLES = 10 # Fake function to hold variable declaration statements
CONSTRUCTOR_VARIABLES = 10 # Fake function to hold variable declaration statements
CONSTRUCTOR_CONSTANT_VARIABLES = 11 # Fake function to hold variable declaration statements
CONSTRUCTOR_CONSTANT_VARIABLES = (
11 # Fake function to hold variable declaration statements
)
class Function ( ChildContract , ChildInheritance , SourceMapping ) :
def _filter_state_variables_written ( expressions : List [ " Expression " ] ) :
ret = [ ]
for expression in expressions :
if isinstance ( expression , Identifier ) :
ret . append ( expression )
if isinstance ( expression , UnaryOperation ) :
ret . append ( expression . expression )
if isinstance ( expression , MemberAccess ) :
ret . append ( expression . expression )
if isinstance ( expression , IndexAccess ) :
ret . append ( expression . expression_left )
return ret
class Function ( ChildContract , ChildInheritance , SourceMapping ) : # pylint: disable=too-many-public-methods
"""
"""
Function class
Function class
"""
"""
@ -147,13 +173,21 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
self . _all_state_variables_written : Optional [ List [ " StateVariable " ] ] = None
self . _all_state_variables_written : Optional [ List [ " StateVariable " ] ] = None
self . _all_slithir_variables : Optional [ List [ " SlithIRVariable " ] ] = None
self . _all_slithir_variables : Optional [ List [ " SlithIRVariable " ] ] = None
self . _all_nodes : Optional [ List [ " Node " ] ] = None
self . _all_nodes : Optional [ List [ " Node " ] ] = None
self . _all_conditional_state_variables_read : Optional [ List [ " StateVariable " ] ] = None
self . _all_conditional_state_variables_read : Optional [
self . _all_conditional_state_variables_read_with_loop : Optional [ List [ " StateVariable " ] ] = None
List [ " StateVariable " ]
self . _all_conditional_solidity_variables_read : Optional [ List [ " SolidityVariable " ] ] = None
] = None
self . _all_conditional_state_variables_read_with_loop : Optional [
List [ " StateVariable " ]
] = None
self . _all_conditional_solidity_variables_read : Optional [
List [ " SolidityVariable " ]
] = None
self . _all_conditional_solidity_variables_read_with_loop : Optional [
self . _all_conditional_solidity_variables_read_with_loop : Optional [
List [ " SolidityVariable " ]
List [ " SolidityVariable " ]
] = None
] = None
self . _all_solidity_variables_used_as_args : Optional [ List [ " SolidityVariable " ] ] = None
self . _all_solidity_variables_used_as_args : Optional [
List [ " SolidityVariable " ]
] = None
self . _is_shadowed : bool = False
self . _is_shadowed : bool = False
self . _shadows : bool = False
self . _shadows : bool = False
@ -187,13 +221,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
"""
"""
if self . _name == " " and self . _function_type == FunctionType . CONSTRUCTOR :
if self . _name == " " and self . _function_type == FunctionType . CONSTRUCTOR :
return " constructor "
return " constructor "
el if self . _function_type == FunctionType . FALLBACK :
if self . _function_type == FunctionType . FALLBACK :
return " fallback "
return " fallback "
el if self . _function_type == FunctionType . RECEIVE :
if self . _function_type == FunctionType . RECEIVE :
return " receive "
return " receive "
el if self . _function_type == FunctionType . CONSTRUCTOR_VARIABLES :
if self . _function_type == FunctionType . CONSTRUCTOR_VARIABLES :
return " slitherConstructorVariables "
return " slitherConstructorVariables "
el if self . _function_type == FunctionType . CONSTRUCTOR_CONSTANT_VARIABLES :
if self . _function_type == FunctionType . CONSTRUCTOR_CONSTANT_VARIABLES :
return " slitherConstructorConstantVariables "
return " slitherConstructorConstantVariables "
return self . _name
return self . _name
@ -815,15 +849,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if self . _return_values is None :
if self . _return_values is None :
return_values = list ( )
return_values = list ( )
returns = [ n for n in self . nodes if n . type == NodeType . RETURN ]
returns = [ n for n in self . nodes if n . type == NodeType . RETURN ]
[
[ # pylint: disable=expression-not-assigned
return_values . extend ( ir . values )
return_values . extend ( ir . values )
for node in returns
for node in returns
for ir in node . irs
for ir in node . irs
if isinstance ( ir , Return )
if isinstance ( ir , Return )
]
]
self . _return_values = list (
self . _return_values = list ( { x for x in return_values if not isinstance ( x , Constant ) } )
set ( [ x for x in return_values if not isinstance ( x , Constant ) ] )
)
return self . _return_values
return self . _return_values
@property
@property
@ -838,15 +870,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if self . _return_values_ssa is None :
if self . _return_values_ssa is None :
return_values_ssa = list ( )
return_values_ssa = list ( )
returns = [ n for n in self . nodes if n . type == NodeType . RETURN ]
returns = [ n for n in self . nodes if n . type == NodeType . RETURN ]
[
[ # pylint: disable=expression-not-assigned
return_values_ssa . extend ( ir . values )
return_values_ssa . extend ( ir . values )
for node in returns
for node in returns
for ir in node . irs_ssa
for ir in node . irs_ssa
if isinstance ( ir , Return )
if isinstance ( ir , Return )
]
]
self . _return_values_ssa = list (
self . _return_values_ssa = list ( { x for x in return_values_ssa if not isinstance ( x , Constant ) } )
set ( [ x for x in return_values_ssa if not isinstance ( x , Constant ) ] )
)
return self . _return_values_ssa
return self . _return_values_ssa
# endregion
# endregion
@ -900,7 +930,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Contract and converted into address
Contract and converted into address
: return : the solidity signature
: return : the solidity signature
"""
"""
parameters = [ self . _convert_type_for_solidity_signature ( x . type ) for x in self . parameters ]
parameters = [
self . _convert_type_for_solidity_signature ( x . type ) for x in self . parameters
]
return self . name + " ( " + " , " . join ( parameters ) + " ) "
return self . name + " ( " + " , " . join ( parameters ) + " ) "
@property
@property
@ -922,7 +954,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Return the function signature as a str ( contains the return values )
Return the function signature as a str ( contains the return values )
"""
"""
name , parameters , returnVars = self . signature
name , parameters , returnVars = self . signature
return name + " ( " + " , " . join ( parameters ) + " ) returns( " + " , " . join ( returnVars ) + " ) "
return (
name
+ " ( "
+ " , " . join ( parameters )
+ " ) returns( "
+ " , " . join ( returnVars )
+ " ) "
)
# endregion
# endregion
###################################################################################
###################################################################################
@ -977,10 +1016,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
values = f_new_values ( self )
values = f_new_values ( self )
explored = [ self ]
explored = [ self ]
to_explore = [
to_explore = [
c for c in self . internal_calls if isinstance ( c , Function ) and c not in explored
c
for c in self . internal_calls
if isinstance ( c , Function ) and c not in explored
]
]
to_explore + = [
to_explore + = [
c for ( _ , c ) in self . library_calls if isinstance ( c , Function ) and c not in explored
c
for ( _ , c ) in self . library_calls
if isinstance ( c , Function ) and c not in explored
]
]
to_explore + = [ m for m in self . modifiers if m not in explored ]
to_explore + = [ m for m in self . modifiers if m not in explored ]
@ -1003,7 +1046,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
for ( _ , c ) in f . library_calls
for ( _ , c ) in f . library_calls
if isinstance ( c , Function ) and c not in explored and c not in to_explore
if isinstance ( c , Function ) and c not in explored and c not in to_explore
]
]
to_explore + = [ m for m in f . modifiers if m not in explored and m not in to_explore ]
to_explore + = [
m for m in f . modifiers if m not in explored and m not in to_explore
]
return list ( set ( values ) )
return list ( set ( values ) )
@ -1029,7 +1074,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" recursive version of slithir_variables
""" recursive version of slithir_variables
"""
"""
if self . _all_slithir_variables is None :
if self . _all_slithir_variables is None :
self . _all_slithir_variables = self . _explore_functions ( lambda x : x . slithir_variables )
self . _all_slithir_variables = self . _explore_functions (
lambda x : x . slithir_variables
)
return self . _all_slithir_variables
return self . _all_slithir_variables
def all_nodes ( self ) - > List [ " Node " ] :
def all_nodes ( self ) - > List [ " Node " ] :
@ -1047,10 +1094,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return self . _all_expressions
return self . _all_expressions
def all_slithir_operations ( self ) - > List [ " Operation " ] :
def all_slithir_operations ( self ) - > List [ " Operation " ] :
"""
"""
if self . _all_slithir_operations is None :
if self . _all_slithir_operations is None :
self . _all_slithir_operations = self . _explore_functions ( lambda x : x . slithir_operations )
self . _all_slithir_operations = self . _explore_functions (
lambda x : x . slithir_operations
)
return self . _all_slithir_operations
return self . _all_slithir_operations
def all_state_variables_written ( self ) - > List [ StateVariable ] :
def all_state_variables_written ( self ) - > List [ StateVariable ] :
@ -1066,21 +1113,27 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" recursive version of internal_calls
""" recursive version of internal_calls
"""
"""
if self . _all_internals_calls is None :
if self . _all_internals_calls is None :
self . _all_internals_calls = self . _explore_functions ( lambda x : x . internal_calls )
self . _all_internals_calls = self . _explore_functions (
lambda x : x . internal_calls
)
return self . _all_internals_calls
return self . _all_internals_calls
def all_low_level_calls ( self ) - > List [ " LowLevelCallType " ] :
def all_low_level_calls ( self ) - > List [ " LowLevelCallType " ] :
""" recursive version of low_level calls
""" recursive version of low_level calls
"""
"""
if self . _all_low_level_calls is None :
if self . _all_low_level_calls is None :
self . _all_low_level_calls = self . _explore_functions ( lambda x : x . low_level_calls )
self . _all_low_level_calls = self . _explore_functions (
lambda x : x . low_level_calls
)
return self . _all_low_level_calls
return self . _all_low_level_calls
def all_high_level_calls ( self ) - > List [ " HighLevelCallType " ] :
def all_high_level_calls ( self ) - > List [ " HighLevelCallType " ] :
""" recursive version of high_level calls
""" recursive version of high_level calls
"""
"""
if self . _all_high_level_calls is None :
if self . _all_high_level_calls is None :
self . _all_high_level_calls = self . _explore_functions ( lambda x : x . high_level_calls )
self . _all_high_level_calls = self . _explore_functions (
lambda x : x . high_level_calls
)
return self . _all_high_level_calls
return self . _all_high_level_calls
def all_library_calls ( self ) - > List [ " LibraryCallType " ] :
def all_library_calls ( self ) - > List [ " LibraryCallType " ] :
@ -1094,15 +1147,23 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
""" recursive version of solidity calls
""" recursive version of solidity calls
"""
"""
if self . _all_solidity_calls is None :
if self . _all_solidity_calls is None :
self . _all_solidity_calls = self . _explore_functions ( lambda x : x . solidity_calls )
self . _all_solidity_calls = self . _explore_functions (
lambda x : x . solidity_calls
)
return self . _all_solidity_calls
return self . _all_solidity_calls
@staticmethod
@staticmethod
def _explore_func_cond_read ( func : " Function " , include_loop : bool ) - > List [ " StateVariable " ] :
def _explore_func_cond_read (
ret = [ n . state_variables_read for n in func . nodes if n . is_conditional ( include_loop ) ]
func : " Function " , include_loop : bool
) - > List [ " StateVariable " ] :
ret = [
n . state_variables_read for n in func . nodes if n . is_conditional ( include_loop )
]
return [ item for sublist in ret for item in sublist ]
return [ item for sublist in ret for item in sublist ]
def all_conditional_state_variables_read ( self , include_loop = True ) - > List [ " StateVariable " ] :
def all_conditional_state_variables_read (
self , include_loop = True
) - > List [ " StateVariable " ] :
"""
"""
Return the state variable used in a condition
Return the state variable used in a condition
@ -1133,12 +1194,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
@staticmethod
@staticmethod
def _explore_func_conditional (
def _explore_func_conditional (
func : " Function " , f : Callable [ [ " Node " ] , List [ SolidityVariable ] ] , include_loop : bool
func : " Function " ,
f : Callable [ [ " Node " ] , List [ SolidityVariable ] ] ,
include_loop : bool ,
) :
) :
ret = [ f ( n ) for n in func . nodes if n . is_conditional ( include_loop ) ]
ret = [ f ( n ) for n in func . nodes if n . is_conditional ( include_loop ) ]
return [ item for sublist in ret for item in sublist ]
return [ item for sublist in ret for item in sublist ]
def all_conditional_solidity_variables_read ( self , include_loop = True ) - > List [ SolidityVariable ] :
def all_conditional_solidity_variables_read (
self , include_loop = True
) - > List [ SolidityVariable ] :
"""
"""
Return the Soldiity variables directly used in a condtion
Return the Soldiity variables directly used in a condtion
@ -1174,7 +1239,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return [ var for var in ret if isinstance ( var , SolidityVariable ) ]
return [ var for var in ret if isinstance ( var , SolidityVariable ) ]
@staticmethod
@staticmethod
def _explore_func_nodes ( func : " Function " , f : Callable [ [ " Node " ] , List [ SolidityVariable ] ] ) :
def _explore_func_nodes (
func : " Function " , f : Callable [ [ " Node " ] , List [ SolidityVariable ] ]
) :
ret = [ f ( n ) for n in func . nodes ]
ret = [ f ( n ) for n in func . nodes ]
return [ item for sublist in ret for item in sublist ]
return [ item for sublist in ret for item in sublist ]
@ -1187,7 +1254,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
"""
"""
if self . _all_solidity_variables_used_as_args is None :
if self . _all_solidity_variables_used_as_args is None :
self . _all_solidity_variables_used_as_args = self . _explore_functions (
self . _all_solidity_variables_used_as_args = self . _explore_functions (
lambda x : self . _explore_func_nodes ( x , self . _solidity_variable_in_internal_calls )
lambda x : self . _explore_func_nodes (
x , self . _solidity_variable_in_internal_calls
)
)
)
return self . _all_solidity_variables_used_as_args
return self . _all_solidity_variables_used_as_args
@ -1217,7 +1286,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
###################################################################################
###################################################################################
###################################################################################
###################################################################################
def get_local_variable_from_name ( self , variable_name : str ) - > Optional [ LocalVariable ] :
def get_local_variable_from_name (
self , variable_name : str
) - > Optional [ LocalVariable ] :
"""
"""
Return a local variable from a name
Return a local variable from a name
@ -1271,7 +1342,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
for node in self . nodes :
for node in self . nodes :
f . write ( ' {} [label= " {} " ]; \n ' . format ( node . node_id , description ( node ) ) )
f . write ( ' {} [label= " {} " ]; \n ' . format ( node . node_id , description ( node ) ) )
if node . immediate_dominator :
if node . immediate_dominator :
f . write ( " {} -> {} ; \n " . format ( node . immediate_dominator . node_id , node . node_id ) )
f . write (
" {} -> {} ; \n " . format (
node . immediate_dominator . node_id , node . node_id
)
)
f . write ( " } \n " )
f . write ( " } \n " )
@ -1305,10 +1380,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if node . type in [ NodeType . IF , NodeType . IFLOOP ] :
if node . type in [ NodeType . IF , NodeType . IFLOOP ] :
true_node = node . son_true
true_node = node . son_true
if true_node :
if true_node :
content + = ' {} -> {} [label= " True " ]; \n ' . format ( node . node_id , true_node . node_id )
content + = ' {} -> {} [label= " True " ]; \n ' . format (
node . node_id , true_node . node_id
)
false_node = node . son_false
false_node = node . son_false
if false_node :
if false_node :
content + = ' {} -> {} [label= " False " ]; \n ' . format ( node . node_id , false_node . node_id )
content + = ' {} -> {} [label= " False " ]; \n ' . format (
node . node_id , false_node . node_id
)
else :
else :
for son in node . sons :
for son in node . sons :
content + = " {} -> {} ; \n " . format ( node . node_id , son . node_id )
content + = " {} -> {} ; \n " . format ( node . node_id , son . node_id )
@ -1353,7 +1432,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
Returns :
Returns :
bool : True if the variable is read
bool : True if the variable is read
"""
"""
variables_reads = [ n . variables_read for n in self . nodes if n . contains_require_or_assert ( ) ]
variables_reads = [
n . variables_read for n in self . nodes if n . contains_require_or_assert ( )
]
variables_read = [ item for sublist in variables_reads for item in sublist ]
variables_read = [ item for sublist in variables_reads for item in sublist ]
return variable in variables_read
return variable in variables_read
@ -1401,7 +1482,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if self . is_constructor :
if self . is_constructor :
return True
return True
conditional_vars = self . all_conditional_solidity_variables_read ( include_loop = False )
conditional_vars = self . all_conditional_solidity_variables_read (
include_loop = False
)
args_vars = self . all_solidity_variables_used_as_args ( )
args_vars = self . all_solidity_variables_used_as_args ( )
return SolidityVariableComposed ( " msg.sender " ) in conditional_vars + args_vars
return SolidityVariableComposed ( " msg.sender " ) in conditional_vars + args_vars
@ -1412,19 +1495,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
###################################################################################
###################################################################################
###################################################################################
###################################################################################
def _filter_state_variables_written ( self , expressions : List [ " Expression " ] ) :
ret = [ ]
for expression in expressions :
if isinstance ( expression , Identifier ) :
ret . append ( expression )
if isinstance ( expression , UnaryOperation ) :
ret . append ( expression . expression )
if isinstance ( expression , MemberAccess ) :
ret . append ( expression . expression )
if isinstance ( expression , IndexAccess ) :
ret . append ( expression . expression_left )
return ret
def _analyze_read_write ( self ) :
def _analyze_read_write ( self ) :
""" Compute variables read/written/...
""" Compute variables read/written/...
@ -1436,7 +1506,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation
# Remove dupplicate if they share the same string representation
write_var = [
write_var = [
next ( obj )
next ( obj )
for i , obj in groupby ( sorted ( write_var , key = lambda x : str ( x ) ) , lambda x : str ( x ) )
for i , obj in groupby (
sorted ( write_var , key = lambda x : str ( x ) ) , lambda x : str ( x )
)
]
]
self . _expression_vars_written = write_var
self . _expression_vars_written = write_var
@ -1447,7 +1519,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation
# Remove dupplicate if they share the same string representation
write_var = [
write_var = [
next ( obj )
next ( obj )
for i , obj in groupby ( sorted ( write_var , key = lambda x : str ( x ) ) , lambda x : str ( x ) )
for i , obj in groupby (
sorted ( write_var , key = lambda x : str ( x ) ) , lambda x : str ( x )
)
]
]
self . _vars_written = write_var
self . _vars_written = write_var
@ -1457,7 +1531,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation
# Remove dupplicate if they share the same string representation
read_var = [
read_var = [
next ( obj )
next ( obj )
for i , obj in groupby ( sorted ( read_var , key = lambda x : str ( x ) ) , lambda x : str ( x ) )
for i , obj in groupby (
sorted ( read_var , key = lambda x : str ( x ) ) , lambda x : str ( x )
)
]
]
self . _expression_vars_read = read_var
self . _expression_vars_read = read_var
@ -1467,14 +1543,18 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# Remove dupplicate if they share the same string representation
# Remove dupplicate if they share the same string representation
read_var = [
read_var = [
next ( obj )
next ( obj )
for i , obj in groupby ( sorted ( read_var , key = lambda x : str ( x ) ) , lambda x : str ( x ) )
for i , obj in groupby (
sorted ( read_var , key = lambda x : str ( x ) ) , lambda x : str ( x )
)
]
]
self . _vars_read = read_var
self . _vars_read = read_var
self . _state_vars_written = [
self . _state_vars_written = [
x for x in self . variables_written if isinstance ( x , StateVariable )
x for x in self . variables_written if isinstance ( x , StateVariable )
]
]
self . _state_vars_read = [ x for x in self . variables_read if isinstance ( x , StateVariable ) ]
self . _state_vars_read = [
x for x in self . variables_read if isinstance ( x , StateVariable )
]
self . _solidity_vars_read = [
self . _solidity_vars_read = [
x for x in self . variables_read if isinstance ( x , SolidityVariable )
x for x in self . variables_read if isinstance ( x , SolidityVariable )
]
]
@ -1483,7 +1563,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
slithir_variables = [ x . slithir_variables for x in self . nodes ]
slithir_variables = [ x . slithir_variables for x in self . nodes ]
slithir_variables = [ x for x in slithir_variables if x ]
slithir_variables = [ x for x in slithir_variables if x ]
self . _slithir_variables = [ item for sublist in slithir_variables for item in sublist ]
self . _slithir_variables = [
item for sublist in slithir_variables for item in sublist
]
def _analyze_calls ( self ) :
def _analyze_calls ( self ) :
calls = [ x . calls_as_expression for x in self . nodes ]
calls = [ x . calls_as_expression for x in self . nodes ]
@ -1496,7 +1578,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
internal_calls = [ item for sublist in internal_calls for item in sublist ]
internal_calls = [ item for sublist in internal_calls for item in sublist ]
self . _internal_calls = list ( set ( internal_calls ) )
self . _internal_calls = list ( set ( internal_calls ) )
self . _solidity_calls = [ c for c in internal_calls if isinstance ( c , SolidityFunction ) ]
self . _solidity_calls = [
c for c in internal_calls if isinstance ( c , SolidityFunction )
]
low_level_calls = [ x . low_level_calls for x in self . nodes ]
low_level_calls = [ x . low_level_calls for x in self . nodes ]
low_level_calls = [ x for x in low_level_calls if x ]
low_level_calls = [ x for x in low_level_calls if x ]
@ -1513,7 +1597,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
library_calls = [ item for sublist in library_calls for item in sublist ]
library_calls = [ item for sublist in library_calls for item in sublist ]
self . _library_calls = list ( set ( library_calls ) )
self . _library_calls = list ( set ( library_calls ) )
external_calls_as_expressions = [ x . external_calls_as_expressions for x in self . nodes ]
external_calls_as_expressions = [
x . external_calls_as_expressions for x in self . nodes
]
external_calls_as_expressions = [ x for x in external_calls_as_expressions if x ]
external_calls_as_expressions = [ x for x in external_calls_as_expressions if x ]
external_calls_as_expressions = [
external_calls_as_expressions = [
item for sublist in external_calls_as_expressions for item in sublist
item for sublist in external_calls_as_expressions for item in sublist
@ -1548,6 +1634,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
def _get_last_ssa_variable_instances (
def _get_last_ssa_variable_instances (
self , target_state : bool , target_local : bool
self , target_state : bool , target_local : bool
) - > Dict [ str , Set [ " SlithIRVariable " ] ] :
) - > Dict [ str , Set [ " SlithIRVariable " ] ] :
# pylint: disable=too-many-locals,too-many-branches
from slither . slithir . variables import ReferenceVariable
from slither . slithir . variables import ReferenceVariable
from slither . slithir . operations import OperationWithLValue
from slither . slithir . operations import OperationWithLValue
from slither . core . cfg . node import NodeType
from slither . core . cfg . node import NodeType
@ -1603,11 +1690,19 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return ret
return ret
def get_last_ssa_state_variables_instances ( self ) - > Dict [ str , Set [ " SlithIRVariable " ] ] :
def get_last_ssa_state_variables_instances (
return self . _get_last_ssa_variable_instances ( target_state = True , target_local = False )
self ,
) - > Dict [ str , Set [ " SlithIRVariable " ] ] :
return self . _get_last_ssa_variable_instances (
target_state = True , target_local = False
)
def get_last_ssa_local_variables_instances ( self ) - > Dict [ str , Set [ " SlithIRVariable " ] ] :
def get_last_ssa_local_variables_instances (
return self . _get_last_ssa_variable_instances ( target_state = False , target_local = True )
self ,
) - > Dict [ str , Set [ " SlithIRVariable " ] ] :
return self . _get_last_ssa_variable_instances (
target_state = False , target_local = True
)
@staticmethod
@staticmethod
def _unchange_phi ( ir : " Operation " ) :
def _unchange_phi ( ir : " Operation " ) :
@ -1619,7 +1714,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
return True
return True
return ir . rvalues [ 0 ] == ir . lvalue
return ir . rvalues [ 0 ] == ir . lvalue
def fix_phi ( self , last_state_variables_instances , initial_state_variables_instances ) :
def fix_phi (
self , last_state_variables_instances , initial_state_variables_instances
) :
from slither . slithir . operations import InternalCall , PhiCallback
from slither . slithir . operations import InternalCall , PhiCallback
from slither . slithir . variables import Constant , StateIRVariable
from slither . slithir . variables import Constant , StateIRVariable
@ -1627,28 +1724,40 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
for ir in node . irs_ssa :
for ir in node . irs_ssa :
if node == self . entry_point :
if node == self . entry_point :
if isinstance ( ir . lvalue , StateIRVariable ) :
if isinstance ( ir . lvalue , StateIRVariable ) :
additional = [ initial_state_variables_instances [ ir . lvalue . canonical_name ] ]
additional = [
additional + = last_state_variables_instances [ ir . lvalue . canonical_name ]
initial_state_variables_instances [ ir . lvalue . canonical_name ]
]
additional + = last_state_variables_instances [
ir . lvalue . canonical_name
]
ir . rvalues = list ( set ( additional + ir . rvalues ) )
ir . rvalues = list ( set ( additional + ir . rvalues ) )
# function parameter
# function parameter
else :
else :
# find index of the parameter
# find index of the parameter
idx = self . parameters . index ( ir . lvalue . non_ssa_version )
idx = self . parameters . index ( ir . lvalue . non_ssa_version )
# find non ssa version of that index
# find non ssa version of that index
additional = [ n . ir . arguments [ idx ] for n in self . reachable_from_nodes ]
additional = [
n . ir . arguments [ idx ] for n in self . reachable_from_nodes
]
additional = unroll ( additional )
additional = unroll ( additional )
additional = [ a for a in additional if not isinstance ( a , Constant ) ]
additional = [
a for a in additional if not isinstance ( a , Constant )
]
ir . rvalues = list ( set ( additional + ir . rvalues ) )
ir . rvalues = list ( set ( additional + ir . rvalues ) )
if isinstance ( ir , PhiCallback ) :
if isinstance ( ir , PhiCallback ) :
callee_ir = ir . callee_ir
callee_ir = ir . callee_ir
if isinstance ( callee_ir , InternalCall ) :
if isinstance ( callee_ir , InternalCall ) :
last_ssa = callee_ir . function . get_last_ssa_state_variables_instances ( )
last_ssa = (
callee_ir . function . get_last_ssa_state_variables_instances ( )
)
if ir . lvalue . canonical_name in last_ssa :
if ir . lvalue . canonical_name in last_ssa :
ir . rvalues = list ( last_ssa [ ir . lvalue . canonical_name ] )
ir . rvalues = list ( last_ssa [ ir . lvalue . canonical_name ] )
else :
else :
ir . rvalues = [ ir . lvalue ]
ir . rvalues = [ ir . lvalue ]
else :
else :
additional = last_state_variables_instances [ ir . lvalue . canonical_name ]
additional = last_state_variables_instances [
ir . lvalue . canonical_name
]
ir . rvalues = list ( set ( additional + ir . rvalues ) )
ir . rvalues = list ( set ( additional + ir . rvalues ) )
node . irs_ssa = [ ir for ir in node . irs_ssa if not self . _unchange_phi ( ir ) ]
node . irs_ssa = [ ir for ir in node . irs_ssa if not self . _unchange_phi ( ir ) ]
@ -1662,7 +1771,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
def generate_slithir_ssa ( self , all_ssa_state_variables_instances ) :
def generate_slithir_ssa ( self , all_ssa_state_variables_instances ) :
from slither . slithir . utils . ssa import add_ssa_ir , transform_slithir_vars_to_ssa
from slither . slithir . utils . ssa import add_ssa_ir , transform_slithir_vars_to_ssa
from slither . core . dominators . utils import compute_dominance_frontier , compute_dominators
from slither . core . dominators . utils import (
compute_dominance_frontier ,
compute_dominators ,
)
compute_dominators ( self . nodes )
compute_dominators ( self . nodes )
compute_dominance_frontier ( self . nodes )
compute_dominance_frontier ( self . nodes )