Merge pull request #990 from crytic/dev-scope-file

Refactor core objects to add a file scope.
pull/994/head
Feist Josselin 3 years ago committed by GitHub
commit 1c127979d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 47
      .github/workflows/features.yml
  2. 14
      slither/core/cfg/node.py
  3. 50
      slither/core/compilation_unit.py
  4. 10
      slither/core/declarations/contract.py
  5. 10
      slither/core/declarations/custom_error_top_level.py
  6. 9
      slither/core/declarations/enum_top_level.py
  7. 39
      slither/core/declarations/function.py
  8. 7
      slither/core/declarations/function_contract.py
  9. 18
      slither/core/declarations/function_top_level.py
  10. 8
      slither/core/declarations/import_directive.py
  11. 8
      slither/core/declarations/pragma_directive.py
  12. 4
      slither/core/declarations/structure.py
  13. 10
      slither/core/declarations/structure_top_level.py
  14. 0
      slither/core/scope/__init__.py
  15. 100
      slither/core/scope/scope.py
  16. 4
      slither/core/slither_core.py
  17. 11
      slither/detectors/slither/name_reused.py
  18. 11
      slither/formatters/attributes/const_functions.py
  19. 11
      slither/formatters/functions/external_function.py
  20. 11
      slither/formatters/naming_convention/naming_convention.py
  21. 3
      slither/formatters/variables/possible_const_state_variables.py
  22. 20
      slither/slither.py
  23. 14
      slither/slithir/convert.py
  24. 2
      slither/slithir/operations/new_contract.py
  25. 35
      slither/solc_parsing/declarations/caller_context.py
  26. 15
      slither/solc_parsing/declarations/contract.py
  27. 14
      slither/solc_parsing/declarations/custom_error.py
  28. 12
      slither/solc_parsing/declarations/function.py
  29. 29
      slither/solc_parsing/declarations/structure_top_level.py
  30. 13
      slither/solc_parsing/expressions/expression_parsing.py
  31. 118
      slither/solc_parsing/expressions/find_variable.py
  32. 70
      slither/solc_parsing/slither_compilation_unit_solc.py
  33. 65
      slither/solc_parsing/solidity_types/type_parsing.py
  34. 3
      slither/solc_parsing/variables/variable_declaration.py
  35. 12
      slither/solc_parsing/yul/parse_yul.py
  36. 8
      slither/tools/upgradeability/checks/initialization.py
  37. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.0-legacy.zip
  38. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.1-legacy.zip
  39. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.10-legacy.zip
  40. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.11-legacy.zip
  41. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.12-compact.zip
  42. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.12-legacy.zip
  43. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.13-compact.zip
  44. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.13-legacy.zip
  45. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.14-compact.zip
  46. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.14-legacy.zip
  47. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.15-compact.zip
  48. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.15-legacy.zip
  49. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.16-compact.zip
  50. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.16-legacy.zip
  51. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.17-compact.zip
  52. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.17-legacy.zip
  53. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.18-compact.zip
  54. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.18-legacy.zip
  55. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.19-compact.zip
  56. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.19-legacy.zip
  57. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.2-legacy.zip
  58. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.20-compact.zip
  59. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.20-legacy.zip
  60. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.21-compact.zip
  61. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.21-legacy.zip
  62. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.22-compact.zip
  63. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.22-legacy.zip
  64. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.23-compact.zip
  65. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.23-legacy.zip
  66. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.24-compact.zip
  67. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.24-legacy.zip
  68. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.25-compact.zip
  69. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.25-legacy.zip
  70. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.26-compact.zip
  71. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.26-legacy.zip
  72. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.3-legacy.zip
  73. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.4-legacy.zip
  74. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.5-legacy.zip
  75. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.6-legacy.zip
  76. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.7-legacy.zip
  77. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.8-legacy.zip
  78. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.4.9-legacy.zip
  79. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.0-compact.zip
  80. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.0-legacy.zip
  81. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.1-compact.zip
  82. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.1-legacy.zip
  83. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.10-compact.zip
  84. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.10-legacy.zip
  85. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.11-compact.zip
  86. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.11-legacy.zip
  87. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.12-compact.zip
  88. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.12-legacy.zip
  89. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.13-compact.zip
  90. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.13-legacy.zip
  91. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.14-compact.zip
  92. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.14-legacy.zip
  93. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.15-compact.zip
  94. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.15-legacy.zip
  95. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.16-compact.zip
  96. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.16-legacy.zip
  97. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.17-compact.zip
  98. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.17-legacy.zip
  99. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.2-compact.zip
  100. BIN
      tests/ast-parsing/compile/top-level-nested-import-0.5.2-legacy.zip
  101. Some files were not shown because too many files have changed in this diff Show More

@ -0,0 +1,47 @@
---
name: Features tests
defaults:
run:
# To load bashrc
shell: bash -ieo pipefail {0}
on:
pull_request:
branches: [master, dev]
schedule:
# run CI every day even if no PRs/merges occur
- cron: '0 12 * * *'
jobs:
build:
name: Features tests
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: Set up Python 3.6
uses: actions/setup-python@v2
with:
python-version: 3.6
- name: Install dependencies
run: |
python setup.py install
pip install deepdiff
pip install pytest
pip install solc-select
solc-select install all
solc-select use 0.8.0
cd tests/test_node_modules/
npm install hardhat
cd ../..
- name: Test with pytest
run: |
pytest tests/test_features.py

@ -54,6 +54,7 @@ if TYPE_CHECKING:
LowLevelCallType, LowLevelCallType,
) )
from slither.core.cfg.scope import Scope from slither.core.cfg.scope import Scope
from slither.core.scope.scope import FileScope
# pylint: disable=too-many-lines,too-many-branches,too-many-instance-attributes # pylint: disable=too-many-lines,too-many-branches,too-many-instance-attributes
@ -152,7 +153,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
""" """
def __init__(self, node_type: NodeType, node_id: int, scope: Union["Scope", "Function"]): def __init__(
self,
node_type: NodeType,
node_id: int,
scope: Union["Scope", "Function"],
file_scope: "FileScope",
):
super().__init__() super().__init__()
self._node_type = node_type self._node_type = node_type
@ -220,7 +227,8 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._asm_source_code: Optional[Union[str, Dict]] = None self._asm_source_code: Optional[Union[str, Dict]] = None
self.scope = scope self.scope: Union["Scope", "Function"] = scope
self.file_scope: "FileScope" = file_scope
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -902,7 +910,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
except AttributeError as error: except AttributeError as error:
# pylint: disable=raise-missing-from # pylint: disable=raise-missing-from
raise SlitherException( raise SlitherException(
f"Function not found on {ir}. Please try compiling with a recent Solidity version. {error}" f"Function not found on IR: {ir}.\nNode: {self} ({self.source_mapping_str})\nFunction: {self.function}\nPlease try compiling with a recent Solidity version. {error}"
) )
elif isinstance(ir, LibraryCall): elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract) assert isinstance(ir.destination, Contract)

