Add support for CustomError (#947)

* Add support for CustomError
- Create a new CustomError/CustomErrorTopLevel/CustomErrorContract core objects, and associated parsing classes
- Create a specific solidity function to handle the revert CustomError call
- Fix #919, Fix #893
pull/984/head
Feist Josselin 3 years ago committed by GitHub
parent 154dd9f260
commit df1062902f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      slither/core/compilation_unit.py
  2. 34
      slither/core/declarations/contract.py
  3. 71
      slither/core/declarations/custom_error.py
  4. 12
      slither/core/declarations/custom_error_contract.py
  5. 6
      slither/core/declarations/custom_error_top_level.py
  6. 19
      slither/core/declarations/solidity_variables.py
  7. 3
      slither/printers/inheritance/inheritance.py
  8. 8
      slither/slithir/convert.py
  9. 2
      slither/slithir/tmp_operations/tmp_call.py
  10. 30
      slither/solc_parsing/declarations/contract.py
  11. 101
      slither/solc_parsing/declarations/custom_error.py
  12. 14
      slither/solc_parsing/declarations/function.py
  13. 21
      slither/solc_parsing/expressions/find_variable.py
  14. 22
      slither/solc_parsing/slither_compilation_unit_solc.py
  15. 22
      slither/solc_parsing/solidity_types/type_parsing.py
  16. BIN
      tests/ast-parsing/compile/custom_error-0.4.0-legacy.zip
  17. BIN
      tests/ast-parsing/compile/custom_error-0.4.1-legacy.zip
  18. BIN
      tests/ast-parsing/compile/custom_error-0.4.10-legacy.zip
  19. BIN
      tests/ast-parsing/compile/custom_error-0.4.11-legacy.zip
  20. BIN
      tests/ast-parsing/compile/custom_error-0.4.12-compact.zip
  21. BIN
      tests/ast-parsing/compile/custom_error-0.4.12-legacy.zip
  22. BIN
      tests/ast-parsing/compile/custom_error-0.4.13-compact.zip
  23. BIN
      tests/ast-parsing/compile/custom_error-0.4.13-legacy.zip
  24. BIN
      tests/ast-parsing/compile/custom_error-0.4.14-compact.zip
  25. BIN
      tests/ast-parsing/compile/custom_error-0.4.14-legacy.zip
  26. BIN
      tests/ast-parsing/compile/custom_error-0.4.15-compact.zip
  27. BIN
      tests/ast-parsing/compile/custom_error-0.4.15-legacy.zip
  28. BIN
      tests/ast-parsing/compile/custom_error-0.4.16-compact.zip
  29. BIN
      tests/ast-parsing/compile/custom_error-0.4.16-legacy.zip
  30. BIN
      tests/ast-parsing/compile/custom_error-0.4.17-compact.zip
  31. BIN
      tests/ast-parsing/compile/custom_error-0.4.17-legacy.zip
  32. BIN
      tests/ast-parsing/compile/custom_error-0.4.18-compact.zip
  33. BIN
      tests/ast-parsing/compile/custom_error-0.4.18-legacy.zip
  34. BIN
      tests/ast-parsing/compile/custom_error-0.4.19-compact.zip
  35. BIN
      tests/ast-parsing/compile/custom_error-0.4.19-legacy.zip
  36. BIN
      tests/ast-parsing/compile/custom_error-0.4.2-legacy.zip
  37. BIN
      tests/ast-parsing/compile/custom_error-0.4.20-compact.zip
  38. BIN
      tests/ast-parsing/compile/custom_error-0.4.20-legacy.zip
  39. BIN
      tests/ast-parsing/compile/custom_error-0.4.21-compact.zip
  40. BIN
      tests/ast-parsing/compile/custom_error-0.4.21-legacy.zip
  41. BIN
      tests/ast-parsing/compile/custom_error-0.4.22-compact.zip
  42. BIN
      tests/ast-parsing/compile/custom_error-0.4.22-legacy.zip
  43. BIN
      tests/ast-parsing/compile/custom_error-0.4.23-compact.zip
  44. BIN
      tests/ast-parsing/compile/custom_error-0.4.23-legacy.zip
  45. BIN
      tests/ast-parsing/compile/custom_error-0.4.24-compact.zip
  46. BIN
      tests/ast-parsing/compile/custom_error-0.4.24-legacy.zip
  47. BIN
      tests/ast-parsing/compile/custom_error-0.4.25-compact.zip
  48. BIN
      tests/ast-parsing/compile/custom_error-0.4.25-legacy.zip
  49. BIN
      tests/ast-parsing/compile/custom_error-0.4.26-compact.zip
  50. BIN
      tests/ast-parsing/compile/custom_error-0.4.26-legacy.zip
  51. BIN
      tests/ast-parsing/compile/custom_error-0.4.3-legacy.zip
  52. BIN
      tests/ast-parsing/compile/custom_error-0.4.4-legacy.zip
  53. BIN
      tests/ast-parsing/compile/custom_error-0.4.5-legacy.zip
  54. BIN
      tests/ast-parsing/compile/custom_error-0.4.6-legacy.zip
  55. BIN
      tests/ast-parsing/compile/custom_error-0.4.7-legacy.zip
  56. BIN
      tests/ast-parsing/compile/custom_error-0.4.8-legacy.zip
  57. BIN
      tests/ast-parsing/compile/custom_error-0.4.9-legacy.zip
  58. BIN
      tests/ast-parsing/compile/custom_error-0.5.0-compact.zip
  59. BIN
      tests/ast-parsing/compile/custom_error-0.5.0-legacy.zip
  60. BIN
      tests/ast-parsing/compile/custom_error-0.5.1-compact.zip
  61. BIN
      tests/ast-parsing/compile/custom_error-0.5.1-legacy.zip
  62. BIN
      tests/ast-parsing/compile/custom_error-0.5.10-compact.zip
  63. BIN
      tests/ast-parsing/compile/custom_error-0.5.10-legacy.zip
  64. BIN
      tests/ast-parsing/compile/custom_error-0.5.11-compact.zip
  65. BIN
      tests/ast-parsing/compile/custom_error-0.5.11-legacy.zip
  66. BIN
      tests/ast-parsing/compile/custom_error-0.5.12-compact.zip
  67. BIN
      tests/ast-parsing/compile/custom_error-0.5.12-legacy.zip
  68. BIN
      tests/ast-parsing/compile/custom_error-0.5.13-compact.zip
  69. BIN
      tests/ast-parsing/compile/custom_error-0.5.13-legacy.zip
  70. BIN
      tests/ast-parsing/compile/custom_error-0.5.14-compact.zip
  71. BIN
      tests/ast-parsing/compile/custom_error-0.5.14-legacy.zip
  72. BIN
      tests/ast-parsing/compile/custom_error-0.5.15-compact.zip
  73. BIN
      tests/ast-parsing/compile/custom_error-0.5.15-legacy.zip
  74. BIN
      tests/ast-parsing/compile/custom_error-0.5.16-compact.zip
  75. BIN
      tests/ast-parsing/compile/custom_error-0.5.16-legacy.zip
  76. BIN
      tests/ast-parsing/compile/custom_error-0.5.17-compact.zip
  77. BIN
      tests/ast-parsing/compile/custom_error-0.5.17-legacy.zip
  78. BIN
      tests/ast-parsing/compile/custom_error-0.5.2-compact.zip
  79. BIN
      tests/ast-parsing/compile/custom_error-0.5.2-legacy.zip
  80. BIN
      tests/ast-parsing/compile/custom_error-0.5.3-compact.zip
  81. BIN
      tests/ast-parsing/compile/custom_error-0.5.3-legacy.zip
  82. BIN
      tests/ast-parsing/compile/custom_error-0.5.4-compact.zip
  83. BIN
      tests/ast-parsing/compile/custom_error-0.5.4-legacy.zip
  84. BIN
      tests/ast-parsing/compile/custom_error-0.5.5-compact.zip
  85. BIN
      tests/ast-parsing/compile/custom_error-0.5.5-legacy.zip
  86. BIN
      tests/ast-parsing/compile/custom_error-0.5.6-compact.zip
  87. BIN
      tests/ast-parsing/compile/custom_error-0.5.6-legacy.zip
  88. BIN
      tests/ast-parsing/compile/custom_error-0.5.7-compact.zip
  89. BIN
      tests/ast-parsing/compile/custom_error-0.5.7-legacy.zip
  90. BIN
      tests/ast-parsing/compile/custom_error-0.5.8-compact.zip
  91. BIN
      tests/ast-parsing/compile/custom_error-0.5.8-legacy.zip
  92. BIN
      tests/ast-parsing/compile/custom_error-0.5.9-compact.zip
  93. BIN
      tests/ast-parsing/compile/custom_error-0.5.9-legacy.zip
  94. BIN
      tests/ast-parsing/compile/custom_error-0.6.0-compact.zip
  95. BIN
      tests/ast-parsing/compile/custom_error-0.6.0-legacy.zip
  96. BIN
      tests/ast-parsing/compile/custom_error-0.6.1-compact.zip
  97. BIN
      tests/ast-parsing/compile/custom_error-0.6.1-legacy.zip
  98. BIN
      tests/ast-parsing/compile/custom_error-0.6.10-compact.zip
  99. BIN
      tests/ast-parsing/compile/custom_error-0.6.10-legacy.zip
  100. BIN
      tests/ast-parsing/compile/custom_error-0.6.11-compact.zip
  101. Some files were not shown because too many files have changed in this diff Show More

@ -13,6 +13,7 @@ from slither.core.declarations import (
Function, Function,
Modifier, Modifier,
) )
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel from slither.core.declarations.structure_top_level import StructureTopLevel
@ -40,6 +41,7 @@ class SlitherCompilationUnit(Context):
self._functions_top_level: List[FunctionTopLevel] = [] self._functions_top_level: List[FunctionTopLevel] = []
self._pragma_directives: List[Pragma] = [] self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = [] self._import_directives: List[Import] = []
self._custom_errors: List[CustomError] = []
self._all_functions: Set[Function] = set() self._all_functions: Set[Function] = set()
self._all_modifiers: Set[Modifier] = set() self._all_modifiers: Set[Modifier] = set()
@ -210,6 +212,10 @@ class SlitherCompilationUnit(Context):
def functions_top_level(self) -> List[FunctionTopLevel]: def functions_top_level(self) -> List[FunctionTopLevel]:
return self._functions_top_level return self._functions_top_level
@property
def custom_errors(self) -> List[CustomError]:
return self._custom_errors
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -38,6 +38,7 @@ if TYPE_CHECKING:
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.custom_error_contract import CustomErrorContract
LOGGER = logging.getLogger("Contract") LOGGER = logging.getLogger("Contract")
@ -68,6 +69,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._modifiers: Dict[str, "Modifier"] = {} self._modifiers: Dict[str, "Modifier"] = {}
self._functions: Dict[str, "FunctionContract"] = {} self._functions: Dict[str, "FunctionContract"] = {}
self._linearizedBaseContracts: List[int] = [] self._linearizedBaseContracts: List[int] = []
self._custom_errors: Dict[str:"CustomErrorContract"] = {}
# The only str is "*" # The only str is "*"
self._using_for: Dict[Union[str, Type], List[str]] = {} self._using_for: Dict[Union[str, Type], List[str]] = {}
@ -242,6 +244,38 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
def using_for(self) -> Dict[Union[str, Type], List[str]]: def using_for(self) -> Dict[Union[str, Type], List[str]]:
return self._using_for return self._using_for
# endregion
###################################################################################
###################################################################################
# region Custom Errors
###################################################################################
###################################################################################
@property
def custom_errors(self) -> List["CustomErrorContract"]:
"""
list(CustomErrorContract): List of the contract's custom errors
"""
return list(self._custom_errors.values())
@property
def custom_errors_inherited(self) -> List["CustomErrorContract"]:
"""
list(CustomErrorContract): List of the inherited custom errors
"""
return [s for s in self.custom_errors if s.contract != self]
@property
def custom_errors_declared(self) -> List["CustomErrorContract"]:
"""
list(CustomErrorContract): List of the custom errors declared within the contract (not inherited)
"""
return [s for s in self.custom_errors if s.contract == self]
@property
def custom_errors_as_dict(self) -> Dict[str, "CustomErrorContract"]:
return self._custom_errors
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -0,0 +1,71 @@
from typing import List, TYPE_CHECKING, Optional, Type, Union
from slither.core.solidity_types import UserDefinedType
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.local_variable import LocalVariable
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
class CustomError(SourceMapping):
def __init__(self, compilation_unit: "SlitherCompilationUnit"):
super().__init__()
self._name: str = ""
self._parameters: List[LocalVariable] = []
self._compilation_unit = compilation_unit
self._solidity_signature: Optional[str] = None
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, new_name: str) -> None:
self._name = new_name
@property
def parameters(self) -> List[LocalVariable]:
return self._parameters
def add_parameters(self, p: "LocalVariable"):
self._parameters.append(p)
@property
def compilation_unit(self) -> "SlitherCompilationUnit":
return self._compilation_unit
# region Signature
###################################################################################
###################################################################################
@staticmethod
def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]):
# pylint: disable=import-outside-toplevel
from slither.core.declarations import Contract
if isinstance(t, UserDefinedType) and isinstance(t.type, Contract):
return "address"
return str(t)
@property
def solidity_signature(self) -> str:
"""
Return a signature following the Solidity Standard
Contract and converted into address
:return: the solidity signature
"""
if self._solidity_signature is None:
parameters = [
self._convert_type_for_solidity_signature(x.type) for x in self.parameters
]
self._solidity_signature = self.name + "(" + ",".join(parameters) + ")"
return self._solidity_signature
# endregion
###################################################################################
###################################################################################
def __str__(self):
return "revert " + self.solidity_signature

