Merge pull request #1852 from crytic/dev-flattening-improvements

make slither-flat work for top level errors, structs, enums
pull/1912/head
alpharush 2 years ago committed by GitHub
commit eda3e540fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      examples/flat/a.sol
  2. 13
      examples/flat/b.sol
  3. 10
      scripts/ci_test_flat.sh
  4. 6
      slither/core/compilation_unit.py
  5. 4
      slither/core/declarations/solidity_variables.py
  6. 2
      slither/tools/flattening/export/export.py
  7. 45
      slither/tools/flattening/flattening.py

@ -1,3 +1,9 @@
contract A{
pragma solidity 0.8.19;
}
error RevertIt();
contract Example {
function reverts() external pure {
revert RevertIt();
}
}

@ -1,5 +1,16 @@
import "./a.sol";
contract B is A{
pragma solidity 0.8.19;
enum B {
a,
b
}
contract T {
Example e = new Example();
function b() public returns(uint) {
B b = B.a;
return 4;
}
}

@ -1,6 +1,8 @@
#!/usr/bin/env bash
shopt -s extglob
### Test slither-prop
### Test slither-flat
solc-select use 0.8.19 --always-install
cd examples/flat || exit 1
@ -8,5 +10,11 @@ if ! slither-flat b.sol; then
echo "slither-flat failed"
exit 1
fi
SUFFIX="@(sol)"
if ! solc "crytic-export/flattening/"*$SUFFIX; then
echo "solc failed on flattened files"
exit 1
fi
exit 0

@ -13,7 +13,7 @@ from slither.core.declarations import (
Function,
Modifier,
)
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
@ -46,7 +46,7 @@ class SlitherCompilationUnit(Context):
self._using_for_top_level: List[UsingForTopLevel] = []
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomError] = []
self._custom_errors: List[CustomErrorTopLevel] = []
self._user_defined_value_types: Dict[str, TypeAliasTopLevel] = {}
self._all_functions: Set[Function] = set()
@ -216,7 +216,7 @@ class SlitherCompilationUnit(Context):
return self._using_for_top_level
@property
def custom_errors(self) -> List[CustomError]:
def custom_errors(self) -> List[CustomErrorTopLevel]:
return self._custom_errors
@property

@ -201,6 +201,10 @@ class SolidityCustomRevert(SolidityFunction):
self._custom_error = custom_error
self._return_type: List[Union[TypeInformation, ElementaryType]] = []
@property
def custom_error(self) -> CustomError:
return self._custom_error
def __eq__(self, other: Any) -> bool:
return (
self.__class__ == other.__class__

@ -15,7 +15,7 @@ ZIP_TYPES_ACCEPTED = {
Export = namedtuple("Export", ["filename", "content"])
logger = logging.getLogger("Slither")
logger = logging.getLogger("Slither-flat")
def save_to_zip(files: List[Export], zip_filename: str, zip_type: str = "lzma"):

@ -11,6 +11,7 @@ from slither.core.declarations import SolidityFunction, EnumContract, StructureC
from slither.core.declarations.contract import Contract
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.top_level import TopLevel
from slither.core.declarations.solidity_variables import SolidityCustomRevert
from slither.core.solidity_types import MappingType, ArrayType
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType
@ -23,7 +24,8 @@ from slither.tools.flattening.export.export import (
save_to_disk,
)
logger = logging.getLogger("Slither-flattening")
logger = logging.getLogger("Slither-flat")
logger.setLevel(logging.INFO)
# index: where to start
# patch_type:
@ -75,6 +77,7 @@ class Flattening:
self._get_source_code_top_level(compilation_unit.structures_top_level)
self._get_source_code_top_level(compilation_unit.enums_top_level)
self._get_source_code_top_level(compilation_unit.custom_errors)
self._get_source_code_top_level(compilation_unit.variables_top_level)
self._get_source_code_top_level(compilation_unit.functions_top_level)
@ -249,12 +252,14 @@ class Flattening:
t: Type,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
list_contract: Set[Contract],
list_top_level: Set[TopLevel],
):
if isinstance(t, UserDefinedType):
t_type = t.type
if isinstance(t_type, (EnumContract, StructureContract)):
if isinstance(t_type, TopLevel):
list_top_level.add(t_type)
elif isinstance(t_type, (EnumContract, StructureContract)):
if t_type.contract != contract and t_type.contract not in exported:
self._export_list_used_contracts(
t_type.contract, exported, list_contract, list_top_level
@ -275,8 +280,8 @@ class Flattening:
self,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
list_contract: Set[Contract],
list_top_level: Set[TopLevel],
):
# TODO: investigate why this happen
if not isinstance(contract, Contract):
@ -332,19 +337,21 @@ class Flattening:
for read in ir.read:
if isinstance(read, TopLevel):
if read not in list_top_level:
list_top_level.append(read)
if isinstance(ir, InternalCall):
function_called = ir.function
if isinstance(function_called, FunctionTopLevel):
list_top_level.append(function_called)
if contract not in list_contract:
list_contract.append(contract)
list_top_level.add(read)
if isinstance(ir, InternalCall) and isinstance(ir.function, FunctionTopLevel):
list_top_level.add(ir.function)
if (
isinstance(ir, SolidityCall)
and isinstance(ir.function, SolidityCustomRevert)
and isinstance(ir.function.custom_error, TopLevel)
):
list_top_level.add(ir.function.custom_error)
list_contract.add(contract)
def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself
list_top_level: List[TopLevel] = []
list_contracts: Set[Contract] = set() # will contain contract itself
list_top_level: Set[TopLevel] = set()
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol")
@ -401,8 +408,8 @@ class Flattening:
def _export_with_import(self) -> List[Export]:
exports: List[Export] = []
for contract in self._compilation_unit.contracts:
list_contracts: List[Contract] = [] # will contain contract itself
list_top_level: List[TopLevel] = []
list_contracts: Set[Contract] = set() # will contain contract itself
list_top_level: Set[TopLevel] = set()
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
if list_top_level:

Loading…
Cancel
Save