add override for all instance of virtual, inherited functions

pull/2376/head
alpharush 8 months ago
parent 759a4fcead
commit e7edac5038
  1. 2
      slither/core/declarations/function.py
  2. 24
      slither/solc_parsing/declarations/function.py
  3. 9
      slither/solc_parsing/slither_compilation_unit_solc.py
  4. 34
      tests/unit/core/test_virtual_overrides.py

@ -474,6 +474,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def overridden_by(self) -> List["FunctionContract"]: def overridden_by(self) -> List["FunctionContract"]:
""" """
List["FunctionContract"]: List of functions in child contracts that override this function List["FunctionContract"]: List of functions in child contracts that override this function
This may include distinct instances of the same function due to inheritance
""" """
return self._overridden_by return self._overridden_by
@ -481,6 +482,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def overrides(self) -> List["FunctionContract"]: def overrides(self) -> List["FunctionContract"]:
""" """
List["FunctionContract"]: List of functions in parent contracts that this function overrides List["FunctionContract"]: List of functions in parent contracts that this function overrides
This may include distinct instances of the same function due to inheritance
""" """
return self._overrides return self._overrides

@ -242,23 +242,21 @@ class FunctionSolc(CallerContextExpression):
self._function.payable = attributes["payable"] self._function.payable = attributes["payable"]
if "baseFunctions" in attributes: if "baseFunctions" in attributes:
overrides_ids = [] overrides_ids = attributes["baseFunctions"]
for o_id in attributes["baseFunctions"]:
overrides_ids.append(o_id)
if len(overrides_ids) > 0: if len(overrides_ids) > 0:
found = 0 for f_id in overrides_ids:
for c in self.contract_parser.underlying_contract.immediate_inheritance: funcs = self.slither_parser.functions_by_id[f_id]
for f in c.functions_declared: for f in funcs:
if f.id in overrides_ids: # Do not consider leaf contracts as overrides.
# B is A { function a() override {} } and C is A { function a() override {} } override A.a(), not each other.
if (
f.contract == self._function.contract
or f.contract in self._function.contract.inheritance
):
self._function.overrides.append(f) self._function.overrides.append(f)
f.overridden_by.append(self._function) f.overridden_by.append(self._function)
found += 1
# Search next parent if already found overridden func in this parent
continue
# Stop searching if we found all the overrides
if len(overrides_ids) == found:
break
# Attaches reference to override specifier e.g. X is referenced by `function a() override(X)`
if "overrides" in attributes and isinstance(attributes["overrides"], dict): if "overrides" in attributes and isinstance(attributes["overrides"], dict):
for override in attributes["overrides"].get("overrides", []): for override in attributes["overrides"].get("overrides", []):
refId = override["referencedDeclaration"] refId = override["referencedDeclaration"]

@ -1,3 +1,4 @@
from collections import defaultdict
import json import json
import logging import logging
import os import os
@ -7,7 +8,7 @@ from typing import List, Dict
from slither.analyses.data_dependency.data_dependency import compute_dependency from slither.analyses.data_dependency.data_dependency import compute_dependency
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Contract from slither.core.declarations import Contract, Function
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.event_top_level import EventTopLevel from slither.core.declarations.event_top_level import EventTopLevel
@ -80,6 +81,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
self._compilation_unit: SlitherCompilationUnit = compilation_unit self._compilation_unit: SlitherCompilationUnit = compilation_unit
self._contracts_by_id: Dict[int, Contract] = {} self._contracts_by_id: Dict[int, Contract] = {}
self._functions_by_id: Dict[int, List[Function]] = defaultdict(list)
self._parsed = False self._parsed = False
self._analyzed = False self._analyzed = False
self._is_compact_ast = False self._is_compact_ast = False
@ -104,6 +106,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
def add_function_or_modifier_parser(self, f: FunctionSolc) -> None: def add_function_or_modifier_parser(self, f: FunctionSolc) -> None:
self._all_functions_and_modifier_parser.append(f) self._all_functions_and_modifier_parser.append(f)
self._functions_by_id[f.underlying_function.id].append(f.underlying_function)
@property @property
def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]: def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]:
@ -117,6 +120,10 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
def contracts_by_id(self) -> Dict[int, Contract]: def contracts_by_id(self) -> Dict[int, Contract]:
return self._contracts_by_id return self._contracts_by_id
@property
def functions_by_id(self) -> Dict[int, List[Function]]:
return self._functions_by_id
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region AST # region AST

