@ -1,17 +1,21 @@
import logging
import logging
import re
import re
import uuid
from collections import namedtuple
from collections import namedtuple
from enum import Enum as PythonEnum
from enum import Enum as PythonEnum
from pathlib import Path
from pathlib import Path
from typing import List , Set , Dict , Optional
from typing import List , Set , Dict , Optional
from slither . core . compilation_unit import SlitherCompilationUnit
from slither . core . declarations import SolidityFunction , EnumContract , StructureContract
from slither . core . declarations import SolidityFunction , EnumContract , StructureContract
from slither . core . declarations . contract import Contract
from slither . core . declarations . contract import Contract
from slither . core . slither_core import SlitherCore
from slither . core . declarations . function_top_level import FunctionTopLevel
from slither . core . declarations . top_level import TopLevel
from slither . core . solidity_types import MappingType , ArrayType
from slither . core . solidity_types import MappingType , ArrayType
from slither . core . solidity_types . type import Type
from slither . core . solidity_types . user_defined_type import UserDefinedType
from slither . core . solidity_types . user_defined_type import UserDefinedType
from slither . exceptions import SlitherException
from slither . exceptions import SlitherException
from slither . slithir . operations import NewContract , TypeConversion , SolidityCall
from slither . slithir . operations import NewContract , TypeConversion , SolidityCall , InternalCall
from slither . tools . flattening . export . export import (
from slither . tools . flattening . export . export import (
Export ,
Export ,
export_as_json ,
export_as_json ,
@ -44,7 +48,7 @@ class Flattening:
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods
def __init__ (
def __init__ (
self ,
self ,
slither : SlitherCore ,
compilation_unit : SlitherCompilationUnit ,
external_to_public = False ,
external_to_public = False ,
remove_assert = False ,
remove_assert = False ,
private_to_internal = False ,
private_to_internal = False ,
@ -52,7 +56,8 @@ class Flattening:
pragma_solidity : Optional [ str ] = None ,
pragma_solidity : Optional [ str ] = None ,
) :
) :
self . _source_codes : Dict [ Contract , str ] = { }
self . _source_codes : Dict [ Contract , str ] = { }
self . _slither : SlitherCore = slither
self . _source_codes_top_level : Dict [ TopLevel , str ] = { }
self . _compilation_unit : SlitherCompilationUnit = compilation_unit
self . _external_to_public = external_to_public
self . _external_to_public = external_to_public
self . _remove_assert = remove_assert
self . _remove_assert = remove_assert
self . _use_abi_encoder_v2 = False
self . _use_abi_encoder_v2 = False
@ -63,20 +68,32 @@ class Flattening:
self . _check_abi_encoder_v2 ( )
self . _check_abi_encoder_v2 ( )
for contract in slither . contracts :
for contract in compilation_unit . contracts :
self . _get_source_code ( contract )
self . _get_source_code ( contract )
self . _get_source_code_top_level ( compilation_unit . structures_top_level )
self . _get_source_code_top_level ( compilation_unit . enums_top_level )
self . _get_source_code_top_level ( compilation_unit . variables_top_level )
self . _get_source_code_top_level ( compilation_unit . functions_top_level )
def _get_source_code_top_level ( self , elems : List [ TopLevel ] ) - > None :
for elem in elems :
src_mapping = elem . source_mapping
content = self . _compilation_unit . core . source_code [ src_mapping [ " filename_absolute " ] ]
start = src_mapping [ " start " ]
end = src_mapping [ " start " ] + src_mapping [ " length " ]
self . _source_codes_top_level [ elem ] = content [ start : end ]
def _check_abi_encoder_v2 ( self ) :
def _check_abi_encoder_v2 ( self ) :
"""
"""
Check if ABIEncoderV2 is required
Check if ABIEncoderV2 is required
Set _use_abi_encorder_v2
Set _use_abi_encorder_v2
: return :
: return :
"""
"""
for compilation_unit in self . _slither . compilation_units :
for p in self . _compilation_unit . pragma_directives :
for p in compilation_unit . pragma_directives :
if " ABIEncoderV2 " in str ( p . directive ) :
if " ABIEncoderV2 " in str ( p . directive ) :
self . _use_abi_encoder_v2 = True
self . _use_abi_encoder_v2 = True
return
return
def _get_source_code (
def _get_source_code (
self , contract : Contract
self , contract : Contract
@ -88,7 +105,7 @@ class Flattening:
: return :
: return :
"""
"""
src_mapping = contract . source_mapping
src_mapping = contract . source_mapping
content = self . _slither . source_code [ src_mapping [ " filename_absolute " ] ]
content = self . _compilation_unit . core . source_code [ src_mapping [ " filename_absolute " ] ]
start = src_mapping [ " start " ]
start = src_mapping [ " start " ]
end = src_mapping [ " start " ] + src_mapping [ " length " ]
end = src_mapping [ " start " ] + src_mapping [ " length " ]
@ -132,11 +149,9 @@ class Flattening:
if self . _private_to_internal :
if self . _private_to_internal :
for variable in contract . state_variables_declared :
for variable in contract . state_variables_declared :
if variable . visibility == " private " :
if variable . visibility == " private " :
print ( variable . source_mapping )
attributes_start = variable . source_mapping [ " start " ]
attributes_start = variable . source_mapping [ " start " ]
attributes_end = attributes_start + variable . source_mapping [ " length " ]
attributes_end = attributes_start + variable . source_mapping [ " length " ]
attributes = content [ attributes_start : attributes_end ]
attributes = content [ attributes_start : attributes_end ]
print ( attributes )
regex = re . search ( r " private " , attributes )
regex = re . search ( r " private " , attributes )
if regex :
if regex :
to_patch . append (
to_patch . append (
@ -191,35 +206,54 @@ class Flattening:
ret + = f " pragma solidity { self . _pragma_solidity } ; \n "
ret + = f " pragma solidity { self . _pragma_solidity } ; \n "
else :
else :
# TODO support multiple compiler version
# TODO support multiple compiler version
ret + = f " pragma solidity { list ( self . _slither . crytic_compile . compilation_units . values ( ) ) [ 0 ] . compiler_version . version } ; \n "
ret + = f " pragma solidity { list ( self . _compilation_unit . crytic_compile . compilation_units . values ( ) ) [ 0 ] . compiler_version . version } ; \n "
if self . _use_abi_encoder_v2 :
if self . _use_abi_encoder_v2 :
ret + = " pragma experimental ABIEncoderV2; \n "
ret + = " pragma experimental ABIEncoderV2; \n "
return ret
return ret
def _export_from_type ( self , t , contract , exported , list_contract ) :
def _export_from_type (
self ,
t : Type ,
contract : Contract ,
exported : Set [ str ] ,
list_contract : List [ Contract ] ,
list_top_level : List [ TopLevel ] ,
) :
if isinstance ( t , UserDefinedType ) :
if isinstance ( t , UserDefinedType ) :
if isinstance ( t . type , ( EnumContract , StructureContract ) ) :
t_type = t . type
if t . type . contract != contract and t . type . contract not in exported :
if isinstance ( t_type , ( EnumContract , StructureContract ) ) :
self . _export_list_used_contracts ( t . type . contract , exported , list_contract )
if t_type . contract != contract and t_type . contract not in exported :
self . _export_list_used_contracts (
t_type . contract , exported , list_contract , list_top_level
)
else :
else :
assert isinstance ( t . type , Contract )
assert isinstance ( t . type , Contract )
if t . type != contract and t . type not in exported :
if t . type != contract and t . type not in exported :
self . _export_list_used_contracts ( t . type , exported , list_contract )
self . _export_list_used_contracts (
t . type , exported , list_contract , list_top_level
)
elif isinstance ( t , MappingType ) :
elif isinstance ( t , MappingType ) :
self . _export_from_type ( t . type_from , contract , exported , list_contract )
self . _export_from_type ( t . type_from , contract , exported , list_contract , list_top_level )
self . _export_from_type ( t . type_to , contract , exported , list_contract )
self . _export_from_type ( t . type_to , contract , exported , list_contract , list_top_level )
elif isinstance ( t , ArrayType ) :
elif isinstance ( t , ArrayType ) :
self . _export_from_type ( t . type , contract , exported , list_contract )
self . _export_from_type ( t . type , contract , exported , list_contract , list_top_level )
def _export_list_used_contracts ( # pylint: disable=too-many-branches
def _export_list_used_contracts ( # pylint: disable=too-many-branches
self , contract : Contract , exported : Set [ str ] , list_contract : List [ Contract ]
self ,
contract : Contract ,
exported : Set [ str ] ,
list_contract : List [ Contract ] ,
list_top_level : List [ TopLevel ] ,
) :
) :
# TODO: investigate why this happen
if not isinstance ( contract , Contract ) :
return
if contract . name in exported :
if contract . name in exported :
return
return
exported . add ( contract . name )
exported . add ( contract . name )
for inherited in contract . inheritance :
for inherited in contract . inheritance :
self . _export_list_used_contracts ( inherited , exported , list_contract )
self . _export_list_used_contracts ( inherited , exported , list_contract , list_top_level )
# Find all the external contracts called
# Find all the external contracts called
externals = contract . all_library_calls + contract . all_high_level_calls
externals = contract . all_library_calls + contract . all_high_level_calls
@ -228,7 +262,16 @@ class Flattening:
externals = list ( { e [ 0 ] for e in externals if e [ 0 ] != contract } )
externals = list ( { e [ 0 ] for e in externals if e [ 0 ] != contract } )
for inherited in externals :
for inherited in externals :
self . _export_list_used_contracts ( inherited , exported , list_contract )
self . _export_list_used_contracts ( inherited , exported , list_contract , list_top_level )
for list_libs in contract . using_for . values ( ) :
for lib_candidate_type in list_libs :
if isinstance ( lib_candidate_type , UserDefinedType ) :
lib_candidate = lib_candidate_type . type
if isinstance ( lib_candidate , Contract ) :
self . _export_list_used_contracts (
lib_candidate , exported , list_contract , list_top_level
)
# Find all the external contracts use as a base type
# Find all the external contracts use as a base type
local_vars = [ ]
local_vars = [ ]
@ -236,11 +279,11 @@ class Flattening:
local_vars + = f . variables
local_vars + = f . variables
for v in contract . variables + local_vars :
for v in contract . variables + local_vars :
self . _export_from_type ( v . type , contract , exported , list_contract )
self . _export_from_type ( v . type , contract , exported , list_contract , list_top_level )
for s in contract . structures :
for s in contract . structures :
for elem in s . elems . values ( ) :
for elem in s . elems . values ( ) :
self . _export_from_type ( elem . type , contract , exported , list_contract )
self . _export_from_type ( elem . type , contract , exported , list_contract , list_top_level )
# Find all convert and "new" operation that can lead to use an external contract
# Find all convert and "new" operation that can lead to use an external contract
for f in contract . functions_declared :
for f in contract . functions_declared :
@ -248,21 +291,38 @@ class Flattening:
if isinstance ( ir , NewContract ) :
if isinstance ( ir , NewContract ) :
if ir . contract_created != contract and not ir . contract_created in exported :
if ir . contract_created != contract and not ir . contract_created in exported :
self . _export_list_used_contracts (
self . _export_list_used_contracts (
ir . contract_created , exported , list_contract
ir . contract_created , exported , list_contract , list_top_level
)
)
if isinstance ( ir , TypeConversion ) :
if isinstance ( ir , TypeConversion ) :
self . _export_from_type ( ir . type , contract , exported , list_contract )
self . _export_from_type (
ir . type , contract , exported , list_contract , list_top_level
)
for read in ir . read :
if isinstance ( read , TopLevel ) :
if read not in list_top_level :
list_top_level . append ( read )
if isinstance ( ir , InternalCall ) :
function_called = ir . function
if isinstance ( function_called , FunctionTopLevel ) :
list_top_level . append ( function_called )
if contract not in list_contract :
if contract not in list_contract :
list_contract . append ( contract )
list_contract . append ( contract )
def _export_contract_with_inheritance ( self , contract ) - > Export :
def _export_contract_with_inheritance ( self , contract ) - > Export :
list_contracts : List [ Contract ] = [ ] # will contain contract itself
list_contracts : List [ Contract ] = [ ] # will contain contract itself
self . _export_list_used_contracts ( contract , set ( ) , list_contracts )
list_top_level : List [ TopLevel ] = [ ]
path = Path ( self . _export_path , f " { contract . name } .sol " )
self . _export_list_used_contracts ( contract , set ( ) , list_contracts , list_top_level )
path = Path ( self . _export_path , f " { contract . name } _ { uuid . uuid4 ( ) } .sol " )
content = " "
content = " "
content + = self . _pragmas ( )
content + = self . _pragmas ( )
for listed_top_level in list_top_level :
content + = self . _source_codes_top_level [ listed_top_level ]
content + = " \n "
for listed_contract in list_contracts :
for listed_contract in list_contracts :
content + = self . _source_codes [ listed_contract ]
content + = self . _source_codes [ listed_contract ]
content + = " \n "
content + = " \n "
@ -271,7 +331,7 @@ class Flattening:
def _export_most_derived ( self ) - > List [ Export ] :
def _export_most_derived ( self ) - > List [ Export ] :
ret : List [ Export ] = [ ]
ret : List [ Export ] = [ ]
for contract in self . _slither . contracts_derived :
for contract in self . _compilation_unit . contracts_derived :
ret . append ( self . _export_contract_with_inheritance ( contract ) )
ret . append ( self . _export_contract_with_inheritance ( contract ) )
return ret
return ret
@ -281,8 +341,13 @@ class Flattening:
content = " "
content = " "
content + = self . _pragmas ( )
content + = self . _pragmas ( )
for top_level_content in self . _source_codes_top_level . values ( ) :
content + = " \n "
content + = top_level_content
content + = " \n "
contract_seen = set ( )
contract_seen = set ( )
contract_to_explore = list ( self . _slither . contracts )
contract_to_explore = list ( self . _compilation_unit . contracts )
# We only need the inheritance order here, as solc can compile
# We only need the inheritance order here, as solc can compile
# a contract that use another contract type (ex: state variable) that he has not seen yet
# a contract that use another contract type (ex: state variable) that he has not seen yet
@ -303,9 +368,17 @@ class Flattening:
def _export_with_import ( self ) - > List [ Export ] :
def _export_with_import ( self ) - > List [ Export ] :
exports : List [ Export ] = [ ]
exports : List [ Export ] = [ ]
for contract in self . _slither . contracts :
for contract in self . _compilation_unit . contracts :
list_contracts : List [ Contract ] = [ ] # will contain contract itself
list_contracts : List [ Contract ] = [ ] # will contain contract itself
self . _export_list_used_contracts ( contract , set ( ) , list_contracts )
list_top_level : List [ TopLevel ] = [ ]
self . _export_list_used_contracts ( contract , set ( ) , list_contracts , list_top_level )
if list_top_level :
logger . info (
" Top level objects are not yet supported with the local import flattening "
)
for elem in list_top_level :
logger . info ( f " Missing { elem } for { contract . name } " )
path = Path ( self . _export_path , f " { contract . name } .sol " )
path = Path ( self . _export_path , f " { contract . name } .sol " )
@ -341,12 +414,13 @@ class Flattening:
elif strategy == Strategy . LocalImport :
elif strategy == Strategy . LocalImport :
exports = self . _export_with_import ( )
exports = self . _export_with_import ( )
else :
else :
contracts = self . _slither . get_contract_from_name ( target )
contracts = self . _compilation_unit . get_contract_from_name ( target )
if len ( contracts ) != 1 :
if len ( contracts ) == 0 :
logger . error ( f " { target } not found " )
logger . error ( f " { target } not found " )
return
return
contract = contracts [ 0 ]
exports = [ ]
exports = [ self . _export_contract_with_inheritance ( contract ) ]
for contract in contracts :
exports . append ( self . _export_contract_with_inheritance ( contract ) )
if json :
if json :
export_as_json ( exports , json )
export_as_json ( exports , json )