reference API handles abstract contract, interface, and virtual func

pull/2376/head
alpharush 8 months ago
parent 20ed490e2d
commit 0933dfa67c
  1. 40
      slither/core/declarations/function.py
  2. 58
      slither/core/slither_core.py
  3. 12
      slither/printers/summary/declaration.py
  4. 24
      slither/solc_parsing/declarations/function.py
  5. 4
      slither/solc_parsing/slither_compilation_unit_solc.py
  6. 58
      tests/unit/core/test_data/virtual_overrides.sol
  7. 12
      tests/unit/core/test_source_mapping.py
  8. 124
      tests/unit/core/test_virtual_overrides.py

@ -37,7 +37,7 @@ if TYPE_CHECKING:
HighLevelCallType,
LibraryCallType,
)
from slither.core.declarations import Contract
from slither.core.declarations import Contract, FunctionContract
from slither.core.cfg.node import Node, NodeType
from slither.core.variables.variable import Variable
from slither.slithir.variables.variable import SlithIRVariable
@ -46,7 +46,6 @@ if TYPE_CHECKING:
from slither.slithir.operations import Operation
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
from slither.slithir.variables.state_variable import StateIRVariable
LOGGER = logging.getLogger("Function")
ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
@ -127,7 +126,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
self._payable: bool = False
self._visibility: Optional[str] = None
self._virtual: bool = False
self._overrides: List["Contract"] = []
self._overrides: List["FunctionContract"] = []
self._overridden_by: List["FunctionContract"] = []
self._is_implemented: Optional[bool] = None
self._is_empty: Optional[bool] = None
@ -451,19 +451,30 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
###################################################################################
@property
def virtual(self) -> bool:
def is_implemented(self) -> bool:
"""
bool: True if the function is implemented
"""
return self._is_implemented
@is_implemented.setter
def is_implemented(self, is_implemented: bool):
self._is_implemented = is_implemented
@property
def is_virtual(self) -> bool:
"""
Note for Solidity < 0.6.0 it will always be false
bool: True if the function is virtual
"""
return self._virtual
@virtual.setter
def virtual(self, v: bool):
@is_virtual.setter
def is_virtual(self, v: bool):
self._virtual = v
@property
def is_overriden(self) -> bool:
def is_override(self) -> bool:
"""
Note for Solidity < 0.6.0 it will always be false
bool: True if the function overrides a base function
@ -471,15 +482,18 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
return len(self._overrides) > 0
@property
def overrides(self) -> List["Contract"]:
def overridden_by(self) -> List["FunctionContract"]:
"""
List["Contract"]: List of which parent contracts' functions definitions are overridden
List["FunctionContract"]: List offunctions in child contracts that override this function
"""
return self._overrides
return self._overridden_by
@overrides.setter
def overrides(self, o: List["Contract"]):
self._overrides = o
@property
def overrides(self) -> List["FunctionContract"]:
"""
List["FunctionContract"]: List of functions in parent contracts that this function overrides
"""
return self._overrides
# endregion
###################################################################################

@ -202,9 +202,33 @@ class SlitherCore(Context):
return self._offset_to_objects[filename][offset]
def _compute_offsets_from_thing(self, thing: SourceMapping):
definition = get_definition(thing, self.crytic_compile)
references = get_references(thing)
implementation = get_implementation(thing)
implementations = set()
# Abstract contracts and interfaces are implemented by their children
if isinstance(thing, Contract):
is_interface = thing.is_interface
is_implicitly_abstract = not thing.is_fully_implemented
is_explicitly_abstract = thing.is_abstract
if is_interface or is_implicitly_abstract or is_explicitly_abstract:
for contract in self.contracts:
if thing in contract.immediate_inheritance:
implementations.add(contract.source_mapping)
# Parent's virtual functions may be overridden by children
elif isinstance(thing, FunctionContract):
for over in thing.overridden_by:
implementations.add(over.source_mapping)
# Only show implemented virtual functions
if not thing.is_virtual or thing.is_implemented:
implementations.add(get_implementation(thing))
else:
implementations.add(get_implementation(thing))
for offset in range(definition.start, definition.end + 1):
@ -216,29 +240,41 @@ class SlitherCore(Context):
)
or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract))
):
self._offset_to_objects[definition.filename][offset].add(thing)
self._offset_to_definitions[definition.filename][offset].add(definition)
self._offset_to_implementations[definition.filename][offset].add(implementation)
self._offset_to_implementations[definition.filename][offset].update(implementations)
self._offset_to_references[definition.filename][offset] |= set(references)
for ref in references:
for offset in range(ref.start, ref.end + 1):
is_declared_function = (
isinstance(thing, FunctionContract)
and thing.contract_declarer == thing.contract
)
if (
isinstance(thing, TopLevel)
or (
isinstance(thing, FunctionContract)
and thing.contract_declarer == thing.contract
)
or is_declared_function
or (
isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)
)
):
self._offset_to_objects[definition.filename][offset].add(thing)
self._offset_to_definitions[ref.filename][offset].add(definition)
self._offset_to_implementations[ref.filename][offset].add(implementation)
if is_declared_function:
# Only show the nearest lexical definition for declared contract-level functions
if (
offset > thing.contract.source_mapping.start
and offset < thing.contract.source_mapping.end
):
self._offset_to_definitions[ref.filename][offset].add(definition)
else:
self._offset_to_definitions[ref.filename][offset].add(definition)
self._offset_to_implementations[ref.filename][offset].update(implementations)
self._offset_to_references[ref.filename][offset] |= set(references)
def _compute_offsets_to_ref_impl_decl(self): # pylint: disable=too-many-branches
@ -251,11 +287,11 @@ class SlitherCore(Context):
for contract in compilation_unit.contracts:
self._compute_offsets_from_thing(contract)
for function in contract.functions:
for function in contract.functions_declared:
self._compute_offsets_from_thing(function)
for variable in function.local_variables:
self._compute_offsets_from_thing(variable)
for modifier in contract.modifiers:
for modifier in contract.modifiers_declared:
self._compute_offsets_from_thing(modifier)
for variable in modifier.local_variables:
self._compute_offsets_from_thing(variable)

@ -21,18 +21,20 @@ class Declaration(AbstractPrinter):
txt += "\n# Contracts\n"
for contract in compilation_unit.contracts:
txt += f"# {contract.name}\n"
txt += f"\t- Declaration: {get_definition(contract, compilation_unit.core.crytic_compile).to_detailed_str()}\n"
txt += f"\t- Implementation: {get_implementation(contract).to_detailed_str()}\n"
contract_def = get_definition(contract, compilation_unit.core.crytic_compile)
txt += f"\t- Declaration: {contract_def.to_detailed_str()}\n"
txt += f"\t- Implementation(s): {[x.to_detailed_str() for x in list(self.slither.offset_to_implementations(contract.source_mapping.filename.absolute, contract_def.start))]}\n"
txt += (
f"\t- References: {[x.to_detailed_str() for x in get_references(contract)]}\n"
)
txt += "\n\t## Function\n"
for func in contract.functions:
for func in contract.functions_declared:
txt += f"\t\t- {func.canonical_name}\n"
txt += f"\t\t\t- Declaration: {get_definition(func, compilation_unit.core.crytic_compile).to_detailed_str()}\n"
txt += f"\t\t\t- Implementation: {get_implementation(func).to_detailed_str()}\n"
function_def = get_definition(func, compilation_unit.core.crytic_compile)
txt += f"\t\t\t- Declaration: {function_def.to_detailed_str()}\n"
txt += f"\t\t\t- Implementation(s): {[x.to_detailed_str() for x in list(self.slither.offset_to_implementations(func.source_mapping.filename.absolute, function_def.start))]}\n"
txt += f"\t\t\t- References: {[x.to_detailed_str() for x in get_references(func)]}\n"
txt += "\n\t## State variables\n"

@ -241,21 +241,35 @@ class FunctionSolc(CallerContextExpression):
if "payable" in attributes:
self._function.payable = attributes["payable"]
if "baseFunctions" in attributes:
overrides_ids = []
for o_id in attributes["baseFunctions"]:
overrides_ids.append(o_id)
if len(overrides_ids) > 0:
found = 0
for c in self.contract_parser.underlying_contract.immediate_inheritance:
for f in c.functions_declared:
if f.id in overrides_ids:
self._function.overrides.append(f)
f.overridden_by.append(self._function)
found += 1
# Search next parent if already found overridden func in this parent
continue
# Stop searching if we found all the overrides
if len(overrides_ids) == found:
break
if "overrides" in attributes and isinstance(attributes["overrides"], dict):
overrides = []
for override in attributes["overrides"].get("overrides", []):
refId = override["referencedDeclaration"]
overridden_contract = self.slither_parser._contracts_by_id.get(refId, None)
if overridden_contract:
overridden_contract.add_reference_from_raw_source(
override["src"], self.compilation_unit
)
overrides.append(overridden_contract)
self._function.overrides = overrides
if "virtual" in attributes:
self._function.virtual = attributes["virtual"]
self._function.is_virtual = attributes["virtual"]
def analyze_params(self) -> None:
# Can be re-analyzed due to inheritance

@ -113,6 +113,10 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
def slither_parser(self) -> "SlitherCompilationUnitSolc":
return self
@property
def contracts_by_id(self) -> Dict[int, Contract]:
return self._contracts_by_id
###################################################################################
###################################################################################
# region AST

@ -0,0 +1,58 @@
contract Test {
function myVirtualFunction() virtual external {
}
}
contract A is Test {
function myVirtualFunction() virtual override external {
}
}
contract B is A {
function myVirtualFunction() override external {
}
}
contract C is Test {
function myVirtualFunction() override external {
}
}
contract X is Test {
function myVirtualFunction() virtual override external {
}
}
contract Y {
function myVirtualFunction() virtual external {
}
}
contract Z is Y, X{
function myVirtualFunction() virtual override(Y, X) external {
}
}
abstract contract Name {
constructor() {
}
}
contract Name2 is Name {
constructor() {
}
}
abstract contract Test2 {
function f() virtual public;
}
contract A2 is Test2 {
function f() virtual override public {
}
}

@ -33,8 +33,12 @@ def test_source_mapping_inheritance(solc_binary_path, solc_version):
assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 27)} == {(26, 28)}
# Only one reference for A.f(), in A.test()
assert {(x.start, x.end) for x in slither.offset_to_references(file, 27)} == {(92, 93)}
# Only one implementation for A.f(), in A.test()
assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 27)} == {(17, 53)}
# Three overridden implementation of A.f(), in A.test()
assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 27)} == {
(17, 53),
(129, 166),
(193, 230),
}
# Check if C.f() is at the offset 203
functions = slither.offset_to_objects(file, 203)
@ -62,11 +66,9 @@ def test_source_mapping_inheritance(solc_binary_path, solc_version):
assert isinstance(function, Function)
assert function.canonical_name in ["A.f()", "B.f()", "C.f()"]
# There are three definitions possible (in A, B or C)
# There is one definition in the lexical scope of A
assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 93)} == {
(26, 28),
(202, 204),
(138, 140),
}
# There are two references possible (in A.test() or C.test2() )

@ -0,0 +1,124 @@
from pathlib import Path
from slither import Slither
TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"
def test_overrides(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.15")
slither = Slither(Path(TEST_DATA_DIR, "virtual_overrides.sol").as_posix(), solc=solc_path)
test = slither.get_contract_from_name("Test")[0]
test_virtual_func = test.get_function_from_full_name("myVirtualFunction()")
assert test_virtual_func.is_virtual
assert not test_virtual_func.is_override
x = test.get_functions_overridden_by(test_virtual_func)
assert len(x) == 0
x = test_virtual_func.overridden_by
assert len(x) == 3, [i.canonical_name for i in x]
assert set([i.canonical_name for i in x]) == set(
["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"]
)
a = slither.get_contract_from_name("A")[0]
a_virtual_func = a.get_function_from_full_name("myVirtualFunction()")
assert a_virtual_func.is_virtual
assert a_virtual_func.is_override
x = a.get_functions_overridden_by(a_virtual_func)
assert len(x) == 1
assert x[0].canonical_name == "Test.myVirtualFunction()"
b = slither.get_contract_from_name("B")[0]
b_virtual_func = b.get_function_from_full_name("myVirtualFunction()")
assert not b_virtual_func.is_virtual
assert b_virtual_func.is_override
x = b.get_functions_overridden_by(b_virtual_func)
assert len(x) == 1
assert x[0].canonical_name == "A.myVirtualFunction()"
assert len(b_virtual_func.overridden_by) == 0
y = slither.get_contract_from_name("Y")[0]
y_virtual_func = y.get_function_from_full_name("myVirtualFunction()")
assert y_virtual_func.is_virtual
assert not y_virtual_func.is_override
x = y_virtual_func.overridden_by
assert len(x) == 1
assert x[0].canonical_name == "Z.myVirtualFunction()"
z = slither.get_contract_from_name("Z")[0]
z_virtual_func = z.get_function_from_full_name("myVirtualFunction()")
assert z_virtual_func.is_virtual
assert z_virtual_func.is_override
x = z.get_functions_overridden_by(z_virtual_func)
assert len(x) == 2
assert set([i.canonical_name for i in x]) == set(
["Y.myVirtualFunction()", "X.myVirtualFunction()"]
)
def test_virtual_override_references_and_implementations(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.15")
file = Path(TEST_DATA_DIR, "virtual_overrides.sol").as_posix()
slither = Slither(file, solc=solc_path)
funcs = slither.offset_to_objects(file, 29)
assert len(funcs) == 1
func = funcs.pop()
assert func.canonical_name == "Test.myVirtualFunction()"
assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 29)} == {
(20, 73),
(102, 164),
(274, 328),
(357, 419),
}
funcs = slither.offset_to_objects(file, 111)
assert len(funcs) == 1
func = funcs.pop()
assert func.canonical_name == "A.myVirtualFunction()"
# A.myVirtualFunction() is implemented in A and also overridden in B
assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 111)} == {
(102, 164),
(190, 244),
}
# X is inherited by Z and Z.myVirtualFunction() overrides X.myVirtualFunction()
assert {(x.start, x.end) for x in slither.offset_to_references(file, 341)} == {
(514, 515),
(570, 571),
}
# The reference to X in inheritance specifier is the definition of Z
assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 514)} == {(341, 343)}
# The reference to X in the function override specifier is the definition of Z
assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 570)} == {(341, 343)}
# Y is inherited by Z and Z.myVirtualFunction() overrides Y.myVirtualFunction()
assert {(x.start, x.end) for x in slither.offset_to_references(file, 432)} == {
(511, 512),
(567, 568),
}
# The reference to Y in inheritance specifier is the definition of Z
assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 511)} == {(432, 434)}
# The reference to Y in the function override specifier is the definition of Z
assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 567)} == {(432, 434)}
# Name is abstract and has no implementation. It is inherited and implemented by Name2
assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 612)} == {(657, 718)}
def test_virtual_is_implemented(solc_binary_path):
solc_path = solc_binary_path("0.8.15")
file = Path(TEST_DATA_DIR, "virtual_overrides.sol").as_posix()
slither = Slither(file, solc=solc_path)
test2 = slither.get_contract_from_name("Test2")[0]
f = test2.get_function_from_full_name("f()")
assert f.is_virtual
assert not f.is_implemented
a2 = slither.get_contract_from_name("A2")[0]
f = a2.get_function_from_full_name("f()")
assert f.is_virtual
assert f.is_implemented
# Test.2f() is not implemented, but A2 inherits from Test2 and overrides f()
assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 759)} == {(809, 853)}
Loading…
Cancel
Save