@ -0,0 +1,12 @@
from slither.core.children.child_contract import ChildContract
from slither.core.declarations.custom_error import CustomError
class CustomErrorContract(CustomError, ChildContract):
def is_declared_by(self, contract):
"""
Check if the element is declared by the contract
:param contract:
:return:
"""
return self.contract == contract

@ -0,0 +1,6 @@
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.top_level import TopLevel
class CustomErrorTopLevel(CustomError, TopLevel):
pass

@ -2,6 +2,7 @@
from typing import List, Dict, Union, TYPE_CHECKING from typing import List, Dict, Union, TYPE_CHECKING
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.declarations.custom_error import CustomError
from slither.core.solidity_types import ElementaryType, TypeInformation from slither.core.solidity_types import ElementaryType, TypeInformation
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
@ -42,6 +43,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = {
"require(bool,string)": [], "require(bool,string)": [],
"revert()": [], "revert()": [],
"revert(string)": [], "revert(string)": [],
"revert ": [],
"addmod(uint256,uint256,uint256)": ["uint256"], "addmod(uint256,uint256,uint256)": ["uint256"],
"mulmod(uint256,uint256,uint256)": ["uint256"], "mulmod(uint256,uint256,uint256)": ["uint256"],
"keccak256()": ["bytes32"], "keccak256()": ["bytes32"],
@ -184,3 +186,20 @@ class SolidityFunction:
def __hash__(self): def __hash__(self):
return hash(self.name) return hash(self.name)
class SolidityCustomRevert(SolidityFunction):
def __init__(self, custom_error: CustomError): # pylint: disable=super-init-not-called
self._name = "revert " + custom_error.solidity_signature
self._custom_error = custom_error
self._return_type: List[Union[TypeInformation, ElementaryType]] = []
def __eq__(self, other):
return (
self.__class__ == other.__class__
and self.name == other.name
and self._custom_error == other._custom_error
)
def __hash__(self):
return hash(hash(self.name) + hash(self._custom_error))

@ -30,9 +30,6 @@ class PrinterInheritance(AbstractPrinter):
""" """
info = "Inheritance\n" info = "Inheritance\n"
if not self.contracts:
return []
info += blue("Child_Contract -> ") + green("Immediate_Base_Contracts") info += blue("Child_Contract -> ") + green("Immediate_Base_Contracts")
info += green(" [Not_Immediate_Base_Contracts]") info += green(" [Not_Immediate_Base_Contracts]")

@ -13,8 +13,10 @@ from slither.core.declarations import (
SolidityVariableComposed, SolidityVariableComposed,
Structure, Structure,
) )
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.declarations.solidity_variables import SolidityCustomRevert
from slither.core.expressions import Identifier, Literal from slither.core.expressions import Identifier, Literal
from slither.core.solidity_types import ( from slither.core.solidity_types import (
ArrayType, ArrayType,
@ -941,6 +943,12 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
s.set_expression(ins.expression) s.set_expression(ins.expression)
return s return s
if isinstance(ins.called, CustomError):
sol_function = SolidityCustomRevert(ins.called)
s = SolidityCall(sol_function, ins.nbr_arguments, ins.lvalue, ins.type_call)
s.set_expression(ins.expression)
return s
if isinstance(ins.ori, TmpNewElementaryType): if isinstance(ins.ori, TmpNewElementaryType):
n = NewElementaryType(ins.ori.type, ins.lvalue) n = NewElementaryType(ins.ori.type, ins.lvalue)
n.set_expression(ins.expression) n.set_expression(ins.expression)

@ -5,6 +5,7 @@ from slither.core.declarations import (
SolidityFunction, SolidityFunction,
Structure, Structure,
) )
from slither.core.declarations.custom_error import CustomError
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
@ -20,6 +21,7 @@ class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attribu
SolidityFunction, SolidityFunction,
Structure, Structure,
Event, Event,
CustomError,
), ),
) )
super().__init__() super().__init__()