@ -1,9 +1,9 @@
import math import math
from collections import defaultdict
from typing import Optional, Dict, List, Set, Union, TYPE_CHECKING, Tuple from typing import Optional, Dict, List, Set, Union, TYPE_CHECKING, Tuple
from crytic_compile import CompilationUnit, CryticCompile from crytic_compile import CompilationUnit, CryticCompile
from crytic_compile.compiler.compiler import CompilerVersion from crytic_compile.compiler.compiler import CompilerVersion
from crytic_compile.utils.naming import Filename
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.declarations import ( from slither.core.declarations import (
@ -17,6 +17,7 @@ from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.scope.scope import FileScope
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.top_level_variable import TopLevelVariable
from slither.slithir.operations import InternalCall from slither.slithir.operations import InternalCall
@ -25,6 +26,7 @@ from slither.slithir.variables import Constant
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.slither_core import SlitherCore from slither.core.slither_core import SlitherCore
# pylint: disable=too-many-instance-attributes,too-many-public-methods # pylint: disable=too-many-instance-attributes,too-many-public-methods
class SlitherCompilationUnit(Context): class SlitherCompilationUnit(Context):
def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit): def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit):
@ -34,7 +36,7 @@ class SlitherCompilationUnit(Context):
self._crytic_compile_compilation_unit = crytic_compilation_unit self._crytic_compile_compilation_unit = crytic_compilation_unit
# Top level object # Top level object
self._contracts: Dict[str, Contract] = {} self.contracts: List[Contract] = []
self._structures_top_level: List[StructureTopLevel] = [] self._structures_top_level: List[StructureTopLevel] = []
self._enums_top_level: List[EnumTopLevel] = [] self._enums_top_level: List[EnumTopLevel] = []
self._variables_top_level: List[TopLevelVariable] = [] self._variables_top_level: List[TopLevelVariable] = []
@ -51,7 +53,6 @@ class SlitherCompilationUnit(Context):
self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._contract_name_collisions = defaultdict(list)
self._contract_with_missing_inheritance = set() self._contract_with_missing_inheritance = set()
self._source_units: Dict[int, str] = {} self._source_units: Dict[int, str] = {}
@ -60,6 +61,8 @@ class SlitherCompilationUnit(Context):
self.counter_slithir_temporary = 0 self.counter_slithir_temporary = 0
self.counter_slithir_reference = 0 self.counter_slithir_reference = 0
self.scopes: Dict[Filename, FileScope] = dict()
@property @property
def core(self) -> "SlitherCore": def core(self) -> "SlitherCore":
return self._core return self._core
@ -115,32 +118,22 @@ class SlitherCompilationUnit(Context):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@property
def contracts(self) -> List[Contract]:
"""list(Contract): List of contracts."""
return list(self._contracts.values())
@property @property
def contracts_derived(self) -> List[Contract]: def contracts_derived(self) -> List[Contract]:
"""list(Contract): List of contracts that are derived and not inherited.""" """list(Contract): List of contracts that are derived and not inherited."""
inheritances = [x.inheritance for x in self.contracts] inheritances = [x.inheritance for x in self.contracts]
inheritance = [item for sublist in inheritances for item in sublist] inheritance = [item for sublist in inheritances for item in sublist]
return [c for c in self._contracts.values() if c not in inheritance and not c.is_top_level] return [c for c in self.contracts if c not in inheritance and not c.is_top_level]
@property def get_contract_from_name(self, contract_name: Union[str, Constant]) -> List[Contract]:
def contracts_as_dict(self) -> Dict[str, Contract]:
"""list(dict(str: Contract): List of contracts as dict: name -> Contract."""
return self._contracts
def get_contract_from_name(self, contract_name: Union[str, Constant]) -> Optional[Contract]:
""" """
Return a contract from a name Return a list of contract from a name
Args: Args:
contract_name (str): name of the contract contract_name (str): name of the contract
Returns: Returns:
Contract List[Contract]
""" """
return next((c for c in self.contracts if c.name == contract_name), None) return [c for c in self.contracts if c.name == contract_name]
# endregion # endregion
################################################################################### ###################################################################################
@ -223,14 +216,27 @@ class SlitherCompilationUnit(Context):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@property
def contract_name_collisions(self) -> Dict:
return self._contract_name_collisions
@property @property
def contracts_with_missing_inheritance(self) -> Set: def contracts_with_missing_inheritance(self) -> Set:
return self._contract_with_missing_inheritance return self._contract_with_missing_inheritance
# endregion
###################################################################################
###################################################################################
# region Scope
###################################################################################
###################################################################################
def get_scope(self, filename_str: str) -> FileScope:
filename = self._crytic_compile_compilation_unit.crytic_compile.filename_lookup(
filename_str
)
if filename not in self.scopes:
self.scopes[filename] = FileScope(filename)
return self.scopes[filename]
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -39,6 +39,7 @@ if TYPE_CHECKING:
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.scope.scope import FileScope
LOGGER = logging.getLogger("Contract") LOGGER = logging.getLogger("Contract")
@ -49,7 +50,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
Contract class Contract class
""" """
def __init__(self, compilation_unit: "SlitherCompilationUnit"): def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"):
super().__init__() super().__init__()
self._name: Optional[str] = None self._name: Optional[str] = None
@ -69,7 +70,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._modifiers: Dict[str, "Modifier"] = {} self._modifiers: Dict[str, "Modifier"] = {}
self._functions: Dict[str, "FunctionContract"] = {} self._functions: Dict[str, "FunctionContract"] = {}
self._linearizedBaseContracts: List[int] = [] self._linearizedBaseContracts: List[int] = []
self._custom_errors: Dict[str:"CustomErrorContract"] = {} self._custom_errors: Dict[str, "CustomErrorContract"] = {}
# The only str is "*" # The only str is "*"
self._using_for: Dict[Union[str, Type], List[str]] = {} self._using_for: Dict[Union[str, Type], List[str]] = {}
@ -92,6 +93,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._all_functions_called: Optional[List["InternalCallType"]] = None self._all_functions_called: Optional[List["InternalCallType"]] = None
self.compilation_unit: "SlitherCompilationUnit" = compilation_unit self.compilation_unit: "SlitherCompilationUnit" = compilation_unit
self.file_scope: "FileScope" = scope
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -1086,7 +1088,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._is_upgradeable = False self._is_upgradeable = False
if self.is_upgradeable_proxy: if self.is_upgradeable_proxy:
return False return False
initializable = self.compilation_unit.get_contract_from_name("Initializable") initializable = self.file_scope.get_contract_from_name("Initializable")
if initializable: if initializable:
if initializable in self.inheritance: if initializable in self.inheritance:
self._is_upgradeable = True self._is_upgradeable = True
@ -1224,7 +1226,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
) )
# Function uses to create node for state variable declaration statements # Function uses to create node for state variable declaration statements
node = Node(NodeType.OTHER_ENTRYPOINT, counter, scope) node = Node(NodeType.OTHER_ENTRYPOINT, counter, scope, func.file_scope)
node.set_offset(variable.source_mapping, self.compilation_unit) node.set_offset(variable.source_mapping, self.compilation_unit)
node.set_function(func) node.set_function(func)
func.add_node(node) func.add_node(node)

@ -1,6 +1,14 @@
from typing import TYPE_CHECKING
from slither.core.declarations.custom_error import CustomError from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
class CustomErrorTopLevel(CustomError, TopLevel): class CustomErrorTopLevel(CustomError, TopLevel):
pass def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"):
super().__init__(compilation_unit)
self.file_scope: "FileScope" = scope

@ -1,6 +1,13 @@
from typing import TYPE_CHECKING, List
from slither.core.declarations import Enum from slither.core.declarations import Enum
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
if TYPE_CHECKING:
from slither.core.scope.scope import FileScope
class EnumTopLevel(Enum, TopLevel): class EnumTopLevel(Enum, TopLevel):
pass def __init__(self, name: str, canonical_name: str, values: List[str], scope: "FileScope"):
super().__init__(name, canonical_name, values)
self.file_scope: "FileScope" = scope

@ -44,6 +44,7 @@ if TYPE_CHECKING:
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.slithir.operations import Operation from slither.slithir.operations import Operation
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
LOGGER = logging.getLogger("Function") LOGGER = logging.getLogger("Function")
ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
@ -117,7 +118,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def __init__(self, compilation_unit: "SlitherCompilationUnit"): def __init__(self, compilation_unit: "SlitherCompilationUnit"):
super().__init__() super().__init__()
self._scope: List[str] = [] self._internal_scope: List[str] = []
self._name: Optional[str] = None self._name: Optional[str] = None
self._view: bool = False self._view: bool = False
self._pure: bool = False self._pure: bool = False
@ -216,6 +217,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
# Assume we are analyzing Solidty by default # Assume we are analyzing Solidty by default
self.function_language: FunctionLanguage = FunctionLanguage.Solidity self.function_language: FunctionLanguage = FunctionLanguage.Solidity
self._id: Optional[str] = None
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region General properties # region General properties
@ -244,18 +247,18 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
self._name = new_name self._name = new_name
@property @property
def scope(self) -> List[str]: def internal_scope(self) -> List[str]:
""" """
Return a list of name representing the scope of the function Return a list of name representing the scope of the function
This is used to model nested functions declared in YUL This is used to model nested functions declared in YUL
:return: :return:
""" """
return self._scope return self._internal_scope
@scope.setter @internal_scope.setter
def scope(self, new_scope: List[str]): def internal_scope(self, new_scope: List[str]):
self._scope = new_scope self._internal_scope = new_scope
@property @property
def full_name(self) -> str: def full_name(self) -> str:
@ -265,7 +268,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
""" """
if self._full_name is None: if self._full_name is None:
name, parameters, _ = self.signature name, parameters, _ = self.signature
full_name = ".".join(self._scope + [name]) + "(" + ",".join(parameters) + ")" full_name = ".".join(self._internal_scope + [name]) + "(" + ",".join(parameters) + ")"
self._full_name = full_name self._full_name = full_name
return self._full_name return self._full_name
@ -334,6 +337,26 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
return self.compilation_unit.solc_version >= "0.8.0" return self.compilation_unit.solc_version >= "0.8.0"
@property
def id(self) -> Optional[str]:
"""
Return the ID of the funciton. For Solidity with compact-AST the ID is the reference ID
For other, the ID is None
:return:
:rtype:
"""
return self._id
@id.setter
def id(self, new_id: str):
self._id = new_id
@property
@abstractmethod
def file_scope(self) -> "FileScope":
pass
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -1552,7 +1575,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
) -> "Node": ) -> "Node":
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
node = Node(node_type, self._counter_nodes, scope) node = Node(node_type, self._counter_nodes, scope, self.file_scope)
node.set_offset(src, self.compilation_unit) node.set_offset(src, self.compilation_unit)
self._counter_nodes += 1 self._counter_nodes += 1
node.set_function(self) node.set_function(self)

