diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 93d4039db..0a6f5ae2a 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -473,7 +473,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu @property def overridden_by(self) -> List["FunctionContract"]: """ - List["FunctionContract"]: List offunctions in child contracts that override this function + List["FunctionContract"]: List of functions in child contracts that override this function + This may include distinct instances of the same function due to inheritance """ return self._overridden_by @@ -481,6 +482,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def overrides(self) -> List["FunctionContract"]: """ List["FunctionContract"]: List of functions in parent contracts that this function overrides + This may include distinct instances of the same function due to inheritance """ return self._overrides diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 7a5852324..4ff77d008 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -242,23 +242,21 @@ class FunctionSolc(CallerContextExpression): self._function.payable = attributes["payable"] if "baseFunctions" in attributes: - overrides_ids = [] - for o_id in attributes["baseFunctions"]: - overrides_ids.append(o_id) + overrides_ids = attributes["baseFunctions"] if len(overrides_ids) > 0: - found = 0 - for c in self.contract_parser.underlying_contract.immediate_inheritance: - for f in c.functions_declared: - if f.id in overrides_ids: + for f_id in overrides_ids: + funcs = self.slither_parser.functions_by_id[f_id] + for f in funcs: + # Do not consider leaf contracts as overrides. + # B is A { function a() override {} } and C is A { function a() override {} } override A.a(), not each other. + if ( + f.contract == self._function.contract + or f.contract in self._function.contract.inheritance + ): self._function.overrides.append(f) f.overridden_by.append(self._function) - found += 1 - # Search next parent if already found overridden func in this parent - continue - # Stop searching if we found all the overrides - if len(overrides_ids) == found: - break + # Attaches reference to override specifier e.g. X is referenced by `function a() override(X)` if "overrides" in attributes and isinstance(attributes["overrides"], dict): for override in attributes["overrides"].get("overrides", []): refId = override["referencedDeclaration"] diff --git a/slither/solc_parsing/slither_compilation_unit_solc.py b/slither/solc_parsing/slither_compilation_unit_solc.py index 212ebca2e..fcf0723c4 100644 --- a/slither/solc_parsing/slither_compilation_unit_solc.py +++ b/slither/solc_parsing/slither_compilation_unit_solc.py @@ -1,3 +1,4 @@ +from collections import defaultdict import json import logging import os @@ -7,7 +8,7 @@ from typing import List, Dict from slither.analyses.data_dependency.data_dependency import compute_dependency from slither.core.compilation_unit import SlitherCompilationUnit -from slither.core.declarations import Contract +from slither.core.declarations import Contract, Function from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.event_top_level import EventTopLevel @@ -80,6 +81,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): self._compilation_unit: SlitherCompilationUnit = compilation_unit self._contracts_by_id: Dict[int, Contract] = {} + self._functions_by_id: Dict[int, List[Function]] = defaultdict(list) self._parsed = False self._analyzed = False self._is_compact_ast = False @@ -104,6 +106,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def add_function_or_modifier_parser(self, f: FunctionSolc) -> None: self._all_functions_and_modifier_parser.append(f) + self._functions_by_id[f.underlying_function.id].append(f.underlying_function) @property def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]: @@ -117,6 +120,10 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def contracts_by_id(self) -> Dict[int, Contract]: return self._contracts_by_id + @property + def functions_by_id(self) -> Dict[int, List[Function]]: + return self._functions_by_id + ################################################################################### ################################################################################### # region AST diff --git a/tests/unit/core/test_virtual_overrides.py b/tests/unit/core/test_virtual_overrides.py index b58e45b62..f2c6b2e3d 100644 --- a/tests/unit/core/test_virtual_overrides.py +++ b/tests/unit/core/test_virtual_overrides.py @@ -15,8 +15,8 @@ def test_overrides(solc_binary_path) -> None: x = test.get_functions_overridden_by(test_virtual_func) assert len(x) == 0 x = test_virtual_func.overridden_by - assert len(x) == 3, [i.canonical_name for i in x] - assert set([i.canonical_name for i in x]) == set( + assert len(x) == 5 + assert set(i.canonical_name for i in x) == set( ["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"] ) @@ -25,18 +25,27 @@ def test_overrides(solc_binary_path) -> None: assert a_virtual_func.is_virtual assert a_virtual_func.is_override x = a.get_functions_overridden_by(a_virtual_func) - assert len(x) == 1 - assert x[0].canonical_name == "Test.myVirtualFunction()" + assert len(x) == 2 + assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"]) b = slither.get_contract_from_name("B")[0] b_virtual_func = b.get_function_from_full_name("myVirtualFunction()") assert not b_virtual_func.is_virtual assert b_virtual_func.is_override x = b.get_functions_overridden_by(b_virtual_func) - assert len(x) == 1 - assert x[0].canonical_name == "A.myVirtualFunction()" + assert len(x) == 2 + assert set(i.canonical_name for i in x) == set(["A.myVirtualFunction()"]) assert len(b_virtual_func.overridden_by) == 0 + c = slither.get_contract_from_name("C")[0] + c_virtual_func = c.get_function_from_full_name("myVirtualFunction()") + assert not c_virtual_func.is_virtual + assert c_virtual_func.is_override + x = c.get_functions_overridden_by(c_virtual_func) + assert len(x) == 2 + # C should not override B as they are distinct leaves in the inheritance tree + assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"]) + y = slither.get_contract_from_name("Y")[0] y_virtual_func = y.get_function_from_full_name("myVirtualFunction()") assert y_virtual_func.is_virtual @@ -50,8 +59,8 @@ def test_overrides(solc_binary_path) -> None: assert z_virtual_func.is_virtual assert z_virtual_func.is_override x = z.get_functions_overridden_by(z_virtual_func) - assert len(x) == 2 - assert set([i.canonical_name for i in x]) == set( + assert len(x) == 4 + assert set(i.canonical_name for i in x) == set( ["Y.myVirtualFunction()", "X.myVirtualFunction()"] ) @@ -59,14 +68,19 @@ def test_overrides(solc_binary_path) -> None: k_virtual_func = k.get_function_from_full_name("a()") assert not k_virtual_func.is_virtual assert k_virtual_func.is_override - assert len(k_virtual_func.overrides) == 1 + assert len(k_virtual_func.overrides) == 3 + x = k_virtual_func.overrides + assert set(i.canonical_name for i in x) == set(["I.a()"]) i = slither.get_contract_from_name("I")[0] i_virtual_func = i.get_function_from_full_name("a()") assert i_virtual_func.is_virtual assert not i_virtual_func.is_override assert len(i_virtual_func.overrides) == 0 - assert len(i_virtual_func.overridden_by) == 1 + x = i_virtual_func.overridden_by + assert len(x) == 1 + assert x[0].canonical_name == "K.a()" + def test_virtual_override_references_and_implementations(solc_binary_path) -> None: solc_path = solc_binary_path("0.8.15")