@ -3,8 +3,10 @@ from typing import List, Dict, Callable, TYPE_CHECKING, Union, Set
from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.declarations.event import EventSolc from slither.solc_parsing.declarations.event import EventSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.modifier import ModifierSolc from slither.solc_parsing.declarations.modifier import ModifierSolc
@ -35,15 +37,17 @@ class ContractSolc:
self._modifiersNotParsed: List[Dict] = [] self._modifiersNotParsed: List[Dict] = []
self._functions_no_params: List[FunctionSolc] = [] self._functions_no_params: List[FunctionSolc] = []
self._modifiers_no_params: List[ModifierSolc] = [] self._modifiers_no_params: List[ModifierSolc] = []
self._eventsNotParsed: List[EventSolc] = [] self._eventsNotParsed: List[Dict] = []
self._variablesNotParsed: List[Dict] = [] self._variablesNotParsed: List[Dict] = []
self._enumsNotParsed: List[Dict] = [] self._enumsNotParsed: List[Dict] = []
self._structuresNotParsed: List[Dict] = [] self._structuresNotParsed: List[Dict] = []
self._usingForNotParsed: List[Dict] = [] self._usingForNotParsed: List[Dict] = []
self._customErrorParsed: List[Dict] = []
self._functions_parser: List[FunctionSolc] = [] self._functions_parser: List[FunctionSolc] = []
self._modifiers_parser: List[ModifierSolc] = [] self._modifiers_parser: List[ModifierSolc] = []
self._structures_parser: List[StructureContractSolc] = [] self._structures_parser: List[StructureContractSolc] = []
self._custom_errors_parser: List[CustomErrorSolc] = []
self._is_analyzed: bool = False self._is_analyzed: bool = False
@ -246,6 +250,8 @@ class ContractSolc:
self._structuresNotParsed.append(item) self._structuresNotParsed.append(item)
elif item[self.get_key()] == "UsingForDirective": elif item[self.get_key()] == "UsingForDirective":
self._usingForNotParsed.append(item) self._usingForNotParsed.append(item)
elif item[self.get_key()] == "ErrorDefinition":
self._customErrorParsed.append(item)
else: else:
raise ParsingError("Unknown contract item: " + item[self.get_key()]) raise ParsingError("Unknown contract item: " + item[self.get_key()])
return return
@ -268,6 +274,23 @@ class ContractSolc:
self._parse_struct(struct) self._parse_struct(struct)
self._structuresNotParsed = None self._structuresNotParsed = None
def _parse_custom_error(self, custom_error: Dict):
ce = CustomErrorContract(self.compilation_unit)
ce.set_contract(self._contract)
ce.set_offset(custom_error["src"], self.compilation_unit)
ce_parser = CustomErrorSolc(ce, custom_error, self._slither_parser)
self._contract.custom_errors_as_dict[ce.name] = ce
self._custom_errors_parser.append(ce_parser)
def parse_custom_errors(self):
for father in self._contract.inheritance_reverse:
self._contract.custom_errors_as_dict.update(father.custom_errors_as_dict)
for custom_error in self._customErrorParsed:
self._parse_custom_error(custom_error)
self._customErrorParsed = None
def parse_state_variables(self): def parse_state_variables(self):
for father in self._contract.inheritance_reverse: for father in self._contract.inheritance_reverse:
self._contract.variables_as_dict.update(father.variables_as_dict) self._contract.variables_as_dict.update(father.variables_as_dict)
@ -600,6 +623,10 @@ class ContractSolc:
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing struct {e}") self.log_incorrect_parsing(f"Missing struct {e}")
def analyze_custom_errors(self):
for custom_error in self._custom_errors_parser:
custom_error.analyze_params()
def analyze_events(self): def analyze_events(self):
try: try:
for father in self._contract.inheritance_reverse: for father in self._contract.inheritance_reverse:
@ -640,6 +667,7 @@ class ContractSolc:
self._enumsNotParsed = [] self._enumsNotParsed = []
self._structuresNotParsed = [] self._structuresNotParsed = []
self._usingForNotParsed = [] self._usingForNotParsed = []
self._customErrorParsed = []
# endregion # endregion
################################################################################### ###################################################################################