@ -15,8 +15,8 @@ def test_overrides(solc_binary_path) -> None:
x = test.get_functions_overridden_by(test_virtual_func) x = test.get_functions_overridden_by(test_virtual_func)
assert len(x) == 0 assert len(x) == 0
x = test_virtual_func.overridden_by x = test_virtual_func.overridden_by
assert len(x) == 3, [i.canonical_name for i in x] assert len(x) == 5
assert set([i.canonical_name for i in x]) == set( assert set(i.canonical_name for i in x) == set(
["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"] ["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"]
) )
@ -25,18 +25,27 @@ def test_overrides(solc_binary_path) -> None:
assert a_virtual_func.is_virtual assert a_virtual_func.is_virtual
assert a_virtual_func.is_override assert a_virtual_func.is_override
x = a.get_functions_overridden_by(a_virtual_func) x = a.get_functions_overridden_by(a_virtual_func)
assert len(x) == 1 assert len(x) == 2
assert x[0].canonical_name == "Test.myVirtualFunction()" assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"])
b = slither.get_contract_from_name("B")[0] b = slither.get_contract_from_name("B")[0]
b_virtual_func = b.get_function_from_full_name("myVirtualFunction()") b_virtual_func = b.get_function_from_full_name("myVirtualFunction()")
assert not b_virtual_func.is_virtual assert not b_virtual_func.is_virtual
assert b_virtual_func.is_override assert b_virtual_func.is_override
x = b.get_functions_overridden_by(b_virtual_func) x = b.get_functions_overridden_by(b_virtual_func)
assert len(x) == 1 assert len(x) == 2
assert x[0].canonical_name == "A.myVirtualFunction()" assert set(i.canonical_name for i in x) == set(["A.myVirtualFunction()"])
assert len(b_virtual_func.overridden_by) == 0 assert len(b_virtual_func.overridden_by) == 0
c = slither.get_contract_from_name("C")[0]
c_virtual_func = c.get_function_from_full_name("myVirtualFunction()")
assert not c_virtual_func.is_virtual
assert c_virtual_func.is_override
x = c.get_functions_overridden_by(c_virtual_func)
assert len(x) == 2
# C should not override B as they are distinct leaves in the inheritance tree
assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"])
y = slither.get_contract_from_name("Y")[0] y = slither.get_contract_from_name("Y")[0]
y_virtual_func = y.get_function_from_full_name("myVirtualFunction()") y_virtual_func = y.get_function_from_full_name("myVirtualFunction()")
assert y_virtual_func.is_virtual assert y_virtual_func.is_virtual
@ -50,8 +59,8 @@ def test_overrides(solc_binary_path) -> None:
assert z_virtual_func.is_virtual assert z_virtual_func.is_virtual
assert z_virtual_func.is_override assert z_virtual_func.is_override
x = z.get_functions_overridden_by(z_virtual_func) x = z.get_functions_overridden_by(z_virtual_func)
assert len(x) == 2 assert len(x) == 4
assert set([i.canonical_name for i in x]) == set( assert set(i.canonical_name for i in x) == set(
["Y.myVirtualFunction()", "X.myVirtualFunction()"] ["Y.myVirtualFunction()", "X.myVirtualFunction()"]
) )
@ -59,14 +68,19 @@ def test_overrides(solc_binary_path) -> None:
k_virtual_func = k.get_function_from_full_name("a()") k_virtual_func = k.get_function_from_full_name("a()")
assert not k_virtual_func.is_virtual assert not k_virtual_func.is_virtual
assert k_virtual_func.is_override assert k_virtual_func.is_override
assert len(k_virtual_func.overrides) == 1 assert len(k_virtual_func.overrides) == 3
x = k_virtual_func.overrides
assert set(i.canonical_name for i in x) == set(["I.a()"])
i = slither.get_contract_from_name("I")[0] i = slither.get_contract_from_name("I")[0]
i_virtual_func = i.get_function_from_full_name("a()") i_virtual_func = i.get_function_from_full_name("a()")
assert i_virtual_func.is_virtual assert i_virtual_func.is_virtual
assert not i_virtual_func.is_override assert not i_virtual_func.is_override
assert len(i_virtual_func.overrides) == 0 assert len(i_virtual_func.overrides) == 0
assert len(i_virtual_func.overridden_by) == 1 x = i_virtual_func.overridden_by
assert len(x) == 1
assert x[0].canonical_name == "K.a()"
def test_virtual_override_references_and_implementations(solc_binary_path) -> None: def test_virtual_override_references_and_implementations(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.15") solc_path = solc_binary_path("0.8.15")

Loading…
Cancel
Save