add override for all instance of virtual, inherited functions

pull/2376/head
alpharush 8 months ago
parent 759a4fcead
commit e7edac5038
  1. 4
      slither/core/declarations/function.py
  2. 24
      slither/solc_parsing/declarations/function.py
  3. 9
      slither/solc_parsing/slither_compilation_unit_solc.py
  4. 34
      tests/unit/core/test_virtual_overrides.py

@ -473,7 +473,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
@property
def overridden_by(self) -> List["FunctionContract"]:
"""
List["FunctionContract"]: List offunctions in child contracts that override this function
List["FunctionContract"]: List of functions in child contracts that override this function
This may include distinct instances of the same function due to inheritance
"""
return self._overridden_by
@ -481,6 +482,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def overrides(self) -> List["FunctionContract"]:
"""
List["FunctionContract"]: List of functions in parent contracts that this function overrides
This may include distinct instances of the same function due to inheritance
"""
return self._overrides

@ -242,23 +242,21 @@ class FunctionSolc(CallerContextExpression):
self._function.payable = attributes["payable"]
if "baseFunctions" in attributes:
overrides_ids = []
for o_id in attributes["baseFunctions"]:
overrides_ids.append(o_id)
overrides_ids = attributes["baseFunctions"]
if len(overrides_ids) > 0:
found = 0
for c in self.contract_parser.underlying_contract.immediate_inheritance:
for f in c.functions_declared:
if f.id in overrides_ids:
for f_id in overrides_ids:
funcs = self.slither_parser.functions_by_id[f_id]
for f in funcs:
# Do not consider leaf contracts as overrides.
# B is A { function a() override {} } and C is A { function a() override {} } override A.a(), not each other.
if (
f.contract == self._function.contract
or f.contract in self._function.contract.inheritance
):
self._function.overrides.append(f)
f.overridden_by.append(self._function)
found += 1
# Search next parent if already found overridden func in this parent
continue
# Stop searching if we found all the overrides
if len(overrides_ids) == found:
break
# Attaches reference to override specifier e.g. X is referenced by `function a() override(X)`
if "overrides" in attributes and isinstance(attributes["overrides"], dict):
for override in attributes["overrides"].get("overrides", []):
refId = override["referencedDeclaration"]

@ -1,3 +1,4 @@
from collections import defaultdict
import json
import logging
import os
@ -7,7 +8,7 @@ from typing import List, Dict
from slither.analyses.data_dependency.data_dependency import compute_dependency
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Contract
from slither.core.declarations import Contract, Function
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.event_top_level import EventTopLevel
@ -80,6 +81,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
self._compilation_unit: SlitherCompilationUnit = compilation_unit
self._contracts_by_id: Dict[int, Contract] = {}
self._functions_by_id: Dict[int, List[Function]] = defaultdict(list)
self._parsed = False
self._analyzed = False
self._is_compact_ast = False
@ -104,6 +106,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
def add_function_or_modifier_parser(self, f: FunctionSolc) -> None:
self._all_functions_and_modifier_parser.append(f)
self._functions_by_id[f.underlying_function.id].append(f.underlying_function)
@property
def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]:
@ -117,6 +120,10 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
def contracts_by_id(self) -> Dict[int, Contract]:
return self._contracts_by_id
@property
def functions_by_id(self) -> Dict[int, List[Function]]:
return self._functions_by_id
###################################################################################
###################################################################################
# region AST

@ -15,8 +15,8 @@ def test_overrides(solc_binary_path) -> None:
x = test.get_functions_overridden_by(test_virtual_func)
assert len(x) == 0
x = test_virtual_func.overridden_by
assert len(x) == 3, [i.canonical_name for i in x]
assert set([i.canonical_name for i in x]) == set(
assert len(x) == 5
assert set(i.canonical_name for i in x) == set(
["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"]
)
@ -25,18 +25,27 @@ def test_overrides(solc_binary_path) -> None:
assert a_virtual_func.is_virtual
assert a_virtual_func.is_override
x = a.get_functions_overridden_by(a_virtual_func)
assert len(x) == 1
assert x[0].canonical_name == "Test.myVirtualFunction()"
assert len(x) == 2
assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"])
b = slither.get_contract_from_name("B")[0]
b_virtual_func = b.get_function_from_full_name("myVirtualFunction()")
assert not b_virtual_func.is_virtual
assert b_virtual_func.is_override
x = b.get_functions_overridden_by(b_virtual_func)
assert len(x) == 1
assert x[0].canonical_name == "A.myVirtualFunction()"
assert len(x) == 2
assert set(i.canonical_name for i in x) == set(["A.myVirtualFunction()"])
assert len(b_virtual_func.overridden_by) == 0
c = slither.get_contract_from_name("C")[0]
c_virtual_func = c.get_function_from_full_name("myVirtualFunction()")
assert not c_virtual_func.is_virtual
assert c_virtual_func.is_override
x = c.get_functions_overridden_by(c_virtual_func)
assert len(x) == 2
# C should not override B as they are distinct leaves in the inheritance tree
assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"])
y = slither.get_contract_from_name("Y")[0]
y_virtual_func = y.get_function_from_full_name("myVirtualFunction()")
assert y_virtual_func.is_virtual
@ -50,8 +59,8 @@ def test_overrides(solc_binary_path) -> None:
assert z_virtual_func.is_virtual
assert z_virtual_func.is_override
x = z.get_functions_overridden_by(z_virtual_func)
assert len(x) == 2
assert set([i.canonical_name for i in x]) == set(
assert len(x) == 4
assert set(i.canonical_name for i in x) == set(
["Y.myVirtualFunction()", "X.myVirtualFunction()"]
)
@ -59,14 +68,19 @@ def test_overrides(solc_binary_path) -> None:
k_virtual_func = k.get_function_from_full_name("a()")
assert not k_virtual_func.is_virtual
assert k_virtual_func.is_override
assert len(k_virtual_func.overrides) == 1
assert len(k_virtual_func.overrides) == 3
x = k_virtual_func.overrides
assert set(i.canonical_name for i in x) == set(["I.a()"])
i = slither.get_contract_from_name("I")[0]
i_virtual_func = i.get_function_from_full_name("a()")
assert i_virtual_func.is_virtual
assert not i_virtual_func.is_override
assert len(i_virtual_func.overrides) == 0
assert len(i_virtual_func.overridden_by) == 1
x = i_virtual_func.overridden_by
assert len(x) == 1
assert x[0].canonical_name == "K.a()"
def test_virtual_override_references_and_implementations(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.15")

Loading…
Cancel
Save