@ -0,0 +1,101 @@
from typing import TYPE_CHECKING, Dict
from slither.core.declarations.custom_error import CustomError
from slither.core.variables.local_variable import LocalVariable
from slither.solc_parsing.variables.local_variable import LocalVariableSolc
if TYPE_CHECKING:
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
# Part of the code was copied from the function parsing
# In the long term we should refactor these two classes to merge the duplicated code
class CustomErrorSolc:
def __init__(
self,
custom_error: CustomError,
custom_error_data: dict,
slither_parser: "SlitherCompilationUnitSolc",
):
self._slither_parser: "SlitherCompilationUnitSolc" = slither_parser
self._custom_error = custom_error
custom_error.name = custom_error_data["name"]
self._params_was_analyzed = False
if not self._slither_parser.is_compact_ast:
custom_error_data = custom_error_data["attributes"]
self._custom_error_data = custom_error_data
def analyze_params(self):
# Can be re-analyzed due to inheritance
if self._params_was_analyzed:
return
self._params_was_analyzed = True
if self._slither_parser.is_compact_ast:
params = self._custom_error_data["parameters"]
else:
children = self._custom_error_data[self.get_children("children")]
# It uses to be
# params = children[0]
# returns = children[1]
# But from Solidity 0.6.3 to 0.6.10 (included)
# Comment above a function might be added in the children
child_iter = iter(
[child for child in children if child[self.get_key()] == "ParameterList"]
)
params = next(child_iter)
if params:
self._parse_params(params)
@property
def is_compact_ast(self) -> bool:
return self._slither_parser.is_compact_ast
def get_key(self) -> str:
return self._slither_parser.get_key()
def get_children(self, key: str) -> str:
if self._slither_parser.is_compact_ast:
return key
return "children"
def _parse_params(self, params: Dict):
assert params[self.get_key()] == "ParameterList"
if self._slither_parser.is_compact_ast:
params = params["parameters"]
else:
params = params[self.get_children("children")]
for param in params:
assert param[self.get_key()] == "VariableDeclaration"
local_var = self._add_param(param)
self._custom_error.add_parameters(local_var.underlying_variable)
def _add_param(self, param: Dict) -> LocalVariableSolc:
local_var = LocalVariable()
local_var.set_offset(param["src"], self._slither_parser.compilation_unit)
local_var_parser = LocalVariableSolc(local_var, param)
local_var_parser.analyze(self)
# see https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location
if local_var.location == "default":
local_var.set_location("memory")
return local_var_parser
@property
def underlying_custom_error(self) -> CustomError:
return self._custom_error
@property
def slither_parser(self) -> "SlitherCompilationUnitSolc":
return self._slither_parser