@ -11,6 +11,7 @@ from slither.core.declarations import Function
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.core.scope.scope import FileScope
class FunctionContract(Function, ChildContract, ChildInheritance): class FunctionContract(Function, ChildContract, ChildInheritance):
@ -23,7 +24,7 @@ class FunctionContract(Function, ChildContract, ChildInheritance):
if self._canonical_name is None: if self._canonical_name is None:
name, parameters, _ = self.signature name, parameters, _ = self.signature
self._canonical_name = ( self._canonical_name = (
".".join([self.contract_declarer.name] + self._scope + [name]) ".".join([self.contract_declarer.name] + self._internal_scope + [name])
+ "(" + "("
+ ",".join(parameters) + ",".join(parameters)
+ ")" + ")"
@ -38,6 +39,10 @@ class FunctionContract(Function, ChildContract, ChildInheritance):
""" """
return self.contract_declarer == contract return self.contract_declarer == contract
@property
def file_scope(self) -> "FileScope":
return self.contract.file_scope
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -1,13 +1,25 @@
""" """
Function module Function module
""" """
from typing import List, Tuple from typing import List, Tuple, TYPE_CHECKING
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
class FunctionTopLevel(Function, TopLevel): class FunctionTopLevel(Function, TopLevel):
def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"):
super().__init__(compilation_unit)
self._scope: "FileScope" = scope
@property
def file_scope(self) -> "FileScope":
return self._scope
@property @property
def canonical_name(self) -> str: def canonical_name(self) -> str:
""" """
@ -16,7 +28,9 @@ class FunctionTopLevel(Function, TopLevel):
""" """
if self._canonical_name is None: if self._canonical_name is None:
name, parameters, _ = self.signature name, parameters, _ = self.signature
self._canonical_name = ".".join(self._scope + [name]) + "(" + ",".join(parameters) + ")" self._canonical_name = (
".".join(self._internal_scope + [name]) + "(" + ",".join(parameters) + ")"
)
return self._canonical_name return self._canonical_name
# endregion # endregion

@ -1,14 +1,18 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, TYPE_CHECKING
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
if TYPE_CHECKING:
from slither.core.scope.scope import FileScope
class Import(SourceMapping): class Import(SourceMapping):
def __init__(self, filename: Path): def __init__(self, filename: Path, scope: "FileScope"):
super().__init__() super().__init__()
self._filename: Path = filename self._filename: Path = filename
self._alias: Optional[str] = None self._alias: Optional[str] = None
self.scope: "FileScope" = scope
@property @property
def filename(self) -> str: def filename(self) -> str:

@ -1,12 +1,16 @@
from typing import List from typing import List, TYPE_CHECKING
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
if TYPE_CHECKING:
from slither.core.scope.scope import FileScope
class Pragma(SourceMapping): class Pragma(SourceMapping):
def __init__(self, directive: List[str]): def __init__(self, directive: List[str], scope: "FileScope"):
super().__init__() super().__init__()
self._directive = directive self._directive = directive
self.scope: "FileScope" = scope
@property @property
def directive(self) -> List[str]: def directive(self) -> List[str]:

@ -4,16 +4,18 @@ from slither.core.source_mapping.source_mapping import SourceMapping
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.variables.structure_variable import StructureVariable from slither.core.variables.structure_variable import StructureVariable
from slither.core.compilation_unit import SlitherCompilationUnit
class Structure(SourceMapping): class Structure(SourceMapping):
def __init__(self): def __init__(self, compilation_unit: "SlitherCompilationUnit"):
super().__init__() super().__init__()
self._name = None self._name = None
self._canonical_name = None self._canonical_name = None
self._elems: Dict[str, "StructureVariable"] = dict() self._elems: Dict[str, "StructureVariable"] = dict()
# Name of the elements in the order of declaration # Name of the elements in the order of declaration
self._elems_ordered: List[str] = [] self._elems_ordered: List[str] = []
self.compilation_unit = compilation_unit
@property @property
def canonical_name(self) -> str: def canonical_name(self) -> str:

@ -1,6 +1,14 @@
from typing import TYPE_CHECKING
from slither.core.declarations import Structure from slither.core.declarations import Structure
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
if TYPE_CHECKING:
from slither.core.scope.scope import FileScope
from slither.core.compilation_unit import SlitherCompilationUnit
class StructureTopLevel(Structure, TopLevel): class StructureTopLevel(Structure, TopLevel):
pass def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"):
super().__init__(compilation_unit)
self.file_scope: "FileScope" = scope

@ -0,0 +1,100 @@
from typing import List, Any, Dict, Optional, Union, Set
from crytic_compile.utils.naming import Filename
from slither.core.declarations import Contract, Import, Pragma
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.slithir.variables import Constant
def _dict_contain(d1: Dict, d2: Dict) -> bool:
"""
Return true if d1 is included in d2
"""
d2_keys = d2.keys()
return all(item in d2_keys for item in d1.keys())
# pylint: disable=too-many-instance-attributes
class FileScope:
def __init__(self, filename: Filename):
self.filename = filename
self.accessible_scopes: List[FileScope] = []
self.contracts: Dict[str, Contract] = dict()
# Custom error are a list instead of a dict
# Because we parse the function signature later on
# So we simplify the logic and have the scope fields all populated
self.custom_errors: Set[CustomErrorTopLevel] = set()
self.enums: Dict[str, EnumTopLevel] = dict()
# Functions is a list instead of a dict
# Because we parse the function signature later on
# So we simplify the logic and have the scope fields all populated
self.functions: Set[FunctionTopLevel] = set()
self.imports: Set[Import] = set()
self.pragmas: Set[Pragma] = set()
self.structures: Dict[str, StructureTopLevel] = dict()
def add_accesible_scopes(self) -> bool:
"""
Add information from accessible scopes. Return true if new information was obtained
:return:
:rtype:
"""
learn_something = False
for new_scope in self.accessible_scopes:
if not _dict_contain(new_scope.contracts, self.contracts):
self.contracts.update(new_scope.contracts)
learn_something = True
if not new_scope.custom_errors.issubset(self.custom_errors):
self.custom_errors |= new_scope.custom_errors
learn_something = True
if not _dict_contain(new_scope.enums, self.enums):
self.enums.update(new_scope.enums)
learn_something = True
if not new_scope.functions.issubset(self.functions):
self.functions |= new_scope.functions
learn_something = True
if not new_scope.imports.issubset(self.imports):
self.imports |= new_scope.imports
learn_something = True
if not new_scope.pragmas.issubset(self.pragmas):
self.pragmas |= new_scope.pragmas
learn_something = True
if not _dict_contain(new_scope.structures, self.structures):
self.structures.update(new_scope.structures)
learn_something = True
return learn_something
def get_contract_from_name(self, name: Union[str, Constant]) -> Optional[Contract]:
if isinstance(name, Constant):
return self.contracts.get(name.name, None)
return self.contracts.get(name, None)
# region Built in definitions
###################################################################################
###################################################################################
def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
return other == self.filename
return NotImplemented
def __neq__(self, other: Any) -> bool:
if isinstance(other, str):
return other != self.filename
return NotImplemented
def __str__(self) -> str:
return str(self.filename.relative)
def __hash__(self) -> int:
return hash(self.filename.relative)
# endregion

@ -110,9 +110,7 @@ class SlitherCore(Context):
""" """
contracts = [] contracts = []
for compilation_unit in self._compilation_units: for compilation_unit in self._compilation_units:
contract = compilation_unit.get_contract_from_name(contract_name) contracts += compilation_unit.get_contract_from_name(contract_name)
if contract:
contracts.append(contract)
return contracts return contracts
################################################################################### ###################################################################################

@ -53,7 +53,16 @@ As a result, the second contract cannot be analyzed.
def _detect(self): # pylint: disable=too-many-locals,too-many-branches def _detect(self): # pylint: disable=too-many-locals,too-many-branches
results = [] results = []
compilation_unit = self.compilation_unit compilation_unit = self.compilation_unit
names_reused = compilation_unit.contract_name_collisions
all_contracts = compilation_unit.contracts
all_contracts_name = [c.name for c in all_contracts]
contracts_name_reused = {
contract for contract in all_contracts_name if all_contracts_name.count(contract) > 1
}
names_reused = {
name: compilation_unit.get_contract_from_name(name) for name in contracts_name_reused
}
# First show the contracts that we know are missing # First show the contracts that we know are missing
incorrectly_constructed = [ incorrectly_constructed = [

@ -5,13 +5,14 @@ from slither.formatters.exceptions import FormatError
from slither.formatters.utils.patches import create_patch from slither.formatters.utils.patches import create_patch
def custom_format(comilation_unit: SlitherCompilationUnit, result): def custom_format(compilation_unit: SlitherCompilationUnit, result):
for file_scope in compilation_unit.scopes.values():
elements = result["elements"] elements = result["elements"]
for element in elements: for element in elements:
if element["type"] != "function": if element["type"] != "function":
# Skip variable elements # Skip variable elements
continue continue
target_contract = comilation_unit.get_contract_from_name( target_contract = file_scope.get_contract_from_name(
element["type_specific_fields"]["parent"]["name"] element["type_specific_fields"]["parent"]["name"]
) )
if target_contract: if target_contract:
@ -20,7 +21,7 @@ def custom_format(comilation_unit: SlitherCompilationUnit, result):
) )
if function: if function:
_patch( _patch(
comilation_unit, compilation_unit,
result, result,
element["source_mapping"]["filename_absolute"], element["source_mapping"]["filename_absolute"],
int( int(
@ -32,9 +33,9 @@ def custom_format(comilation_unit: SlitherCompilationUnit, result):
def _patch( def _patch(
comilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start, modify_loc_end compilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start, modify_loc_end
): ):
in_file_str = comilation_unit.core.source_code[in_file].encode("utf8") in_file_str = compilation_unit.core.source_code[in_file].encode("utf8")
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Find the keywords view|pure|constant and remove them # Find the keywords view|pure|constant and remove them
m = re.search("(view|pure|constant)", old_str_of_interest.decode("utf-8")) m = re.search("(view|pure|constant)", old_str_of_interest.decode("utf-8"))

@ -4,10 +4,11 @@ from slither.core.compilation_unit import SlitherCompilationUnit
from slither.formatters.utils.patches import create_patch from slither.formatters.utils.patches import create_patch
def custom_format(comilation_unit: SlitherCompilationUnit, result): def custom_format(compilation_unit: SlitherCompilationUnit, result):
for file_scope in compilation_unit.scopes.values():
elements = result["elements"] elements = result["elements"]
for element in elements: for element in elements:
target_contract = comilation_unit.get_contract_from_name( target_contract = file_scope.get_contract_from_name(
element["type_specific_fields"]["parent"]["name"] element["type_specific_fields"]["parent"]["name"]
) )
if target_contract: if target_contract:
@ -16,7 +17,7 @@ def custom_format(comilation_unit: SlitherCompilationUnit, result):
) )
if function: if function:
_patch( _patch(
comilation_unit, compilation_unit,
result, result,
element["source_mapping"]["filename_absolute"], element["source_mapping"]["filename_absolute"],
int(function.parameters_src().source_mapping["start"]), int(function.parameters_src().source_mapping["start"]),
@ -25,9 +26,9 @@ def custom_format(comilation_unit: SlitherCompilationUnit, result):
def _patch( def _patch(
comilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start, modify_loc_end compilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start, modify_loc_end
): ):
in_file_str = comilation_unit.core.source_code[in_file].encode("utf8") in_file_str = compilation_unit.core.source_code[in_file].encode("utf8")
old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end]
# Search for 'public' keyword which is in-between the function name and modifier name (if present) # Search for 'public' keyword which is in-between the function name and modifier name (if present)
# regex: 'public' could have spaces around or be at the end of the line # regex: 'public' could have spaces around or be at the end of the line

@ -201,8 +201,9 @@ conventions = {
def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name, getter): def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name, getter):
scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"])
contract_name = element["type_specific_fields"]["parent"]["name"] contract_name = element["type_specific_fields"]["parent"]["name"]
contract = compilation_unit.get_contract_from_name(contract_name) contract = scope.get_contract_from_name(contract_name)
return getattr(contract, getter)(name) return getattr(contract, getter)(name)
@ -215,8 +216,10 @@ def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name,
def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target):
scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"])
if _target == "contract": if _target == "contract":
target = compilation_unit.get_contract_from_name(element["name"]) target = scope.get_contract_from_name(element["name"])
elif _target == "structure": elif _target == "structure":
target = _get_from_contract( target = _get_from_contract(
@ -250,7 +253,7 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target):
"signature" "signature"
] ]
param_name = element["name"] param_name = element["name"]
contract = compilation_unit.get_contract_from_name(contract_name) contract = scope.get_contract_from_name(contract_name)
function = contract.get_function_from_signature(function_sig) function = contract.get_function_from_signature(function_sig)
target = function.get_local_variable_from_name(param_name) target = function.get_local_variable_from_name(param_name)
@ -264,7 +267,7 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target):
"signature" "signature"
] ]
var_name = element["name"] var_name = element["name"]
contract = compilation_unit.get_contract_from_name(contract_name) contract = scope.get_contract_from_name(contract_name)
function = contract.get_function_from_signature(function_sig) function = contract.get_function_from_signature(function_sig)
target = function.get_local_variable_from_name(var_name) target = function.get_local_variable_from_name(var_name)
# State variable # State variable

@ -11,7 +11,8 @@ def custom_format(compilation_unit: SlitherCompilationUnit, result):
# TODO: decide if this should be changed in the constant detector # TODO: decide if this should be changed in the constant detector
contract_name = element["type_specific_fields"]["parent"]["name"] contract_name = element["type_specific_fields"]["parent"]["name"]
contract = compilation_unit.get_contract_from_name(contract_name) scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"])
contract = scope.get_contract_from_name(contract_name)
var = contract.get_state_variable_from_name(element["name"]) var = contract.get_state_variable_from_name(element["name"])
if not var.expression: if not var.expression:
raise FormatImpossible(f"{var.name} is uninitialized and cannot become constant.") raise FormatImpossible(f"{var.name} is uninitialized and cannot become constant.")

@ -1,10 +1,11 @@
import logging import logging
from typing import Union, List from typing import Union, List, ValuesView
from crytic_compile import CryticCompile, InvalidCompilation from crytic_compile import CryticCompile, InvalidCompilation
# pylint: disable= no-name-in-module # pylint: disable= no-name-in-module
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
from slither.core.slither_core import SlitherCore from slither.core.slither_core import SlitherCore
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.exceptions import SlitherError from slither.exceptions import SlitherError
@ -31,6 +32,21 @@ def _check_common_things(thing_name, cls, base_cls, instances_list):
raise Exception("You can't register {!r} twice.".format(cls)) raise Exception("You can't register {!r} twice.".format(cls))
def _update_file_scopes(candidates: ValuesView[FileScope]):
"""
Because solc's import allows cycle in the import
We iterate until we aren't adding new information to the scope
"""
learned_something = False
while True:
for candidate in candidates:
learned_something |= candidate.add_accesible_scopes()
if not learned_something:
break
learned_something = False
class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
def __init__(self, target: Union[str, CryticCompile], **kwargs): def __init__(self, target: Union[str, CryticCompile], **kwargs):
""" """
@ -81,6 +97,8 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
parser.parse_top_level_from_loaded_json(ast, path) parser.parse_top_level_from_loaded_json(ast, path)
self.add_source_code(path) self.add_source_code(path)
_update_file_scopes(compilation_unit_slither.scopes.values())
if kwargs.get("generate_patches", False): if kwargs.get("generate_patches", False):
self.generate_patches = True self.generate_patches = True

@ -527,7 +527,7 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
# UserdefinedType # UserdefinedType
t_type = t.type t_type = t.type
if isinstance(t_type, Contract): if isinstance(t_type, Contract):
contract = node.compilation_unit.get_contract_from_name(t_type.name) contract = node.file_scope.get_contract_from_name(t_type.name)
return convert_type_of_high_and_internal_level_call(ir, contract) return convert_type_of_high_and_internal_level_call(ir, contract)
# Convert HighLevelCall to LowLevelCall # Convert HighLevelCall to LowLevelCall
@ -729,7 +729,7 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
elif isinstance(ir, NewArray): elif isinstance(ir, NewArray):
ir.lvalue.set_type(ir.array_type) ir.lvalue.set_type(ir.array_type)
elif isinstance(ir, NewContract): elif isinstance(ir, NewContract):
contract = node.compilation_unit.get_contract_from_name(ir.contract_name) contract = node.file_scope.get_contract_from_name(ir.contract_name)
ir.lvalue.set_type(UserDefinedType(contract)) ir.lvalue.set_type(UserDefinedType(contract))
elif isinstance(ir, NewElementaryType): elif isinstance(ir, NewElementaryType):
ir.lvalue.set_type(ir.type) ir.lvalue.set_type(ir.type)
@ -1287,7 +1287,7 @@ def convert_to_pop(ir, node):
def look_for_library(contract, ir, using_for, t): def look_for_library(contract, ir, using_for, t):
for destination in using_for[t]: for destination in using_for[t]:
lib_contract = contract.compilation_unit.get_contract_from_name(str(destination)) lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if lib_contract: if lib_contract:
lib_call = LibraryCall( lib_call = LibraryCall(
lib_contract, lib_contract,
@ -1434,7 +1434,7 @@ def _convert_to_structure_to_list(return_type: Type) -> List[Type]:
def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Optional[Contract]): def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Optional[Contract]):
func = None func = None
if isinstance(ir, InternalCall): if isinstance(ir, InternalCall):
candidates: List[Function]
if ir.function_candidates: if ir.function_candidates:
# This path is taken only for SolidityImportPlaceHolder # This path is taken only for SolidityImportPlaceHolder
# Here we have already done a filtering on the potential targets # Here we have already done a filtering on the potential targets
@ -1447,6 +1447,12 @@ def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Option
and f.contract_declarer.name == ir.contract_name and f.contract_declarer.name == ir.contract_name
and len(f.parameters) == len(ir.arguments) and len(f.parameters) == len(ir.arguments)
] ]
for import_statement in contract.file_scope.imports:
if import_statement.alias and import_statement.alias == ir.contract_name:
imported_scope = contract.compilation_unit.get_scope(import_statement.filename)
candidates += list(imported_scope.functions)
func = _find_function_from_parameter(ir, candidates) func = _find_function_from_parameter(ir, candidates)
if not func: if not func:

@ -50,7 +50,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
@property @property
def contract_created(self): def contract_created(self):
contract_name = self.contract_name contract_name = self.contract_name
contract_instance = self.compilation_unit.get_contract_from_name(contract_name) contract_instance = self.node.file_scope.get_contract_from_name(contract_name)
return contract_instance return contract_instance
################################################################################### ###################################################################################

@ -0,0 +1,35 @@
import abc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
class CallerContextExpression(metaclass=abc.ABCMeta):
"""
This class is inherited by all the declarations class that can be used in the expression/type parsing
As a source of context/scope
It is used by any declaration class that can be top-level and require complex parsing
"""
@property
@abc.abstractmethod
def is_compact_ast(self) -> bool:
pass
@property
@abc.abstractmethod
def compilation_unit(self) -> "SlitherCompilationUnit":
pass
@abc.abstractmethod
def get_key(self) -> str:
pass
@property
@abc.abstractmethod
def slither_parser(self) -> "SlitherCompilationUnitSolc":
pass

@ -6,6 +6,7 @@ from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.declarations.event import EventSolc from slither.solc_parsing.declarations.event import EventSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
@ -25,7 +26,7 @@ if TYPE_CHECKING:
# pylint: disable=too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks,too-many-public-methods # pylint: disable=too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks,too-many-public-methods
class ContractSolc: class ContractSolc(CallerContextExpression):
def __init__(self, slither_parser: "SlitherCompilationUnitSolc", contract: Contract, data): def __init__(self, slither_parser: "SlitherCompilationUnitSolc", contract: Contract, data):
# assert slitherSolc.solc_version.startswith('0.4') # assert slitherSolc.solc_version.startswith('0.4')
@ -258,7 +259,7 @@ class ContractSolc:
def _parse_struct(self, struct: Dict): def _parse_struct(self, struct: Dict):
st = StructureContract() st = StructureContract(self._contract.compilation_unit)
st.set_contract(self._contract) st.set_contract(self._contract)
st.set_offset(struct["src"], self._contract.compilation_unit) st.set_offset(struct["src"], self._contract.compilation_unit)
@ -423,7 +424,7 @@ class ContractSolc:
Cls: Callable, Cls: Callable,
Cls_parser: Callable, Cls_parser: Callable,
element_parser: FunctionSolc, element_parser: FunctionSolc,
explored_reference_id: Set[int], explored_reference_id: Set[str],
parser: List[FunctionSolc], parser: List[FunctionSolc],
all_elements: Dict[str, Function], all_elements: Dict[str, Function],
): ):
@ -442,13 +443,13 @@ class ContractSolc:
elem, element_parser.function_not_parsed, self, self.slither_parser elem, element_parser.function_not_parsed, self, self.slither_parser
) )
if ( if (
element_parser.referenced_declaration element_parser.underlying_function.id
and element_parser.referenced_declaration in explored_reference_id and element_parser.underlying_function.id in explored_reference_id
): ):
# Already added from other fathers # Already added from other fathers
return return
if element_parser.referenced_declaration: if element_parser.underlying_function.id:
explored_reference_id.add(element_parser.referenced_declaration) explored_reference_id.add(element_parser.underlying_function.id)
elem_parser.analyze_params() elem_parser.analyze_params()
if isinstance(elem, Modifier): if isinstance(elem, Modifier):
self._contract.compilation_unit.add_modifier(elem) self._contract.compilation_unit.add_modifier(elem)