@ -10,12 +10,11 @@ from slither.core.declarations.function import (
FunctionType, FunctionType,
) )
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.expressions import AssignmentOperation from slither.core.expressions import AssignmentOperation
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple
from slither.solc_parsing.cfg.node import NodeSolc from slither.solc_parsing.cfg.node import NodeSolc
from slither.solc_parsing.exceptions import ParsingError
from slither.solc_parsing.expressions.expression_parsing import parse_expression from slither.solc_parsing.expressions.expression_parsing import parse_expression
from slither.solc_parsing.variables.local_variable import LocalVariableSolc from slither.solc_parsing.variables.local_variable import LocalVariableSolc
from slither.solc_parsing.variables.local_variable_init_from_tuple import ( from slither.solc_parsing.variables.local_variable_init_from_tuple import (
@ -26,13 +25,11 @@ from slither.solc_parsing.yul.parse_yul import YulBlock
from slither.utils.expression_manipulations import SplitTernaryExpression from slither.utils.expression_manipulations import SplitTernaryExpression
from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.export_values import ExportValues
from slither.visitors.expression.has_conditional import HasConditional from slither.visitors.expression.has_conditional import HasConditional
from slither.solc_parsing.exceptions import ParsingError
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
from slither.core.slither_core import SlitherCore
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
@ -1012,6 +1009,15 @@ class FunctionSolc:
node = self._parse_try_catch(statement, node) node = self._parse_try_catch(statement, node)
# elif name == 'TryCatchClause': # elif name == 'TryCatchClause':
# self._parse_catch(statement, node) # self._parse_catch(statement, node)
elif name == "RevertStatement":
if self.is_compact_ast:
expression = statement[self.get_children("errorCall")]
else:
expression = statement[self.get_children("errorCall")][0]
new_node = self._new_node(NodeType.EXPRESSION, statement["src"], scope)
new_node.add_unparsed_expression(expression)
link_underlying_nodes(node, new_node)
node = new_node
else: else:
raise ParsingError("Statement not parsed %s" % name) raise ParsingError("Statement not parsed %s" % name)

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional, Union, List, Tuple
from slither.core.declarations import Event, Enum, Structure from slither.core.declarations import Event, Enum, Structure
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.function import Function from slither.core.declarations.function import Function
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel
@ -100,7 +101,7 @@ def _find_variable_in_function_parser(
def _find_top_level( def _find_top_level(
var_name: str, sl: "SlitherCompilationUnit" var_name: str, sl: "SlitherCompilationUnit"
) -> Tuple[Optional[Union[Enum, Structure, SolidityImportPlaceHolder]], bool]: ) -> Tuple[Optional[Union[Enum, Structure, SolidityImportPlaceHolder, CustomError]], bool]:
""" """
Return the top level variable use, and a boolean indicating if the variable returning was cretead Return the top level variable use, and a boolean indicating if the variable returning was cretead
If the variable was created, it has no source_mapping If the variable was created, it has no source_mapping
@ -127,6 +128,11 @@ def _find_top_level(
new_val = SolidityImportPlaceHolder(import_directive) new_val = SolidityImportPlaceHolder(import_directive)
return new_val, True return new_val, True
# Note for now solidity prevent two custom error from having the same name
for custom_error in sl.custom_errors:
if custom_error.solidity_signature == var_name:
return custom_error, False
return None, False return None, False
@ -135,7 +141,7 @@ def _find_in_contract(
contract: Optional[Contract], contract: Optional[Contract],
contract_declarer: Optional[Contract], contract_declarer: Optional[Contract],
is_super: bool, is_super: bool,
) -> Optional[Union[Variable, Function, Contract, Event, Enum, Structure,]]: ) -> Optional[Union[Variable, Function, Contract, Event, Enum, Structure, CustomError]]:
if contract is None or contract_declarer is None: if contract is None or contract_declarer is None:
return None return None
@ -191,6 +197,14 @@ def _find_in_contract(
if var_name in enums: if var_name in enums:
return enums[var_name] return enums[var_name]
# Note: contract.custom_errors_as_dict uses the name (not the sol sig) as key
# This is because when the dic is populated the underlying object is not yet parsed
# As a result, we need to iterate over all the custom errors here instead of using the dict
custom_errors = contract.custom_errors
for custom_error in custom_errors:
if var_name == custom_error.solidity_signature:
return custom_error
# If the enum is refered as its name rather than its canonicalName # If the enum is refered as its name rather than its canonicalName
enums = {e.name: e for e in contract.enums} enums = {e.name: e for e in contract.enums}
if var_name in enums: if var_name in enums:
@ -260,6 +274,7 @@ def find_variable(
Event, Event,
Enum, Enum,
Structure, Structure,
CustomError,
], ],
bool, bool,
]: ]:
@ -384,4 +399,4 @@ def find_variable(
if ret: if ret:
return ret, False return ret, False
raise VariableNotFound("Variable not found: {} (context {})".format(var_name, caller_context)) raise VariableNotFound("Variable not found: {} (context {})".format(var_name, contract))

@ -8,6 +8,7 @@ from typing import List, Dict
from slither.analyses.data_dependency.data_dependency import compute_dependency from slither.analyses.data_dependency.data_dependency import compute_dependency
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.import_directive import Import from slither.core.declarations.import_directive import Import
@ -16,6 +17,7 @@ from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.top_level_variable import TopLevelVariable
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc
from slither.solc_parsing.exceptions import VariableNotFound from slither.solc_parsing.exceptions import VariableNotFound
@ -37,6 +39,7 @@ class SlitherCompilationUnitSolc:
self._underlying_contract_to_parser: Dict[Contract, ContractSolc] = dict() self._underlying_contract_to_parser: Dict[Contract, ContractSolc] = dict()
self._structures_top_level_parser: List[StructureTopLevelSolc] = [] self._structures_top_level_parser: List[StructureTopLevelSolc] = []
self._custom_error_parser: List[CustomErrorSolc] = []
self._variables_top_level_parser: List[TopLevelVariableSolc] = [] self._variables_top_level_parser: List[TopLevelVariableSolc] = []
self._functions_top_level_parser: List[FunctionSolc] = [] self._functions_top_level_parser: List[FunctionSolc] = []
@ -146,7 +149,7 @@ class SlitherCompilationUnitSolc:
def parse_top_level_from_loaded_json( def parse_top_level_from_loaded_json(
self, data_loaded: Dict, filename: str self, data_loaded: Dict, filename: str
): # pylint: disable=too-many-branches,too-many-statements ): # pylint: disable=too-many-branches,too-many-statements,too-many-locals
if "nodeType" in data_loaded: if "nodeType" in data_loaded:
self._is_compact_ast = True self._is_compact_ast = True
@ -164,6 +167,8 @@ class SlitherCompilationUnitSolc:
logger.error("solc version is not supported") logger.error("solc version is not supported")
return return
if self.get_children() not in data_loaded:
return
for top_level_data in data_loaded[self.get_children()]: for top_level_data in data_loaded[self.get_children()]:
if top_level_data[self.get_key()] == "ContractDefinition": if top_level_data[self.get_key()] == "ContractDefinition":
contract = Contract(self._compilation_unit) contract = Contract(self._compilation_unit)
@ -235,6 +240,14 @@ class SlitherCompilationUnitSolc:
self._functions_top_level_parser.append(func_parser) self._functions_top_level_parser.append(func_parser)
self.add_function_or_modifier_parser(func_parser) self.add_function_or_modifier_parser(func_parser)
elif top_level_data[self.get_key()] == "ErrorDefinition":
custom_error = CustomErrorTopLevel(self._compilation_unit)
custom_error.set_offset(top_level_data["src"], self._compilation_unit)
custom_error_parser = CustomErrorSolc(custom_error, top_level_data, self)
self._compilation_unit.custom_errors.append(custom_error)
self._custom_error_parser.append(custom_error_parser)
else: else:
raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported") raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported")
@ -522,6 +535,7 @@ Please rename it, this name is reserved for Slither's internals"""
contract.parse_state_variables() contract.parse_state_variables()
contract.parse_modifiers() contract.parse_modifiers()
contract.parse_functions() contract.parse_functions()
contract.parse_custom_errors()
contract.set_is_analyzed(True) contract.set_is_analyzed(True)
def _analyze_struct_events(self, contract: ContractSolc): def _analyze_struct_events(self, contract: ContractSolc):
@ -534,6 +548,7 @@ Please rename it, this name is reserved for Slither's internals"""
contract.analyze_events() contract.analyze_events()
contract.analyze_using_for() contract.analyze_using_for()
contract.analyze_custom_errors()
contract.set_is_analyzed(True) contract.set_is_analyzed(True)
@ -556,6 +571,10 @@ Please rename it, this name is reserved for Slither's internals"""
func_parser.analyze_params() func_parser.analyze_params()
self._compilation_unit.add_function(func_parser.underlying_function) self._compilation_unit.add_function(func_parser.underlying_function)
def _analyze_params_custom_error(self):
for custom_error_parser in self._custom_error_parser:
custom_error_parser.analyze_params()
def _analyze_content_top_level_function(self): def _analyze_content_top_level_function(self):
try: try:
for func_parser in self._functions_top_level_parser: for func_parser in self._functions_top_level_parser:
@ -569,6 +588,7 @@ Please rename it, this name is reserved for Slither's internals"""
contract.analyze_params_modifiers() contract.analyze_params_modifiers()
contract.analyze_params_functions() contract.analyze_params_functions()
self._analyze_params_top_level_function() self._analyze_params_top_level_function()
self._analyze_params_custom_error()
contract.analyze_state_variables() contract.analyze_state_variables()

@ -190,7 +190,12 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t
return UserDefinedType(var_type) return UserDefinedType(var_type)
def parse_type(t: Union[Dict, UnknownType], caller_context): def parse_type(
t: Union[Dict, UnknownType],
caller_context: Union[
"SlitherCompilationUnitSolc", "FunctionSolc", "ContractSolc", "CustomSolc"
],
):
# local import to avoid circular dependency # local import to avoid circular dependency
# pylint: disable=too-many-locals,too-many-branches,too-many-statements # pylint: disable=too-many-locals,too-many-branches,too-many-statements
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -198,17 +203,22 @@ def parse_type(t: Union[Dict, UnknownType], caller_context):
from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc
from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
sl: "SlitherCompilationUnit" sl: "SlitherCompilationUnit"
# Note: for convenicence top level functions use the same parser than function in contract # Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None # but contract_parser is set to None
if isinstance(caller_context, SlitherCompilationUnitSolc) or ( if isinstance(caller_context, (SlitherCompilationUnitSolc, CustomErrorSolc)) or (
isinstance(caller_context, FunctionSolc) and caller_context.contract_parser is None isinstance(caller_context, FunctionSolc) and caller_context.contract_parser is None
): ):
structures_direct_access: List["Structure"]
if isinstance(caller_context, SlitherCompilationUnitSolc): if isinstance(caller_context, SlitherCompilationUnitSolc):
sl = caller_context.compilation_unit sl = caller_context.compilation_unit
next_context = caller_context next_context = caller_context
elif isinstance(caller_context, CustomErrorSolc):
sl = caller_context.underlying_custom_error.compilation_unit
next_context = caller_context.slither_parser
else: else:
assert isinstance(caller_context, FunctionSolc) assert isinstance(caller_context, FunctionSolc)
sl = caller_context.underlying_function.compilation_unit sl = caller_context.underlying_function.compilation_unit
@ -235,13 +245,13 @@ def parse_type(t: Union[Dict, UnknownType], caller_context):
contract = caller_context.underlying_contract contract = caller_context.underlying_contract
next_context = caller_context next_context = caller_context
structures_direct_access = ( structures_direct_access = contract.structures
contract.structures + contract.compilation_unit.structures_top_level structures_direct_access += contract.compilation_unit.structures_top_level
)
all_structuress = [c.structures for c in contract.compilation_unit.contracts] all_structuress = [c.structures for c in contract.compilation_unit.contracts]
all_structures = [item for sublist in all_structuress for item in sublist] all_structures = [item for sublist in all_structuress for item in sublist]
all_structures += contract.compilation_unit.structures_top_level all_structures += contract.compilation_unit.structures_top_level
enums_direct_access = contract.enums + contract.compilation_unit.enums_top_level enums_direct_access: List["Enum"] = contract.enums
enums_direct_access += contract.compilation_unit.enums_top_level
all_enumss = [c.enums for c in contract.compilation_unit.contracts] all_enumss = [c.enums for c in contract.compilation_unit.contracts]
all_enums = [item for sublist in all_enumss for item in sublist] all_enums = [item for sublist in all_enumss for item in sublist]
all_enums += contract.compilation_unit.enums_top_level all_enums += contract.compilation_unit.enums_top_level

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save