@ -1,18 +1,22 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from slither.core.declarations.custom_error import CustomError from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.variables.local_variable import LocalVariableSolc from slither.solc_parsing.variables.local_variable import LocalVariableSolc
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
from slither.core.compilation_unit import SlitherCompilationUnit
# Part of the code was copied from the function parsing # Part of the code was copied from the function parsing
# In the long term we should refactor these two classes to merge the duplicated code # In the long term we should refactor these two classes to merge the duplicated code
class CustomErrorSolc: class CustomErrorSolc(CallerContextExpression):
def __init__( def __init__(
self, self,
custom_error: CustomError, custom_error: CustomError,
@ -84,6 +88,10 @@ class CustomErrorSolc:
local_var_parser = LocalVariableSolc(local_var, param) local_var_parser = LocalVariableSolc(local_var, param)
if isinstance(self._custom_error, CustomErrorTopLevel):
local_var_parser.analyze(self)
else:
assert isinstance(self._custom_error, CustomErrorContract)
local_var_parser.analyze(self) local_var_parser.analyze(self)
# see https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location # see https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location
@ -99,3 +107,7 @@ class CustomErrorSolc:
@property @property
def slither_parser(self) -> "SlitherCompilationUnitSolc": def slither_parser(self) -> "SlitherCompilationUnitSolc":
return self._slither_parser return self._slither_parser
@property
def compilation_unit(self) -> "SlitherCompilationUnit":
return self._custom_error.compilation_unit

@ -14,6 +14,7 @@ from slither.core.expressions import AssignmentOperation
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
from slither.solc_parsing.cfg.node import NodeSolc from slither.solc_parsing.cfg.node import NodeSolc
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.exceptions import ParsingError from slither.solc_parsing.exceptions import ParsingError
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.solc_parsing.variables.local_variable import LocalVariableSolc from slither.solc_parsing.variables.local_variable import LocalVariableSolc
@ -43,7 +44,7 @@ def link_underlying_nodes(node1: NodeSolc, node2: NodeSolc):
# pylint: disable=too-many-lines,too-many-branches,too-many-locals,too-many-statements,too-many-instance-attributes # pylint: disable=too-many-lines,too-many-branches,too-many-locals,too-many-statements,too-many-instance-attributes
class FunctionSolc: class FunctionSolc(CallerContextExpression):
# elems = [(type, name)] # elems = [(type, name)]
@ -59,11 +60,9 @@ class FunctionSolc:
self._function = function self._function = function
# Only present if compact AST # Only present if compact AST
self._referenced_declaration: Optional[int] = None
if self.is_compact_ast: if self.is_compact_ast:
self._function.name = function_data["name"] self._function.name = function_data["name"]
if "id" in function_data: if "id" in function_data:
self._referenced_declaration = function_data["id"]
self._function.id = function_data["id"] self._function.id = function_data["id"]
else: else:
self._function.name = function_data["attributes"][self.get_key()] self._function.name = function_data["attributes"][self.get_key()]
@ -125,13 +124,6 @@ class FunctionSolc:
def is_compact_ast(self): def is_compact_ast(self):
return self._slither_parser.is_compact_ast return self._slither_parser.is_compact_ast
@property
def referenced_declaration(self) -> Optional[str]:
"""
Return the compact AST referenced declaration id (None for legacy AST)
"""
return self._referenced_declaration
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -3,15 +3,17 @@
""" """
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from slither.core.declarations.structure import Structure from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.variables.structure_variable import StructureVariable from slither.core.variables.structure_variable import StructureVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.variables.structure_variable import StructureVariableSolc from slither.solc_parsing.variables.structure_variable import StructureVariableSolc
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
class StructureTopLevelSolc: # pylint: disable=too-few-public-methods class StructureTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-public-methods
""" """
Structure class Structure class
""" """
@ -20,7 +22,7 @@ class StructureTopLevelSolc: # pylint: disable=too-few-public-methods
def __init__( # pylint: disable=too-many-arguments def __init__( # pylint: disable=too-many-arguments
self, self,
st: Structure, st: StructureTopLevel,
struct: Dict, struct: Dict,
slither_parser: "SlitherCompilationUnitSolc", slither_parser: "SlitherCompilationUnitSolc",
): ):
@ -52,8 +54,27 @@ class StructureTopLevelSolc: # pylint: disable=too-few-public-methods
elem.set_offset(elem_to_parse["src"], self._slither_parser.compilation_unit) elem.set_offset(elem_to_parse["src"], self._slither_parser.compilation_unit)
elem_parser = StructureVariableSolc(elem, elem_to_parse) elem_parser = StructureVariableSolc(elem, elem_to_parse)
elem_parser.analyze(self._slither_parser) elem_parser.analyze(self)
self._structure.elems[elem.name] = elem self._structure.elems[elem.name] = elem
self._structure.add_elem_in_order(elem.name) self._structure.add_elem_in_order(elem.name)
self._elemsNotParsed = [] self._elemsNotParsed = []
@property
def is_compact_ast(self) -> bool:
return self._slither_parser.is_compact_ast
@property
def compilation_unit(self) -> SlitherCompilationUnit:
return self._slither_parser.compilation_unit
def get_key(self) -> str:
return self._slither_parser.get_key()
@property
def slither_parser(self) -> "SlitherCompilationUnitSolc":
return self._slither_parser
@property
def underlying_structure(self) -> StructureTopLevel:
return self._structure

@ -33,8 +33,9 @@ from slither.core.solidity_types import (
ArrayType, ArrayType,
ElementaryType, ElementaryType,
) )
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.exceptions import ParsingError, VariableNotFound
from slither.solc_parsing.expressions.find_variable import CallerContext, find_variable from slither.solc_parsing.expressions.find_variable import find_variable
from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type
if TYPE_CHECKING: if TYPE_CHECKING:
@ -196,7 +197,7 @@ def parse_super_name(expression: Dict, is_compact_ast: bool) -> str:
def _parse_elementary_type_name_expression( def _parse_elementary_type_name_expression(
expression: Dict, is_compact_ast: bool, caller_context expression: Dict, is_compact_ast: bool, caller_context: CallerContextExpression
) -> ElementaryTypeNameExpression: ) -> ElementaryTypeNameExpression:
# nop exression # nop exression
# uint; # uint;
@ -216,7 +217,12 @@ def _parse_elementary_type_name_expression(
return e return e
def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expression": if TYPE_CHECKING:
from slither.core.scope.scope import FileScope
def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression":
# pylint: disable=too-many-nested-blocks,too-many-statements # pylint: disable=too-many-nested-blocks,too-many-statements
""" """
@ -246,6 +252,7 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres
# | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression
# | PrimaryExpression # | PrimaryExpression
# The AST naming does not follow the spec # The AST naming does not follow the spec
assert isinstance(caller_context, CallerContextExpression)
name = expression[caller_context.get_key()] name = expression[caller_context.get_key()]
is_compact_ast = caller_context.is_compact_ast is_compact_ast = caller_context.is_compact_ast
src = expression["src"] src = expression["src"]

@ -13,6 +13,7 @@ from slither.core.declarations.solidity_variables import (
SolidityFunction, SolidityFunction,
SolidityVariable, SolidityVariable,
) )
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import ( from slither.core.solidity_types import (
ArrayType, ArrayType,
FunctionType, FunctionType,
@ -20,18 +21,17 @@ from slither.core.solidity_types import (
) )
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.exceptions import SlitherError from slither.exceptions import SlitherError
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.exceptions import VariableNotFound from slither.solc_parsing.exceptions import VariableNotFound
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
from slither.core.compilation_unit import SlitherCompilationUnit
# pylint: disable=import-outside-toplevel,too-many-branches,too-many-locals # pylint: disable=import-outside-toplevel,too-many-branches,too-many-locals
CallerContext = Union["ContractSolc", "FunctionSolc"] # CallerContext =Union["ContractSolc", "FunctionSolc", "CustomErrorSolc", "StructureTopLevelSolc"]
def _get_pointer_name(variable: Variable): def _get_pointer_name(variable: Variable):
@ -51,7 +51,7 @@ def _get_pointer_name(variable: Variable):
def _find_variable_from_ref_declaration( def _find_variable_from_ref_declaration(
referenced_declaration: Optional[int], referenced_declaration: Optional[int],
all_contracts: List["Contract"], all_contracts: List["Contract"],
all_functions_parser: List["FunctionSolc"], all_functions: List["Function"],
) -> Optional[Union[Contract, Function]]: ) -> Optional[Union[Contract, Function]]:
if referenced_declaration is None: if referenced_declaration is None:
return None return None
@ -59,14 +59,11 @@ def _find_variable_from_ref_declaration(
# This is not true for the functions, as we dont always have the referenced_declaration # This is not true for the functions, as we dont always have the referenced_declaration
# But maybe we could? (TODO) # But maybe we could? (TODO)
for contract_candidate in all_contracts: for contract_candidate in all_contracts:
if contract_candidate.id == referenced_declaration: if contract_candidate and contract_candidate.id == referenced_declaration:
return contract_candidate return contract_candidate
for function_candidate in all_functions_parser: for function_candidate in all_functions:
if ( if function_candidate.id == referenced_declaration and not function_candidate.is_shadowed:
function_candidate.referenced_declaration == referenced_declaration return function_candidate
and not function_candidate.underlying_function.is_shadowed
):
return function_candidate.underlying_function
return None return None
@ -100,7 +97,7 @@ def _find_variable_in_function_parser(
def _find_top_level( def _find_top_level(
var_name: str, sl: "SlitherCompilationUnit" var_name: str, scope: "FileScope"
) -> Tuple[Optional[Union[Enum, Structure, SolidityImportPlaceHolder, CustomError]], bool]: ) -> Tuple[Optional[Union[Enum, Structure, SolidityImportPlaceHolder, CustomError]], bool]:
""" """
Return the top level variable use, and a boolean indicating if the variable returning was cretead Return the top level variable use, and a boolean indicating if the variable returning was cretead
@ -113,23 +110,19 @@ def _find_top_level(
:return: :return:
:rtype: :rtype:
""" """
structures_top_level = sl.structures_top_level
for st in structures_top_level:
if st.name == var_name:
return st, False
enums_top_level = sl.enums_top_level if var_name in scope.structures:
for enum in enums_top_level: return scope.structures[var_name], False
if enum.name == var_name:
return enum, False
for import_directive in sl.import_directives: if var_name in scope.enums:
return scope.enums[var_name], False
for import_directive in scope.imports:
if import_directive.alias == var_name: if import_directive.alias == var_name:
new_val = SolidityImportPlaceHolder(import_directive) new_val = SolidityImportPlaceHolder(import_directive)
return new_val, True return new_val, True
# Note for now solidity prevent two custom error from having the same name for custom_error in scope.custom_errors:
for custom_error in sl.custom_errors:
if custom_error.solidity_signature == var_name: if custom_error.solidity_signature == var_name:
return custom_error, False return custom_error, False
@ -142,7 +135,6 @@ def _find_in_contract(
contract_declarer: Optional[Contract], contract_declarer: Optional[Contract],
is_super: bool, is_super: bool,
) -> Optional[Union[Variable, Function, Contract, Event, Enum, Structure, CustomError]]: ) -> Optional[Union[Variable, Function, Contract, Event, Enum, Structure, CustomError]]:
if contract is None or contract_declarer is None: if contract is None or contract_declarer is None:
return None return None
@ -214,56 +206,57 @@ def _find_in_contract(
def _find_variable_init( def _find_variable_init(
caller_context: CallerContext, caller_context: CallerContextExpression,
) -> Tuple[ ) -> Tuple[List[Contract], List["Function"], FileScope,]:
List[Contract],
Union[List["FunctionSolc"]],
"SlitherCompilationUnit",
"SlitherCompilationUnitSolc",
]:
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
direct_contracts: List[Contract] direct_contracts: List[Contract]
direct_functions_parser: List[FunctionSolc] direct_functions_parser: List[Function]
scope: FileScope
if isinstance(caller_context, SlitherCompilationUnitSolc): if isinstance(caller_context, FileScope):
direct_contracts = [] direct_contracts = []
direct_functions_parser = [] direct_functions_parser = []
sl = caller_context.compilation_unit scope = caller_context
sl_parser = caller_context
elif isinstance(caller_context, ContractSolc): elif isinstance(caller_context, ContractSolc):
direct_contracts = [caller_context.underlying_contract] direct_contracts = [caller_context.underlying_contract]
direct_functions_parser = caller_context.functions_parser + caller_context.modifiers_parser direct_functions_parser = [
sl = caller_context.slither_parser.compilation_unit f.underlying_function
sl_parser = caller_context.slither_parser for f in caller_context.functions_parser + caller_context.modifiers_parser
]
scope = caller_context.underlying_contract.file_scope
elif isinstance(caller_context, FunctionSolc): elif isinstance(caller_context, FunctionSolc):
if caller_context.contract_parser: if caller_context.contract_parser:
direct_contracts = [caller_context.contract_parser.underlying_contract] direct_contracts = [caller_context.contract_parser.underlying_contract]
direct_functions_parser = ( direct_functions_parser = [
caller_context.contract_parser.functions_parser f.underlying_function
for f in caller_context.contract_parser.functions_parser
+ caller_context.contract_parser.modifiers_parser + caller_context.contract_parser.modifiers_parser
) ]
else: else:
# Top level functions # Top level functions
direct_contracts = [] direct_contracts = []
direct_functions_parser = [] direct_functions_parser = []
sl = caller_context.underlying_function.compilation_unit underlying_function = caller_context.underlying_function
sl_parser = caller_context.slither_parser if isinstance(underlying_function, FunctionTopLevel):
scope = underlying_function.file_scope
else:
assert isinstance(underlying_function, FunctionContract)
scope = underlying_function.contract.file_scope
else: else:
raise SlitherError( raise SlitherError(
f"{type(caller_context)} ({caller_context} is not valid for find_variable" f"{type(caller_context)} ({caller_context} is not valid for find_variable"
) )
return direct_contracts, direct_functions_parser, sl, sl_parser return direct_contracts, direct_functions_parser, scope
def find_variable( def find_variable(
var_name: str, var_name: str,
caller_context: CallerContext, caller_context: CallerContextExpression,
referenced_declaration: Optional[int] = None, referenced_declaration: Optional[int] = None,
is_super=False, is_super: bool = False,
) -> Tuple[ ) -> Tuple[
Union[ Union[
Variable, Variable,
@ -311,28 +304,25 @@ def find_variable(
# for events it's unclear what should be the behavior, as they can be shadowed, but there is not impact # for events it's unclear what should be the behavior, as they can be shadowed, but there is not impact
# structure/enums cannot be shadowed # structure/enums cannot be shadowed
direct_contracts, direct_functions_parser, sl, sl_parser = _find_variable_init(caller_context) direct_contracts, direct_functions, current_scope = _find_variable_init(caller_context)
all_contracts = sl.contracts
all_functions_parser = sl_parser.all_functions_and_modifiers_parser
# Only look for reference declaration in the direct contract, see comment at the end # Only look for reference declaration in the direct contract, see comment at the end
# Reference looked are split between direct and all # Reference looked are split between direct and all
# Because functions are copied between contracts, two functions can have the same ref # Because functions are copied between contracts, two functions can have the same ref
# So we need to first look with respect to the direct context # So we need to first look with respect to the direct context
ret = _find_variable_from_ref_declaration( # Use ret0/ret1 to help mypy
referenced_declaration, direct_contracts, direct_functions_parser ret0 = _find_variable_from_ref_declaration(
referenced_declaration, direct_contracts, direct_functions
) )
if ret: if ret0:
return ret, False return ret0, False
function_parser: Optional[FunctionSolc] = ( function_parser: Optional[FunctionSolc] = (
caller_context if isinstance(caller_context, FunctionSolc) else None caller_context if isinstance(caller_context, FunctionSolc) else None
) )
ret = _find_variable_in_function_parser(var_name, function_parser, referenced_declaration) ret1 = _find_variable_in_function_parser(var_name, function_parser, referenced_declaration)
if ret: if ret1:
return ret, False return ret1, False
contract: Optional[Contract] = None contract: Optional[Contract] = None
contract_declarer: Optional[Contract] = None contract_declarer: Optional[Contract] = None
@ -352,12 +342,12 @@ def find_variable(
return ret, False return ret, False
# Could refer to any enum # Could refer to any enum
all_enumss = [c.enums_as_dict for c in sl.contracts] all_enumss = [c.enums_as_dict for c in current_scope.contracts.values()]
all_enums = {k: v for d in all_enumss for k, v in d.items()} all_enums = {k: v for d in all_enumss for k, v in d.items()}
if var_name in all_enums: if var_name in all_enums:
return all_enums[var_name], False return all_enums[var_name], False
contracts = sl.contracts_as_dict contracts = current_scope.contracts
if var_name in contracts: if var_name in contracts:
return contracts[var_name], False return contracts[var_name], False
@ -368,7 +358,7 @@ def find_variable(
return SolidityFunction(var_name), False return SolidityFunction(var_name), False
# Top level must be at the end, if nothing else was found # Top level must be at the end, if nothing else was found
ret, var_was_created = _find_top_level(var_name, sl) ret, var_was_created = _find_top_level(var_name, current_scope)
if ret: if ret:
return ret, var_was_created return ret, var_was_created
@ -394,7 +384,9 @@ def find_variable(
# get's AST will say that the ref declaration for _f() is A._f(), but in the context of B, its not # get's AST will say that the ref declaration for _f() is A._f(), but in the context of B, its not
ret = _find_variable_from_ref_declaration( ret = _find_variable_from_ref_declaration(
referenced_declaration, all_contracts, all_functions_parser referenced_declaration,
list(current_scope.contracts.values()),
list(current_scope.functions),
) )
if ret: if ret:
return ret, False return ret, False

@ -120,7 +120,7 @@ class SlitherCompilationUnitSolc:
return True return True
return False return False
def _parse_enum(self, top_level_data: Dict): def _parse_enum(self, top_level_data: Dict, filename: str):
if self.is_compact_ast: if self.is_compact_ast:
name = top_level_data["name"] name = top_level_data["name"]
canonicalName = top_level_data["canonicalName"] canonicalName = top_level_data["canonicalName"]
@ -143,7 +143,9 @@ class SlitherCompilationUnitSolc:
else: else:
values.append(child["attributes"][self.get_key()]) values.append(child["attributes"][self.get_key()])
enum = EnumTopLevel(name, canonicalName, values) scope = self.compilation_unit.get_scope(filename)
enum = EnumTopLevel(name, canonicalName, values, scope)
scope.enums[name] = enum
enum.set_offset(top_level_data["src"], self._compilation_unit) enum.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.enums_top_level.append(enum) self._compilation_unit.enums_top_level.append(enum)
@ -169,10 +171,13 @@ class SlitherCompilationUnitSolc:
if self.get_children() not in data_loaded: if self.get_children() not in data_loaded:
return return
scope = self.compilation_unit.get_scope(filename)
for top_level_data in data_loaded[self.get_children()]: for top_level_data in data_loaded[self.get_children()]:
if top_level_data[self.get_key()] == "ContractDefinition": if top_level_data[self.get_key()] == "ContractDefinition":
contract = Contract(self._compilation_unit) contract = Contract(self._compilation_unit, scope)
contract_parser = ContractSolc(self, contract, top_level_data) contract_parser = ContractSolc(self, contract, top_level_data)
scope.contracts[contract.name] = contract
if "src" in top_level_data: if "src" in top_level_data:
contract.set_offset(top_level_data["src"], self._compilation_unit) contract.set_offset(top_level_data["src"], self._compilation_unit)
@ -180,29 +185,33 @@ class SlitherCompilationUnitSolc:
elif top_level_data[self.get_key()] == "PragmaDirective": elif top_level_data[self.get_key()] == "PragmaDirective":
if self._is_compact_ast: if self._is_compact_ast:
pragma = Pragma(top_level_data["literals"]) pragma = Pragma(top_level_data["literals"], scope)
scope.pragmas.add(pragma)
else: else:
pragma = Pragma(top_level_data["attributes"]["literals"]) pragma = Pragma(top_level_data["attributes"]["literals"], scope)
scope.pragmas.add(pragma)
pragma.set_offset(top_level_data["src"], self._compilation_unit) pragma.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.pragma_directives.append(pragma) self._compilation_unit.pragma_directives.append(pragma)
elif top_level_data[self.get_key()] == "ImportDirective": elif top_level_data[self.get_key()] == "ImportDirective":
if self.is_compact_ast: if self.is_compact_ast:
import_directive = Import( import_directive = Import(
Path( Path(
self._compilation_unit.crytic_compile.working_dir,
top_level_data["absolutePath"], top_level_data["absolutePath"],
),
scope,
) )
) scope.imports.add(import_directive)
# TODO investigate unitAlias in version < 0.7 and legacy ast # TODO investigate unitAlias in version < 0.7 and legacy ast
if "unitAlias" in top_level_data: if "unitAlias" in top_level_data:
import_directive.alias = top_level_data["unitAlias"] import_directive.alias = top_level_data["unitAlias"]
else: else:
import_directive = Import( import_directive = Import(
Path( Path(
self._compilation_unit.crytic_compile.working_dir,
top_level_data["attributes"].get("absolutePath", ""), top_level_data["attributes"].get("absolutePath", ""),
),
scope,
) )
) scope.imports.add(import_directive)
# TODO investigate unitAlias in version < 0.7 and legacy ast # TODO investigate unitAlias in version < 0.7 and legacy ast
if ( if (
"attributes" in top_level_data "attributes" in top_level_data
@ -212,17 +221,22 @@ class SlitherCompilationUnitSolc:
import_directive.set_offset(top_level_data["src"], self._compilation_unit) import_directive.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.import_directives.append(import_directive) self._compilation_unit.import_directives.append(import_directive)
get_imported_scope = self.compilation_unit.get_scope(import_directive.filename)
scope.accessible_scopes.append(get_imported_scope)
elif top_level_data[self.get_key()] == "StructDefinition": elif top_level_data[self.get_key()] == "StructDefinition":
st = StructureTopLevel() scope = self.compilation_unit.get_scope(filename)
st = StructureTopLevel(self.compilation_unit, scope)
st.set_offset(top_level_data["src"], self._compilation_unit) st.set_offset(top_level_data["src"], self._compilation_unit)
st_parser = StructureTopLevelSolc(st, top_level_data, self) st_parser = StructureTopLevelSolc(st, top_level_data, self)
scope.structures[st.name] = st
self._compilation_unit.structures_top_level.append(st) self._compilation_unit.structures_top_level.append(st)
self._structures_top_level_parser.append(st_parser) self._structures_top_level_parser.append(st_parser)
elif top_level_data[self.get_key()] == "EnumDefinition": elif top_level_data[self.get_key()] == "EnumDefinition":
# Note enum don't need a complex parser, so everything is directly done # Note enum don't need a complex parser, so everything is directly done
self._parse_enum(top_level_data) self._parse_enum(top_level_data, filename)
elif top_level_data[self.get_key()] == "VariableDeclaration": elif top_level_data[self.get_key()] == "VariableDeclaration":
var = TopLevelVariable() var = TopLevelVariable()
@ -232,7 +246,9 @@ class SlitherCompilationUnitSolc:
self._compilation_unit.variables_top_level.append(var) self._compilation_unit.variables_top_level.append(var)
self._variables_top_level_parser.append(var_parser) self._variables_top_level_parser.append(var_parser)
elif top_level_data[self.get_key()] == "FunctionDefinition": elif top_level_data[self.get_key()] == "FunctionDefinition":
func = FunctionTopLevel(self._compilation_unit) scope = self.compilation_unit.get_scope(filename)
func = FunctionTopLevel(self._compilation_unit, scope)
scope.functions.add(func)
func.set_offset(top_level_data["src"], self._compilation_unit) func.set_offset(top_level_data["src"], self._compilation_unit)
func_parser = FunctionSolc(func, top_level_data, None, self) func_parser = FunctionSolc(func, top_level_data, None, self)
@ -241,10 +257,12 @@ class SlitherCompilationUnitSolc:
self.add_function_or_modifier_parser(func_parser) self.add_function_or_modifier_parser(func_parser)
elif top_level_data[self.get_key()] == "ErrorDefinition": elif top_level_data[self.get_key()] == "ErrorDefinition":
custom_error = CustomErrorTopLevel(self._compilation_unit) scope = self.compilation_unit.get_scope(filename)
custom_error = CustomErrorTopLevel(self._compilation_unit, scope)
custom_error.set_offset(top_level_data["src"], self._compilation_unit) custom_error.set_offset(top_level_data["src"], self._compilation_unit)
custom_error_parser = CustomErrorSolc(custom_error, top_level_data, self) custom_error_parser = CustomErrorSolc(custom_error, top_level_data, self)
scope.custom_errors.add(custom_error)
self._compilation_unit.custom_errors.append(custom_error) self._compilation_unit.custom_errors.append(custom_error)
self._custom_error_parser.append(custom_error_parser) self._custom_error_parser.append(custom_error_parser)
@ -322,17 +340,8 @@ class SlitherCompilationUnitSolc:
Please rename it, this name is reserved for Slither's internals""" Please rename it, this name is reserved for Slither's internals"""
# endregion multi-line # endregion multi-line
) )
if contract.name in self._compilation_unit.contracts_as_dict:
if contract.id != self._compilation_unit.contracts_as_dict[contract.name].id:
self._compilation_unit.contract_name_collisions[contract.name].append(
contract.source_mapping_str
)
self._compilation_unit.contract_name_collisions[contract.name].append(
self._compilation_unit.contracts_as_dict[contract.name].source_mapping_str
)
else:
self._contracts_by_id[contract.id] = contract self._contracts_by_id[contract.id] = contract
self._compilation_unit.contracts_as_dict[contract.name] = contract self._compilation_unit.contracts.append(contract)
# Update of the inheritance # Update of the inheritance
for contract_parser in self._underlying_contract_to_parser.values(): for contract_parser in self._underlying_contract_to_parser.values():
@ -347,7 +356,10 @@ Please rename it, this name is reserved for Slither's internals"""
for i in contract_parser.linearized_base_contracts[1:]: for i in contract_parser.linearized_base_contracts[1:]:
if i in contract_parser.remapping: if i in contract_parser.remapping:
ancestors.append( ancestors.append(
self._compilation_unit.get_contract_from_name(contract_parser.remapping[i]) contract_parser.underlying_contract.file_scope.get_contract_from_name(
contract_parser.remapping[i]
)
# self._compilation_unit.get_contract_from_name(contract_parser.remapping[i])
) )
elif i in self._contracts_by_id: elif i in self._contracts_by_id:
ancestors.append(self._contracts_by_id[i]) ancestors.append(self._contracts_by_id[i])
@ -358,7 +370,10 @@ Please rename it, this name is reserved for Slither's internals"""
for i in contract_parser.baseContracts: for i in contract_parser.baseContracts:
if i in contract_parser.remapping: if i in contract_parser.remapping:
fathers.append( fathers.append(
self._compilation_unit.get_contract_from_name(contract_parser.remapping[i]) contract_parser.underlying_contract.file_scope.get_contract_from_name(
contract_parser.remapping[i]
)
# self._compilation_unit.get_contract_from_name(contract_parser.remapping[i])
) )
elif i in self._contracts_by_id: elif i in self._contracts_by_id:
fathers.append(self._contracts_by_id[i]) fathers.append(self._contracts_by_id[i])
@ -369,7 +384,10 @@ Please rename it, this name is reserved for Slither's internals"""
for i in contract_parser.baseConstructorContractsCalled: for i in contract_parser.baseConstructorContractsCalled:
if i in contract_parser.remapping: if i in contract_parser.remapping:
father_constructors.append( father_constructors.append(
self._compilation_unit.get_contract_from_name(contract_parser.remapping[i]) contract_parser.underlying_contract.file_scope.get_contract_from_name(
contract_parser.remapping[i]
)
# self._compilation_unit.get_contract_from_name(contract_parser.remapping[i])
) )
elif i in self._contracts_by_id: elif i in self._contracts_by_id:
father_constructors.append(self._contracts_by_id[i]) father_constructors.append(self._contracts_by_id[i])

@ -1,7 +1,9 @@
import logging import logging
import re import re
from typing import List, TYPE_CHECKING, Union, Dict from typing import List, TYPE_CHECKING, Union, Dict, ValuesView
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.expressions.literal import Literal from slither.core.expressions.literal import Literal
from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.array_type import ArrayType
@ -16,11 +18,13 @@ from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.function_type_variable import FunctionTypeVariable
from slither.exceptions import SlitherError from slither.exceptions import SlitherError
from slither.solc_parsing.exceptions import ParsingError from slither.solc_parsing.exceptions import ParsingError
from slither.solc_parsing.expressions.expression_parsing import CallerContextExpression
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Structure, Enum from slither.core.declarations import Structure, Enum
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
logger = logging.getLogger("TypeParsing") logger = logging.getLogger("TypeParsing")
@ -41,9 +45,9 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t
functions_direct_access: List["Function"], functions_direct_access: List["Function"],
contracts_direct_access: List["Contract"], contracts_direct_access: List["Contract"],
structures_direct_access: List["Structure"], structures_direct_access: List["Structure"],
all_structures: List["Structure"], all_structures: ValuesView["Structure"],
enums_direct_access: List["Enum"], enums_direct_access: List["Enum"],
all_enums: List["Enum"], all_enums: ValuesView["Enum"],
) -> Type: ) -> Type:
name_elementary = name.split(" ")[0] name_elementary = name.split(" ")[0]
if "[" in name_elementary: if "[" in name_elementary:
@ -190,12 +194,22 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t
return UserDefinedType(var_type) return UserDefinedType(var_type)
# TODO: since the add of FileScope, we can probably refactor this function and makes it a lot simpler
def parse_type( def parse_type(
t: Union[Dict, UnknownType], t: Union[Dict, UnknownType],
caller_context: Union[ caller_context: Union[CallerContextExpression, "SlitherCompilationUnitSolc"],
"SlitherCompilationUnitSolc", "FunctionSolc", "ContractSolc", "CustomSolc"
],
): ):
"""
caller_context can be a SlitherCompilationUnitSolc because we recursively call the function
and go up in the context's scope. If we are really lost we just go over the SlitherCompilationUnitSolc
:param t:
:type t:
:param caller_context:
:type caller_context:
:return:
:rtype:
"""
# local import to avoid circular dependency # local import to avoid circular dependency
# pylint: disable=too-many-locals,too-many-branches,too-many-statements # pylint: disable=too-many-locals,too-many-branches,too-many-statements
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -204,21 +218,19 @@ def parse_type(
from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
sl: "SlitherCompilationUnit" sl: "SlitherCompilationUnit"
# Note: for convenicence top level functions use the same parser than function in contract # Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None # but contract_parser is set to None
if isinstance(caller_context, (SlitherCompilationUnitSolc, CustomErrorSolc)) or ( if isinstance(caller_context, SlitherCompilationUnitSolc) or (
isinstance(caller_context, FunctionSolc) and caller_context.contract_parser is None isinstance(caller_context, FunctionSolc) and caller_context.contract_parser is None
): ):
structures_direct_access: List["Structure"] structures_direct_access: List["Structure"]
if isinstance(caller_context, SlitherCompilationUnitSolc): if isinstance(caller_context, SlitherCompilationUnitSolc):
sl = caller_context.compilation_unit sl = caller_context.compilation_unit
next_context = caller_context next_context = caller_context
elif isinstance(caller_context, CustomErrorSolc):
sl = caller_context.underlying_custom_error.compilation_unit
next_context = caller_context.slither_parser
else: else:
assert isinstance(caller_context, FunctionSolc) assert isinstance(caller_context, FunctionSolc)
sl = caller_context.underlying_function.compilation_unit sl = caller_context.underlying_function.compilation_unit
@ -233,6 +245,25 @@ def parse_type(
all_enums += enums_direct_access all_enums += enums_direct_access
contracts = sl.contracts contracts = sl.contracts
functions = [] functions = []
elif isinstance(caller_context, (StructureTopLevelSolc, CustomErrorSolc)):
if isinstance(caller_context, StructureTopLevelSolc):
scope = caller_context.underlying_structure.file_scope
else:
assert isinstance(caller_context, CustomErrorSolc)
custom_error = caller_context.underlying_custom_error
if isinstance(custom_error, CustomErrorTopLevel):
scope = custom_error.file_scope
else:
assert isinstance(custom_error, CustomErrorContract)
scope = custom_error.contract.file_scope
next_context = caller_context.slither_parser
structures_direct_access = []
all_structures = scope.structures.values()
enums_direct_access = []
all_enums = scope.enums.values()
contracts = scope.contracts.values()
functions = list(scope.functions)
elif isinstance(caller_context, (ContractSolc, FunctionSolc)): elif isinstance(caller_context, (ContractSolc, FunctionSolc)):
if isinstance(caller_context, FunctionSolc): if isinstance(caller_context, FunctionSolc):
underlying_func = caller_context.underlying_function underlying_func = caller_context.underlying_function
@ -246,16 +277,16 @@ def parse_type(
next_context = caller_context next_context = caller_context
structures_direct_access = contract.structures structures_direct_access = contract.structures
structures_direct_access += contract.compilation_unit.structures_top_level structures_direct_access += contract.file_scope.structures.values()
all_structuress = [c.structures for c in contract.compilation_unit.contracts] all_structuress = [c.structures for c in contract.file_scope.contracts.values()]
all_structures = [item for sublist in all_structuress for item in sublist] all_structures = [item for sublist in all_structuress for item in sublist]
all_structures += contract.compilation_unit.structures_top_level all_structures += contract.file_scope.structures.values()
enums_direct_access: List["Enum"] = contract.enums enums_direct_access: List["Enum"] = contract.enums
enums_direct_access += contract.compilation_unit.enums_top_level enums_direct_access += contract.file_scope.enums.values()
all_enumss = [c.enums for c in contract.compilation_unit.contracts] all_enumss = [c.enums for c in contract.file_scope.contracts.values()]
all_enums = [item for sublist in all_enumss for item in sublist] all_enums = [item for sublist in all_enumss for item in sublist]
all_enums += contract.compilation_unit.enums_top_level all_enums += contract.file_scope.enums.values()
contracts = contract.compilation_unit.contracts contracts = contract.file_scope.contracts.values()
functions = contract.functions + contract.modifiers functions = contract.functions + contract.modifiers
else: else:
raise ParsingError(f"Incorrect caller context: {type(caller_context)}") raise ParsingError(f"Incorrect caller context: {type(caller_context)}")

@ -1,6 +1,7 @@
import logging import logging
from typing import Dict from typing import Dict
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
@ -179,7 +180,7 @@ class VariableDeclarationSolc:
self._variable.initialized = True self._variable.initialized = True
self._initializedNotParsed = var["children"][1] self._initializedNotParsed = var["children"][1]
def analyze(self, caller_context): def analyze(self, caller_context: CallerContextExpression):
# Can be re-analyzed due to inheritance # Can be re-analyzed due to inheritance
if self._was_analyzed: if self._was_analyzed:
return return

@ -223,7 +223,7 @@ class YulFunction(YulScope):
func.set_contract(root.contract) func.set_contract(root.contract)
func.set_contract_declarer(root.contract) func.set_contract_declarer(root.contract)
func.compilation_unit = root.compilation_unit func.compilation_unit = root.compilation_unit
func.scope = root.id func.internal_scope = root.id
func.is_implemented = True func.is_implemented = True
self.node_scope = node_scope self.node_scope = node_scope
@ -359,8 +359,14 @@ def convert_yul_function_definition(
while not isinstance(top_node_scope, Function): while not isinstance(top_node_scope, Function):
top_node_scope = top_node_scope.father top_node_scope = top_node_scope.father
assert isinstance(top_node_scope, (FunctionTopLevel, FunctionContract)) if isinstance(top_node_scope, FunctionTopLevel):
func = type(top_node_scope)(root.compilation_unit) scope = root.contract.file_scope
func = FunctionTopLevel(root.compilation_unit, scope)
# Note: we do not add the function in the scope
# While its a top level function, it is not accessible outside of the function definition
# In practice we should probably have a specific function type for function defined within a function
else:
func = FunctionContract(root.compilation_unit)
func.function_language = FunctionLanguage.Yul func.function_language = FunctionLanguage.Yul
yul_function = YulFunction(func, root, ast, node_scope) yul_function = YulFunction(func, root, ast, node_scope)

@ -71,7 +71,7 @@ Consider using a `Initializable` contract to follow [standard practice](https://
# endregion wiki_recommendation # endregion wiki_recommendation
def _check(self): def _check(self):
initializable = self.contract.compilation_unit.get_contract_from_name("Initializable") initializable = self.contract.file_scope.get_contract_from_name("Initializable")
if initializable is None: if initializable is None:
info = [ info = [
"Initializable contract not found, the contract does not follow a standard initalization schema.\n" "Initializable contract not found, the contract does not follow a standard initalization schema.\n"
@ -104,7 +104,7 @@ Review manually the contract's initialization. Consider inheriting `Initializabl
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self):
initializable = self.contract.compilation_unit.get_contract_from_name("Initializable") initializable = self.contract.file_scope.get_contract_from_name("Initializable")
# See InitializablePresent # See InitializablePresent
if initializable is None: if initializable is None:
return [] return []
@ -138,7 +138,7 @@ Review manually the contract's initialization. Consider inheriting a `Initializa
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self):
initializable = self.contract.compilation_unit.get_contract_from_name("Initializable") initializable = self.contract.file_scope.get_contract_from_name("Initializable")
# See InitializablePresent # See InitializablePresent
if initializable is None: if initializable is None:
return [] return []
@ -191,7 +191,7 @@ Use `Initializable.initializer()`.
REQUIRE_CONTRACT = True REQUIRE_CONTRACT = True
def _check(self): def _check(self):
initializable = self.contract.compilation_unit.get_contract_from_name("Initializable") initializable = self.contract.file_scope.get_contract_from_name("Initializable")
# See InitializablePresent # See InitializablePresent
if initializable is None: if initializable is None:
return [] return []

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save