diff --git a/.github/workflows/black_auto.yml b/.github/workflows/black_auto.yml new file mode 100644 index 000000000..ad1ad2ea4 --- /dev/null +++ b/.github/workflows/black_auto.yml @@ -0,0 +1,43 @@ +--- +name: Run black (auto) + +defaults: + run: + # To load bashrc + shell: bash -ieo pipefail {0} + +on: + pull_request: + branches: [master, dev] + paths: + - "**/*.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + name: Black + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Set up Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: 3.8 + + - name: Run black + uses: psf/black@stable + with: + options: "" + summary: false + version: "~= 22.3.0" + + - name: Annotate diff changes using reviewdog + uses: reviewdog/action-suggester@v1 + with: + tool_name: blackfmt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d2ba83d0..0bac900ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,26 +25,9 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest", "windows-2022"] + os: ${{ (github.event_name == 'pull_request' && fromJSON('["ubuntu-latest"]')) || fromJSON('["ubuntu-latest", "windows-2022"]') }} python: ${{ (github.event_name == 'pull_request' && fromJSON('["3.8", "3.12"]')) || fromJSON('["3.8", "3.9", "3.10", "3.11", "3.12"]') }} - type: ["cli", - "dapp", - "data_dependency", - "path_filtering", - # "embark", - "erc", - # "etherlime", - "etherscan", - "find_paths", - "flat", - "interface", - "kspec", - "printers", - # "prop" - "simil", - "slither_config", - "truffle", - "upgradability"] + type: ${{ (github.event_name == 'pull_request' && fromJSON('["data_dependency", "path_filtering","erc","find_paths","flat","interface", "printers","slither_config","upgradability"]')) || fromJSON('["data_dependency", "path_filtering","erc","find_paths","flat","interface", "printers","slither_config","upgradability", "cli", "dapp", "etherscan", "kspec", "simil", "truffle"]') }} exclude: # Requires nix - os: windows-2022 @@ -67,7 +50,7 @@ jobs: - name: Set up nix if: matrix.type == 'dapp' - uses: cachix/install-nix-action@V27 + uses: cachix/install-nix-action@v30 - name: Set up cachix if: matrix.type == 'dapp' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0a0f04f2b..2e9f9d966 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: path: dist/ - name: publish - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.3 - name: sign uses: sigstore/gh-action-sigstore-python@v3.0.0 diff --git a/CODEOWNERS b/CODEOWNERS index c92f0d79d..496da0c30 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,5 +1,4 @@ -* @montyly @0xalpharush @smonicas -/slither/tools/read_storage/ @0xalpharush +* @montyly @smonicas /slither/tools/doctor/ @elopez /slither/slithir/ @montyly /slither/analyses/ @montyly diff --git a/README.md b/README.md index 660f4f8e8..011bb5314 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,12 @@ If you're **not** going to use one of the [supported compilation frameworks](htt python3 -m pip install slither-analyzer ``` +#### How to upgrade + +```console +python3 -m pip install --upgrade slither-analyzer +``` + ### Using Git ```bash diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 87d0e16a2..fc178db4a 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -46,12 +46,6 @@ from slither.slithir.variables import ( if TYPE_CHECKING: from slither.slithir.variables.variable import SlithIRVariable from slither.core.compilation_unit import SlitherCompilationUnit - from slither.utils.type_helpers import ( - InternalCallType, - HighLevelCallType, - LibraryCallType, - LowLevelCallType, - ) from slither.core.cfg.scope import Scope from slither.core.scope.scope import FileScope @@ -153,11 +147,11 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls: List[Union["Function", "SolidityFunction"]] = [] - self._solidity_calls: List[SolidityFunction] = [] - self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls - self._library_calls: List["LibraryCallType"] = [] - self._low_level_calls: List["LowLevelCallType"] = [] + self._internal_calls: List[InternalCall] = [] # contains solidity calls + self._solidity_calls: List[SolidityCall] = [] + self._high_level_calls: List[Tuple[Contract, HighLevelCall]] = [] # contains library calls + self._library_calls: List[LibraryCall] = [] + self._low_level_calls: List[LowLevelCall] = [] self._external_calls_as_expressions: List[Expression] = [] self._internal_calls_as_expressions: List[Expression] = [] self._irs: List[Operation] = [] @@ -226,8 +220,9 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods @property def will_return(self) -> bool: if not self.sons and self.type != NodeType.THROW: - if SolidityFunction("revert()") not in self.solidity_calls: - if SolidityFunction("revert(string)") not in self.solidity_calls: + solidity_calls = [ir.function for ir in self.solidity_calls] + if SolidityFunction("revert()") not in solidity_calls: + if SolidityFunction("revert(string)") not in solidity_calls: return True return False @@ -373,44 +368,38 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### @property - def internal_calls(self) -> List["InternalCallType"]: + def internal_calls(self) -> List[InternalCall]: """ - list(Function or SolidityFunction): List of internal/soldiity function calls + list(InternalCall): List of IR operations with internal/solidity function calls """ return list(self._internal_calls) @property - def solidity_calls(self) -> List[SolidityFunction]: + def solidity_calls(self) -> List[SolidityCall]: """ - list(SolidityFunction): List of Soldity calls + list(SolidityCall): List of IR operations with solidity calls """ return list(self._solidity_calls) @property - def high_level_calls(self) -> List["HighLevelCallType"]: + def high_level_calls(self) -> List[HighLevelCall]: """ - list((Contract, Function|Variable)): - List of high level calls (external calls). - A variable is called in case of call to a public state variable + list(HighLevelCall): List of IR operations with high level calls (external calls). Include library calls """ return list(self._high_level_calls) @property - def library_calls(self) -> List["LibraryCallType"]: + def library_calls(self) -> List[LibraryCall]: """ - list((Contract, Function)): - Include library calls + list(LibraryCall): List of IR operations with library calls. """ return list(self._library_calls) @property - def low_level_calls(self) -> List["LowLevelCallType"]: + def low_level_calls(self) -> List[LowLevelCall]: """ - list((Variable|SolidityVariable, str)): List of low_level call - A low level call is defined by - - the variable called - - the name of the function (call/delegatecall/codecall) + list(LowLevelCall): List of IR operations with low_level call """ return list(self._low_level_calls) @@ -529,8 +518,9 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods bool: True if the node has a require or assert call """ return any( - c.name in ["require(bool)", "require(bool,string)", "assert(bool)"] - for c in self.internal_calls + ir.function.name + in ["require(bool)", "require(bool,string)", "require(bool,error)", "assert(bool)"] + for ir in self.internal_calls ) def contains_if(self, include_loop: bool = True) -> bool: @@ -894,11 +884,11 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods self._vars_written.append(var) if isinstance(ir, InternalCall): - self._internal_calls.append(ir.function) + self._internal_calls.append(ir) if isinstance(ir, SolidityCall): # TODO: consider removing dependancy of solidity_call to internal_call - self._solidity_calls.append(ir.function) - self._internal_calls.append(ir.function) + self._solidity_calls.append(ir) + self._internal_calls.append(ir) if ( isinstance(ir, SolidityCall) and ir.function == SolidityFunction("sstore(uint256,uint256)") @@ -916,22 +906,22 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods self._vars_read.append(ir.arguments[0]) if isinstance(ir, LowLevelCall): assert isinstance(ir.destination, (Variable, SolidityVariable)) - self._low_level_calls.append((ir.destination, str(ir.function_name.value))) + self._low_level_calls.append(ir) elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): # Todo investigate this if condition # It does seem right to compare against a contract # This might need a refactoring if isinstance(ir.destination.type, Contract): - self._high_level_calls.append((ir.destination.type, ir.function)) + self._high_level_calls.append((ir.destination.type, ir)) elif ir.destination == SolidityVariable("this"): func = self.function # Can't use this in a top level function assert isinstance(func, FunctionContract) - self._high_level_calls.append((func.contract, ir.function)) + self._high_level_calls.append((func.contract, ir)) else: try: # Todo this part needs more tests and documentation - self._high_level_calls.append((ir.destination.type.type, ir.function)) + self._high_level_calls.append((ir.destination.type.type, ir)) except AttributeError as error: # pylint: disable=raise-missing-from raise SlitherException( @@ -940,8 +930,8 @@ class Node(SourceMapping): # pylint: disable=too-many-public-methods elif isinstance(ir, LibraryCall): assert isinstance(ir.destination, Contract) assert isinstance(ir.function, Function) - self._high_level_calls.append((ir.destination, ir.function)) - self._library_calls.append((ir.destination, ir.function)) + self._high_level_calls.append((ir.destination, ir)) + self._library_calls.append(ir) self._vars_read = list(set(self._vars_read)) self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index df652dab0..f4bd07e55 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -73,7 +73,8 @@ class SlitherCompilationUnit(Context): # Memoize self._all_state_variables: Optional[Set[StateVariable]] = None - self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} + self._persistent_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} + self._transient_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} self._contract_with_missing_inheritance: Set[Contract] = set() @@ -297,33 +298,52 @@ class SlitherCompilationUnit(Context): def compute_storage_layout(self) -> None: assert self.is_solidity + for contract in self.contracts_derived: - self._storage_layouts[contract.name] = {} - - slot = 0 - offset = 0 - for var in contract.stored_state_variables_ordered: - assert var.type - size, new_slot = var.type.storage_size - - if new_slot: - if offset > 0: - slot += 1 - offset = 0 - elif size + offset > 32: + self._compute_storage_layout(contract.name, contract.storage_variables_ordered, False) + self._compute_storage_layout(contract.name, contract.transient_variables_ordered, True) + + def _compute_storage_layout( + self, contract_name: str, state_variables_ordered: List[StateVariable], is_transient: bool + ): + if is_transient: + self._transient_storage_layouts[contract_name] = {} + else: + self._persistent_storage_layouts[contract_name] = {} + + slot = 0 + offset = 0 + for var in state_variables_ordered: + assert var.type + size, new_slot = var.type.storage_size + + if new_slot: + if offset > 0: slot += 1 offset = 0 + elif size + offset > 32: + slot += 1 + offset = 0 - self._storage_layouts[contract.name][var.canonical_name] = ( + if is_transient: + self._transient_storage_layouts[contract_name][var.canonical_name] = ( slot, offset, ) - if new_slot: - slot += math.ceil(size / 32) - else: - offset += size + else: + self._persistent_storage_layouts[contract_name][var.canonical_name] = ( + slot, + offset, + ) + + if new_slot: + slot += math.ceil(size / 32) + else: + offset += size def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]: - return self._storage_layouts[contract.name][var.canonical_name] + if var.is_stored: + return self._persistent_storage_layouts[contract.name][var.canonical_name] + return self._transient_storage_layouts[contract.name][var.canonical_name] # endregion diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 3f97a33ed..8dccc007f 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -29,7 +29,6 @@ from slither.utils.tests_pattern import is_test_contract # pylint: disable=too-many-lines,too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks if TYPE_CHECKING: - from slither.utils.type_helpers import LibraryCallType, HighLevelCallType, InternalCallType from slither.core.declarations import ( Enum, EventContract, @@ -39,6 +38,7 @@ if TYPE_CHECKING: FunctionContract, CustomErrorContract, ) + from slither.slithir.operations import HighLevelCall, LibraryCall from slither.slithir.variables.variable import SlithIRVariable from slither.core.variables import Variable, StateVariable from slither.core.compilation_unit import SlitherCompilationUnit @@ -106,7 +106,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods self._is_incorrectly_parsed: bool = False self._available_functions_as_dict: Optional[Dict[str, "Function"]] = None - self._all_functions_called: Optional[List["InternalCallType"]] = None + self._all_functions_called: Optional[List["Function"]] = None self.compilation_unit: "SlitherCompilationUnit" = compilation_unit self.file_scope: "FileScope" = scope @@ -440,55 +440,43 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def state_variables(self) -> List["StateVariable"]: """ Returns all the accessible variables (do not include private variable from inherited contract). - Use state_variables_ordered for all the variables following the storage order + Use stored_state_variables_ordered for all the storage variables following the storage order + Use transient_state_variables_ordered for all the transient variables following the storage order list(StateVariable): List of the state variables. """ return list(self._variables.values()) @property - def stored_state_variables(self) -> List["StateVariable"]: + def state_variables_entry_points(self) -> List["StateVariable"]: """ - Returns state variables with storage locations, excluding private variables from inherited contracts. - Use stored_state_variables_ordered to access variables with storage locations in their declaration order. - - This implementation filters out state variables if they are constant or immutable. It will be - updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future. - - Returns: - List[StateVariable]: A list of state variables with storage locations. + list(StateVariable): List of the state variables that are public. """ - return [variable for variable in self.state_variables if variable.is_stored] + return [var for var in self._variables.values() if var.visibility == "public"] @property - def stored_state_variables_ordered(self) -> List["StateVariable"]: + def state_variables_ordered(self) -> List["StateVariable"]: """ - list(StateVariable): List of the state variables with storage locations by order of declaration. - - This implementation filters out state variables if they are constant or immutable. It will be - updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future. - - Returns: - List[StateVariable]: A list of state variables with storage locations ordered by declaration. + list(StateVariable): List of the state variables by order of declaration. """ - return [variable for variable in self.state_variables_ordered if variable.is_stored] + return self._variables_ordered + + def add_state_variables_ordered(self, new_vars: List["StateVariable"]) -> None: + self._variables_ordered += new_vars @property - def state_variables_entry_points(self) -> List["StateVariable"]: + def storage_variables_ordered(self) -> List["StateVariable"]: """ - list(StateVariable): List of the state variables that are public. + list(StateVariable): List of the state variables in storage location by order of declaration. """ - return [var for var in self._variables.values() if var.visibility == "public"] + return [v for v in self._variables_ordered if v.is_stored] @property - def state_variables_ordered(self) -> List["StateVariable"]: + def transient_variables_ordered(self) -> List["StateVariable"]: """ - list(StateVariable): List of the state variables by order of declaration. + list(StateVariable): List of the state variables in transient location by order of declaration. """ - return list(self._variables_ordered) - - def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None: - self._variables_ordered += new_vars + return [v for v in self._variables_ordered if v.is_transient] @property def state_variables_inherited(self) -> List["StateVariable"]: @@ -1023,15 +1011,21 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### @property - def all_functions_called(self) -> List["InternalCallType"]: + def all_functions_called(self) -> List["Function"]: """ list(Function): List of functions reachable from the contract Includes super, and private/internal functions not shadowed """ + from slither.slithir.operations import Operation + if self._all_functions_called is None: all_functions = [f for f in self.functions + self.modifiers if not f.is_shadowed] # type: ignore all_callss = [f.all_internal_calls() for f in all_functions] + [list(all_functions)] - all_calls = [item for sublist in all_callss for item in sublist] + all_calls = [ + item.function if isinstance(item, Operation) else item + for sublist in all_callss + for item in sublist + ] all_calls = list(set(all_calls)) all_constructors = [c.constructor for c in self.inheritance if c.constructor] @@ -1069,18 +1063,18 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return list(set(all_state_variables_read)) @property - def all_library_calls(self) -> List["LibraryCallType"]: + def all_library_calls(self) -> List["LibraryCall"]: """ - list((Contract, Function): List all of the libraries func called + list(LibraryCall): List all of the libraries func called """ all_high_level_callss = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [item for sublist in all_high_level_callss for item in sublist] return list(set(all_high_level_calls)) @property - def all_high_level_calls(self) -> List["HighLevelCallType"]: + def all_high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]: """ - list((Contract, Function|Variable)): List all of the external high level calls + list(Tuple("Contract", "HighLevelCall")): List all of the external high level calls """ all_high_level_callss = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore all_high_level_calls = [item for sublist in all_high_level_callss for item in sublist] diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 6e8968dfb..b91e58f24 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -31,19 +31,20 @@ from slither.utils.utils import unroll # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines if TYPE_CHECKING: - from slither.utils.type_helpers import ( - InternalCallType, - LowLevelCallType, - HighLevelCallType, - LibraryCallType, - ) from slither.core.declarations import Contract, FunctionContract from slither.core.cfg.node import Node, NodeType from slither.core.variables.variable import Variable from slither.slithir.variables.variable import SlithIRVariable from slither.slithir.variables import LocalIRVariable from slither.core.expressions.expression import Expression - from slither.slithir.operations import Operation + from slither.slithir.operations import ( + HighLevelCall, + InternalCall, + LibraryCall, + LowLevelCall, + SolidityCall, + Operation, + ) from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope @@ -149,11 +150,11 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._vars_read_or_written: List["Variable"] = [] self._solidity_vars_read: List["SolidityVariable"] = [] self._state_vars_written: List["StateVariable"] = [] - self._internal_calls: List["InternalCallType"] = [] - self._solidity_calls: List["SolidityFunction"] = [] - self._low_level_calls: List["LowLevelCallType"] = [] - self._high_level_calls: List["HighLevelCallType"] = [] - self._library_calls: List["LibraryCallType"] = [] + self._internal_calls: List["InternalCall"] = [] + self._solidity_calls: List["SolidityCall"] = [] + self._low_level_calls: List["LowLevelCall"] = [] + self._high_level_calls: List[Tuple["Contract", "HighLevelCall"]] = [] + self._library_calls: List["LibraryCall"] = [] self._external_calls_as_expressions: List["Expression"] = [] self._expression_vars_read: List["Expression"] = [] self._expression_vars_written: List["Expression"] = [] @@ -169,11 +170,11 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._all_expressions: Optional[List["Expression"]] = None self._all_slithir_operations: Optional[List["Operation"]] = None - self._all_internals_calls: Optional[List["InternalCallType"]] = None - self._all_high_level_calls: Optional[List["HighLevelCallType"]] = None - self._all_library_calls: Optional[List["LibraryCallType"]] = None - self._all_low_level_calls: Optional[List["LowLevelCallType"]] = None - self._all_solidity_calls: Optional[List["SolidityFunction"]] = None + self._all_internals_calls: Optional[List["InternalCall"]] = None + self._all_high_level_calls: Optional[List[Tuple["Contract", "HighLevelCall"]]] = None + self._all_library_calls: Optional[List["LibraryCall"]] = None + self._all_low_level_calls: Optional[List["LowLevelCall"]] = None + self._all_solidity_calls: Optional[List["SolidityCall"]] = None self._all_variables_read: Optional[List["Variable"]] = None self._all_variables_written: Optional[List["Variable"]] = None self._all_state_variables_read: Optional[List["StateVariable"]] = None @@ -857,43 +858,42 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### @property - def internal_calls(self) -> List["InternalCallType"]: + def internal_calls(self) -> List["InternalCall"]: """ - list(Function or SolidityFunction): List of function calls (that does not create a transaction) + list(InternalCall): List of IR operations for internal calls """ return list(self._internal_calls) @property - def solidity_calls(self) -> List[SolidityFunction]: + def solidity_calls(self) -> List["SolidityCall"]: """ - list(SolidityFunction): List of Soldity calls + list(SolidityCall): List of IR operations for Solidity calls """ return list(self._solidity_calls) @property - def high_level_calls(self) -> List["HighLevelCallType"]: + def high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]: """ - list((Contract, Function|Variable)): - List of high level calls (external calls). + list(Tuple(Contract, "HighLevelCall")): List of call target contract and IR of the high level call A variable is called in case of call to a public state variable Include library calls """ return list(self._high_level_calls) @property - def library_calls(self) -> List["LibraryCallType"]: + def library_calls(self) -> List["LibraryCall"]: """ - list((Contract, Function)): + list(LibraryCall): List of IR operations for library calls """ return list(self._library_calls) @property - def low_level_calls(self) -> List["LowLevelCallType"]: + def low_level_calls(self) -> List["LowLevelCall"]: """ - list((Variable|SolidityVariable, str)): List of low_level call + list(LowLevelCall): List of IR operations for low level calls A low level call is defined by - the variable called - - the name of the function (call/delegatecall/codecall) + - the name of the function (call/delegatecall/callcode) """ return list(self._low_level_calls) @@ -1121,10 +1121,14 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu values = f_new_values(self) explored = [self] to_explore = [ - c for c in self.internal_calls if isinstance(c, Function) and c not in explored + ir.function + for ir in self.internal_calls + if isinstance(ir.function, Function) and ir.function not in explored ] to_explore += [ - c for (_, c) in self.library_calls if isinstance(c, Function) and c not in explored + ir.function + for ir in self.library_calls + if isinstance(ir.function, Function) and ir.function not in explored ] to_explore += [m for m in self.modifiers if m not in explored] @@ -1138,14 +1142,18 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu values += f_new_values(f) to_explore += [ - c - for c in f.internal_calls - if isinstance(c, Function) and c not in explored and c not in to_explore + ir.function + for ir in f.internal_calls + if isinstance(ir.function, Function) + and ir.function not in explored + and ir.function not in to_explore ] to_explore += [ - c - for (_, c) in f.library_calls - if isinstance(c, Function) and c not in explored and c not in to_explore + ir.function + for ir in f.library_calls + if isinstance(ir.function, Function) + and ir.function not in explored + and ir.function not in to_explore ] to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] @@ -1210,31 +1218,31 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ) return self._all_state_variables_written - def all_internal_calls(self) -> List["InternalCallType"]: + def all_internal_calls(self) -> List["InternalCall"]: """recursive version of internal_calls""" if self._all_internals_calls is None: self._all_internals_calls = self._explore_functions(lambda x: x.internal_calls) return self._all_internals_calls - def all_low_level_calls(self) -> List["LowLevelCallType"]: + def all_low_level_calls(self) -> List["LowLevelCall"]: """recursive version of low_level calls""" if self._all_low_level_calls is None: self._all_low_level_calls = self._explore_functions(lambda x: x.low_level_calls) return self._all_low_level_calls - def all_high_level_calls(self) -> List["HighLevelCallType"]: + def all_high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]: """recursive version of high_level calls""" if self._all_high_level_calls is None: self._all_high_level_calls = self._explore_functions(lambda x: x.high_level_calls) return self._all_high_level_calls - def all_library_calls(self) -> List["LibraryCallType"]: + def all_library_calls(self) -> List["LibraryCall"]: """recursive version of library calls""" if self._all_library_calls is None: self._all_library_calls = self._explore_functions(lambda x: x.library_calls) return self._all_library_calls - def all_solidity_calls(self) -> List[SolidityFunction]: + def all_solidity_calls(self) -> List["SolidityCall"]: """recursive version of solidity calls""" if self._all_solidity_calls is None: self._all_solidity_calls = self._explore_functions(lambda x: x.solidity_calls) @@ -1653,7 +1661,9 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu internal_calls = [item for sublist in internal_calls for item in sublist] self._internal_calls = list(set(internal_calls)) - self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] + self._solidity_calls = [ + ir for ir in internal_calls if isinstance(ir.function, SolidityFunction) + ] low_level_calls = [x.low_level_calls for x in self.nodes] low_level_calls = [x for x in low_level_calls if x] diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 8094ab7c3..ce8477ab2 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -50,6 +50,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = { "assert(bool)": [], "require(bool)": [], "require(bool,string)": [], + "require(bool,error)": [], # Solidity 0.8.26 via-ir and Solidity >= 0.8.27 "revert()": [], "revert(string)": [], "revert ": [], diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index 1206e564b..f5e36ace3 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -8,7 +8,7 @@ import pathlib import posixpath import re from collections import defaultdict -from typing import Optional, Dict, List, Set, Union, Tuple +from typing import Optional, Dict, List, Set, Union, Tuple, TypeVar from crytic_compile import CryticCompile from crytic_compile.utils.naming import Filename @@ -88,6 +88,7 @@ class SlitherCore(Context): self._contracts: List[Contract] = [] self._contracts_derived: List[Contract] = [] + self._offset_to_min_offset: Optional[Dict[Filename, Dict[int, Set[int]]]] = None self._offset_to_objects: Optional[Dict[Filename, Dict[int, Set[SourceMapping]]]] = None self._offset_to_references: Optional[Dict[Filename, Dict[int, Set[Source]]]] = None self._offset_to_implementations: Optional[Dict[Filename, Dict[int, Set[Source]]]] = None @@ -195,69 +196,70 @@ class SlitherCore(Context): for f in c.functions: f.cfg_to_dot(os.path.join(d, f"{c.name}.{f.name}.dot")) - def offset_to_objects(self, filename_str: str, offset: int) -> Set[SourceMapping]: - if self._offset_to_objects is None: - self._compute_offsets_to_ref_impl_decl() - filename: Filename = self.crytic_compile.filename_lookup(filename_str) - return self._offset_to_objects[filename][offset] - def _compute_offsets_from_thing(self, thing: SourceMapping): definition = get_definition(thing, self.crytic_compile) references = get_references(thing) implementations = get_all_implementations(thing, self.contracts) + # Create the offset mapping for offset in range(definition.start, definition.end + 1): - if ( - isinstance(thing, (TopLevel, Contract)) - or ( - isinstance(thing, FunctionContract) - and thing.contract_declarer == thing.contract - ) - or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) - ): + self._offset_to_min_offset[definition.filename][offset].add(definition.start) - self._offset_to_objects[definition.filename][offset].add(thing) + is_declared_function = ( + isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract + ) - self._offset_to_definitions[definition.filename][offset].add(definition) - self._offset_to_implementations[definition.filename][offset].update(implementations) - self._offset_to_references[definition.filename][offset] |= set(references) + should_add_to_objects = ( + isinstance(thing, (TopLevel, Contract)) + or is_declared_function + or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) + ) + + if should_add_to_objects: + self._offset_to_objects[definition.filename][definition.start].add(thing) + + self._offset_to_definitions[definition.filename][definition.start].add(definition) + self._offset_to_implementations[definition.filename][definition.start].update( + implementations + ) + self._offset_to_references[definition.filename][definition.start] |= set(references) + + # For references + should_add_to_objects = ( + isinstance(thing, TopLevel) + or is_declared_function + or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) + ) for ref in references: for offset in range(ref.start, ref.end + 1): - is_declared_function = ( - isinstance(thing, FunctionContract) - and thing.contract_declarer == thing.contract - ) + self._offset_to_min_offset[definition.filename][offset].add(ref.start) + + if should_add_to_objects: + self._offset_to_objects[definition.filename][ref.start].add(thing) + + if is_declared_function: + # Only show the nearest lexical definition for declared contract-level functions if ( - isinstance(thing, TopLevel) - or is_declared_function - or ( - isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract) - ) + thing.contract.source_mapping.start + < ref.start + < thing.contract.source_mapping.end ): - self._offset_to_objects[definition.filename][offset].add(thing) - - if is_declared_function: - # Only show the nearest lexical definition for declared contract-level functions - if ( - thing.contract.source_mapping.start - < offset - < thing.contract.source_mapping.end - ): - self._offset_to_definitions[ref.filename][offset].add(definition) + self._offset_to_definitions[ref.filename][ref.start].add(definition) - else: - self._offset_to_definitions[ref.filename][offset].add(definition) + else: + self._offset_to_definitions[ref.filename][ref.start].add(definition) - self._offset_to_implementations[ref.filename][offset].update(implementations) - self._offset_to_references[ref.filename][offset] |= set(references) + self._offset_to_implementations[ref.filename][ref.start].update(implementations) + self._offset_to_references[ref.filename][ref.start] |= set(references) def _compute_offsets_to_ref_impl_decl(self): # pylint: disable=too-many-branches self._offset_to_references = defaultdict(lambda: defaultdict(lambda: set())) self._offset_to_definitions = defaultdict(lambda: defaultdict(lambda: set())) self._offset_to_implementations = defaultdict(lambda: defaultdict(lambda: set())) self._offset_to_objects = defaultdict(lambda: defaultdict(lambda: set())) + self._offset_to_min_offset = defaultdict(lambda: defaultdict(lambda: set())) for compilation_unit in self._compilation_units: for contract in compilation_unit.contracts: @@ -308,23 +310,59 @@ class SlitherCore(Context): for pragma in compilation_unit.pragma_directives: self._compute_offsets_from_thing(pragma) + T = TypeVar("T", Source, SourceMapping) + + def _get_offset( + self, mapping: Dict[Filename, Dict[int, Set[T]]], filename_str: str, offset: int + ) -> Set[T]: + """Get the Source/SourceMapping referenced by the offset. + + For performance reasons, references are only stored once at the lowest offset. + It uses the _offset_to_min_offset mapping to retrieve the correct offsets. + As multiple definitions can be related to the same offset, we retrieve all of them. + + :param mapping: Mapping to search for (objects. references, ...) + :param filename_str: Filename to consider + :param offset: Look-up offset + :raises IndexError: When the start offset is not found + :return: The corresponding set of Source/SourceMapping + """ + filename: Filename = self.crytic_compile.filename_lookup(filename_str) + + start_offsets = self._offset_to_min_offset[filename][offset] + if not start_offsets: + msg = f"Unable to find reference for offset {offset}" + raise IndexError(msg) + + results = set() + for start_offset in start_offsets: + results |= mapping[filename][start_offset] + + return results + def offset_to_references(self, filename_str: str, offset: int) -> Set[Source]: if self._offset_to_references is None: self._compute_offsets_to_ref_impl_decl() - filename: Filename = self.crytic_compile.filename_lookup(filename_str) - return self._offset_to_references[filename][offset] + + return self._get_offset(self._offset_to_references, filename_str, offset) def offset_to_implementations(self, filename_str: str, offset: int) -> Set[Source]: if self._offset_to_implementations is None: self._compute_offsets_to_ref_impl_decl() - filename: Filename = self.crytic_compile.filename_lookup(filename_str) - return self._offset_to_implementations[filename][offset] + + return self._get_offset(self._offset_to_implementations, filename_str, offset) def offset_to_definitions(self, filename_str: str, offset: int) -> Set[Source]: if self._offset_to_definitions is None: self._compute_offsets_to_ref_impl_decl() - filename: Filename = self.crytic_compile.filename_lookup(filename_str) - return self._offset_to_definitions[filename][offset] + + return self._get_offset(self._offset_to_definitions, filename_str, offset) + + def offset_to_objects(self, filename_str: str, offset: int) -> Set[SourceMapping]: + if self._offset_to_objects is None: + self._compute_offsets_to_ref_impl_decl() + + return self._get_offset(self._offset_to_objects, filename_str, offset) # endregion ################################################################################### diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index 2d644148e..8a328e361 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -36,7 +36,7 @@ class FunctionType(Type): def is_dynamic(self) -> bool: return False - def __str__(self): + def __str__(self) -> str: # Use x.type # x.name may be empty params = ",".join([str(x.type) for x in self._params]) diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index 41841f1e8..355aa5538 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -112,12 +112,8 @@ class Source: try: return ( self.start == other.start - and self.length == other.length - and self.filename == other.filename + and self.filename.relative == other.filename.relative and self.is_dependency == other.is_dependency - and self.lines == other.lines - and self.starting_column == other.starting_column - and self.ending_column == other.ending_column and self.end == other.end ) except AttributeError: diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index f2a2d6ee3..d3e3e6018 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -12,6 +12,7 @@ class StateVariable(ContractLevel, Variable): def __init__(self) -> None: super().__init__() self._node_initialization: Optional["Node"] = None + self._location: Optional[str] = None def is_declared_by(self, contract: "Contract") -> bool: """ @@ -21,6 +22,35 @@ class StateVariable(ContractLevel, Variable): """ return self.contract == contract + def set_location(self, loc: str) -> None: + self._location = loc + + @property + def location(self) -> Optional[str]: + """ + Variable Location + Can be default or transient + Returns: + (str) + """ + return self._location + + @property + def is_stored(self) -> bool: + """ + Checks if the state variable is stored, based on it not being constant or immutable or transient. + """ + return ( + not self._is_constant and not self._is_immutable and not self._location == "transient" + ) + + @property + def is_transient(self) -> bool: + """ + Checks if the state variable is transient. A transient variable can not be constant or immutable. + """ + return self._location == "transient" + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 63d1a7a83..4af350d31 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -9,6 +9,7 @@ from slither.core.solidity_types.elementary_type import ElementaryType if TYPE_CHECKING: from slither.core.expressions.expression import Expression + from slither.core.declarations import Function # pylint: disable=too-many-instance-attributes class Variable(SourceMapping): @@ -16,7 +17,7 @@ class Variable(SourceMapping): super().__init__() self._name: Optional[str] = None self._initial_expression: Optional["Expression"] = None - self._type: Optional[Type] = None + self._type: Optional[Union[List, Type, "Function", str]] = None self._initialized: Optional[bool] = None self._visibility: Optional[str] = None self._is_constant = False @@ -77,7 +78,7 @@ class Variable(SourceMapping): self._name = name @property - def type(self) -> Optional[Type]: + def type(self) -> Optional[Union[List, Type, "Function", str]]: return self._type @type.setter @@ -93,13 +94,6 @@ class Variable(SourceMapping): def is_constant(self, is_cst: bool) -> None: self._is_constant = is_cst - @property - def is_stored(self) -> bool: - """ - Checks if a variable is stored, based on it not being constant or immutable. Future updates may adjust for new non-storage keywords. - """ - return not self._is_constant and not self._is_immutable - @property def is_reentrant(self) -> bool: return self._is_reentrant @@ -127,7 +121,7 @@ class Variable(SourceMapping): def visibility(self, v: str) -> None: self._visibility = v - def set_type(self, t: Optional[Union[List, Type, str]]) -> None: + def set_type(self, t: Optional[Union[List, Type, "Function", str]]) -> None: if isinstance(t, str): self._type = ElementaryType(t) return diff --git a/slither/detectors/all_detectors.py b/slither/detectors/all_detectors.py index 44a168c2b..a30d2b3c0 100644 --- a/slither/detectors/all_detectors.py +++ b/slither/detectors/all_detectors.py @@ -97,5 +97,12 @@ from .operations.incorrect_exp import IncorrectOperatorExponentiation from .statements.tautological_compare import TautologicalCompare from .statements.return_bomb import ReturnBomb from .functions.out_of_order_retryable import OutOfOrderRetryable +from .functions.gelato_unprotected_randomness import GelatoUnprotectedRandomness +from .statements.chronicle_unchecked_price import ChronicleUncheckedPrice +from .statements.pyth_unchecked_confidence import PythUncheckedConfidence +from .statements.pyth_unchecked_publishtime import PythUncheckedPublishTime +from .functions.chainlink_feed_registry import ChainlinkFeedRegistry +from .functions.pyth_deprecated_functions import PythDeprecatedFunctions +from .functions.optimism_deprecation import OptimismDeprecation # from .statements.unused_import import UnusedImport diff --git a/slither/detectors/assembly/incorrect_return.py b/slither/detectors/assembly/incorrect_return.py index bd5a6d844..9052979ac 100644 --- a/slither/detectors/assembly/incorrect_return.py +++ b/slither/detectors/assembly/incorrect_return.py @@ -21,10 +21,8 @@ def _assembly_node(function: Function) -> Optional[SolidityCall]: """ - for ir in function.all_slithir_operations(): - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "return(uint256,uint256)" - ): + for ir in function.all_solidity_calls(): + if ir.function == SolidityFunction("return(uint256,uint256)"): return ir return None @@ -71,23 +69,23 @@ The function will return 6 bytes starting from offset 5, instead of returning a for c in self.contracts: for f in c.functions_and_modifiers_declared: - for node in f.nodes: - if node.sons: - for function_called in node.internal_calls: - if isinstance(function_called, Function): - found = _assembly_node(function_called) - if found: - - info: DETECTOR_INFO = [ - f, - " calls ", - function_called, - " which halt the execution ", - found.node, - "\n", - ] - json = self.generate_result(info) - - results.append(json) + for ir in f.internal_calls: + if ir.node.sons: + function_called = ir.function + if isinstance(function_called, Function): + found = _assembly_node(function_called) + if found: + + info: DETECTOR_INFO = [ + f, + " calls ", + function_called, + " which halt the execution ", + found.node, + "\n", + ] + json = self.generate_result(info) + + results.append(json) return results diff --git a/slither/detectors/assembly/return_instead_of_leave.py b/slither/detectors/assembly/return_instead_of_leave.py index a1ad9c87e..603705974 100644 --- a/slither/detectors/assembly/return_instead_of_leave.py +++ b/slither/detectors/assembly/return_instead_of_leave.py @@ -6,7 +6,6 @@ from slither.detectors.abstract_detector import ( DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import SolidityCall from slither.utils.output import Output @@ -42,15 +41,12 @@ The function will halt the execution, instead of returning a two uint.""" def _check_function(self, f: Function) -> List[Output]: results: List[Output] = [] - for node in f.nodes: - for ir in node.irs: - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "return(uint256,uint256)" - ): - info: DETECTOR_INFO = [f, " contains an incorrect call to return: ", node, "\n"] - json = self.generate_result(info) + for ir in f.solidity_calls: + if ir.function == SolidityFunction("return(uint256,uint256)"): + info: DETECTOR_INFO = [f, " contains an incorrect call to return: ", ir.node, "\n"] + json = self.generate_result(info) - results.append(json) + results.append(json) return results def _detect(self) -> List[Output]: diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py index 91ec68650..efb376e22 100644 --- a/slither/detectors/attributes/locked_ether.py +++ b/slither/detectors/attributes/locked_ether.py @@ -59,7 +59,7 @@ Every Ether sent to `Locked` will be lost.""" explored += to_explore to_explore = [] for function in functions: - calls = [c.name for c in function.internal_calls] + calls = [ir.function.name for ir in function.internal_calls] if "suicide(address)" in calls or "selfdestruct(address)" in calls: return False for node in function.nodes: diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index 47e2af581..e4dde4360 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -13,8 +13,6 @@ from slither.detectors.abstract_detector import ( from slither.core.solidity_types.array_type import ArrayType from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable -from slither.slithir.operations.high_level_call import HighLevelCall -from slither.slithir.operations.internal_call import InternalCall from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract @@ -117,37 +115,26 @@ As a result, Bob's usage of the contract is incorrect.""" # pylint: disable=too-many-nested-blocks for contract in contracts: for function in contract.functions_and_modifiers_declared: - for node in function.nodes: + for ir in [ir for _, ir in function.high_level_calls] + function.internal_calls: - # If this node has no expression, skip it. - if not node.expression: + # Verify this references a function in our array modifying functions collection. + if ir.function not in array_modifying_funcs: continue - for ir in node.irs: - # Verify this is a high level call. - if not isinstance(ir, (HighLevelCall, InternalCall)): + # Verify one of these parameters is an array in storage. + for (param, arg) in zip(ir.function.parameters, ir.arguments): + # Verify this argument is a variable that is an array type. + if not isinstance(arg, (StateVariable, LocalVariable)): continue - - # Verify this references a function in our array modifying functions collection. - if ir.function not in array_modifying_funcs: + if not isinstance(arg.type, ArrayType): continue - # Verify one of these parameters is an array in storage. - for (param, arg) in zip(ir.function.parameters, ir.arguments): - # Verify this argument is a variable that is an array type. - if not isinstance(arg, (StateVariable, LocalVariable)): - continue - if not isinstance(arg.type, ArrayType): - continue - - # If it is a state variable OR a local variable referencing storage, we add it to the list. - if ( - isinstance(arg, StateVariable) - or (isinstance(arg, LocalVariable) and arg.location == "storage") - ) and ( - isinstance(param.type, ArrayType) and param.location != "storage" - ): - results.append((node, arg, ir.function)) + # If it is a state variable OR a local variable referencing storage, we add it to the list. + if ( + isinstance(arg, StateVariable) + or (isinstance(arg, LocalVariable) and arg.location == "storage") + ) and (isinstance(param.type, ArrayType) and param.location != "storage"): + results.append((ir.node, arg, ir.function)) return results def _detect(self) -> List[Output]: diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20.py b/slither/detectors/erc/erc20/arbitrary_send_erc20.py index f06005459..4dc1f8db5 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20.py @@ -3,7 +3,7 @@ from typing import List from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node from slither.core.compilation_unit import SlitherCompilationUnit -from slither.core.declarations import Contract, Function, SolidityVariableComposed +from slither.core.declarations import Contract, Function, SolidityVariableComposed, FunctionContract from slither.core.declarations.solidity_variables import SolidityVariable from slither.slithir.operations import HighLevelCall, LibraryCall @@ -31,11 +31,11 @@ class ArbitrarySendErc20: def _detect_arbitrary_from(self, contract: Contract) -> None: for f in contract.functions: all_high_level_calls = [ - f_called[1].solidity_signature - for f_called in f.high_level_calls - if isinstance(f_called[1], Function) + ir.function.solidity_signature + for _, ir in f.high_level_calls + if isinstance(ir.function, Function) ] - all_library_calls = [f_called[1].solidity_signature for f_called in f.library_calls] + all_library_calls = [ir.function.solidity_signature for ir in f.library_calls] if ( "transferFrom(address,address,uint256)" in all_high_level_calls or "safeTransferFrom(address,address,address,uint256)" in all_library_calls @@ -44,51 +44,50 @@ class ArbitrarySendErc20: "permit(address,address,uint256,uint256,uint8,bytes32,bytes32)" in all_high_level_calls ): - ArbitrarySendErc20._arbitrary_from(f.nodes, self._permit_results) + ArbitrarySendErc20._arbitrary_from(f, self._permit_results) else: - ArbitrarySendErc20._arbitrary_from(f.nodes, self._no_permit_results) + ArbitrarySendErc20._arbitrary_from(f, self._no_permit_results) @staticmethod - def _arbitrary_from(nodes: List[Node], results: List[Node]) -> None: + def _arbitrary_from(function: FunctionContract, results: List[Node]) -> None: """Finds instances of (safe)transferFrom that do not use msg.sender or address(this) as from parameter.""" - for node in nodes: - for ir in node.irs: - if ( - isinstance(ir, HighLevelCall) - and isinstance(ir.function, Function) - and ir.function.solidity_signature == "transferFrom(address,address,uint256)" - and not ( - is_dependent( - ir.arguments[0], - SolidityVariableComposed("msg.sender"), - node, - ) - or is_dependent( - ir.arguments[0], - SolidityVariable("this"), - node, - ) + for _, ir in function.high_level_calls: + if ( + isinstance(ir, LibraryCall) + and ir.function.solidity_signature + == "safeTransferFrom(address,address,address,uint256)" + and not ( + is_dependent( + ir.arguments[1], + SolidityVariableComposed("msg.sender"), + ir.node, ) - ): - results.append(ir.node) - elif ( - isinstance(ir, LibraryCall) - and ir.function.solidity_signature - == "safeTransferFrom(address,address,address,uint256)" - and not ( - is_dependent( - ir.arguments[1], - SolidityVariableComposed("msg.sender"), - node, - ) - or is_dependent( - ir.arguments[1], - SolidityVariable("this"), - node, - ) + or is_dependent( + ir.arguments[1], + SolidityVariable("this"), + ir.node, ) - ): - results.append(ir.node) + ) + ): + results.append(ir.node) + elif ( + isinstance(ir, HighLevelCall) + and isinstance(ir.function, Function) + and ir.function.solidity_signature == "transferFrom(address,address,uint256)" + and not ( + is_dependent( + ir.arguments[0], + SolidityVariableComposed("msg.sender"), + ir.node, + ) + or is_dependent( + ir.arguments[0], + SolidityVariable("this"), + ir.node, + ) + ) + ): + results.append(ir.node) def detect(self) -> None: """Detect transfers that use arbitrary `from` parameter.""" diff --git a/slither/detectors/functions/chainlink_feed_registry.py b/slither/detectors/functions/chainlink_feed_registry.py new file mode 100644 index 000000000..82ab17424 --- /dev/null +++ b/slither/detectors/functions/chainlink_feed_registry.py @@ -0,0 +1,102 @@ +from typing import List + +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) +from slither.utils.output import Output + + +class ChainlinkFeedRegistry(AbstractDetector): + + ARGUMENT = "chainlink-feed-registry" + HELP = "Detect when chainlink feed registry is used" + IMPACT = DetectorClassification.LOW + CONFIDENCE = DetectorClassification.HIGH + + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#chainlink-feed-registry" + + WIKI_TITLE = "Chainlink Feed Registry usage" + WIKI_DESCRIPTION = "Detect when Chainlink Feed Registry is used. At the moment is only available on Ethereum Mainnet." + + # region wiki_exploit_scenario + WIKI_EXPLOIT_SCENARIO = """ +```solidity +import "chainlink/contracts/src/v0.8/interfaces/FeedRegistryInteface.sol" + +contract A { + FeedRegistryInterface public immutable registry; + + constructor(address _registry) { + registry = _registry; + } + + function getPrice(address base, address quote) public return(uint256) { + (, int256 price,,,) = registry.latestRoundData(base, quote); + // Do price validation + return uint256(price); + } +} +``` +If the contract is deployed on a different chain than Ethereum Mainnet the `getPrice` function will revert. +""" + # endregion wiki_exploit_scenario + + WIKI_RECOMMENDATION = "Do not use Chainlink Feed Registry outside of Ethereum Mainnet." + + def _detect(self) -> List[Output]: + # https://github.com/smartcontractkit/chainlink/blob/8ca41fc8f722accfccccb4b1778db2df8fef5437/contracts/src/v0.8/interfaces/FeedRegistryInterface.sol + registry_functions = [ + "decimals", + "description", + "versiom", + "latestRoundData", + "getRoundData", + "latestAnswer", + "latestTimestamp", + "latestRound", + "getAnswer", + "getTimestamp", + "getFeed", + "getPhaseFeed", + "isFeedEnabled", + "getPhase", + "getRoundFeed", + "getPhaseRange", + "getPreviousRoundId", + "getNextRoundId", + "proposeFeed", + "confirmFeed", + "getProposedFeed", + "proposedGetRoundData", + "proposedLatestRoundData", + "getCurrentPhaseId", + ] + results = [] + + for contract in self.compilation_unit.contracts_derived: + nodes = [] + for target, ir in contract.all_high_level_calls: + if ( + target.name == "FeedRegistryInterface" + and ir.function_name in registry_functions + ): + nodes.append(ir.node) + # Sort so output is deterministic + nodes.sort(key=lambda x: (x.node_id, x.function.full_name)) + + if len(nodes) > 0: + info: DETECTOR_INFO = [ + "The Chainlink Feed Registry is used in the ", + contract.name, + " contract. It's only available on Ethereum Mainnet, consider to not use it if the contract needs to be deployed on other chains.\n", + ] + + for node in nodes: + info.extend(["\t - ", node, "\n"]) + + res = self.generate_result(info) + results.append(res) + + return results diff --git a/slither/detectors/functions/dead_code.py b/slither/detectors/functions/dead_code.py index 5cafa1650..3628d10a2 100644 --- a/slither/detectors/functions/dead_code.py +++ b/slither/detectors/functions/dead_code.py @@ -48,13 +48,15 @@ contract Contract{ all_functionss_called = [ f.all_internal_calls() for f in contract.functions_entry_points ] - all_functions_called = [item for sublist in all_functionss_called for item in sublist] + all_functions_called = [ + item.function for sublist in all_functionss_called for item in sublist + ] functions_used |= { f.canonical_name for f in all_functions_called if isinstance(f, Function) } all_libss_called = [f.all_library_calls() for f in contract.functions_entry_points] all_libs_called: List[Tuple[Contract, Function]] = [ - item for sublist in all_libss_called for item in sublist + item.function for sublist in all_libss_called for item in sublist ] functions_used |= { lib[1].canonical_name for lib in all_libs_called if isinstance(lib, tuple) diff --git a/slither/detectors/functions/external_function.py b/slither/detectors/functions/external_function.py index 5858c2baf..d9cc2bc36 100644 --- a/slither/detectors/functions/external_function.py +++ b/slither/detectors/functions/external_function.py @@ -13,8 +13,7 @@ from slither.detectors.abstract_detector import ( make_solc_versions, ) from slither.formatters.functions.external_function import custom_format -from slither.slithir.operations import InternalCall, InternalDynamicCall -from slither.slithir.operations import SolidityCall +from slither.slithir.operations import InternalDynamicCall from slither.utils.output import Output @@ -55,11 +54,11 @@ class ExternalFunction(AbstractDetector): for func in contract.all_functions_called: if not isinstance(func, Function): continue - # Loop through all nodes in the function, add all calls to a list. - for node in func.nodes: - for ir in node.irs: - if isinstance(ir, (InternalCall, SolidityCall)): - result.append(ir.function) + + # Loop through all internal and solidity calls in the function, add them to a list. + for ir in func.internal_calls + func.solidity_calls: + result.append(ir.function) + return result @staticmethod @@ -101,6 +100,7 @@ class ExternalFunction(AbstractDetector): # Somehow we couldn't resolve it, which shouldn't happen, as the provided function should be found if we could # not find some any more basic. + # pylint: disable=broad-exception-raised raise Exception("Could not resolve the base-most function for the provided function.") @staticmethod diff --git a/slither/detectors/functions/gelato_unprotected_randomness.py b/slither/detectors/functions/gelato_unprotected_randomness.py new file mode 100644 index 000000000..bdc3a6fb0 --- /dev/null +++ b/slither/detectors/functions/gelato_unprotected_randomness.py @@ -0,0 +1,78 @@ +from typing import List + +from slither.slithir.operations.internal_call import InternalCall +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) +from slither.utils.output import Output + + +class GelatoUnprotectedRandomness(AbstractDetector): + """ + Unprotected Gelato VRF requests + """ + + ARGUMENT = "gelato-unprotected-randomness" + HELP = "Call to _requestRandomness within an unprotected function" + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.MEDIUM + + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#gelato-unprotected-randomness" + + WIKI_TITLE = "Gelato unprotected randomness" + WIKI_DESCRIPTION = "Detect calls to `_requestRandomness` within an unprotected function." + + # region wiki_exploit_scenario + WIKI_EXPLOIT_SCENARIO = """ +```solidity +contract C is GelatoVRFConsumerBase { + function _fulfillRandomness( + uint256 randomness, + uint256, + bytes memory extraData + ) internal override { + // Do something with the random number + } + + function bad() public { + _requestRandomness(abi.encode(msg.sender)); + } +} +``` +The function `bad` is uprotected and requests randomness.""" + # endregion wiki_exploit_scenario + + WIKI_RECOMMENDATION = ( + "Function that request randomness should be allowed only to authorized users." + ) + + def _detect(self) -> List[Output]: + results = [] + + for contract in self.compilation_unit.contracts_derived: + if "GelatoVRFConsumerBase" in [c.name for c in contract.inheritance]: + for function in contract.functions_entry_points: + if not function.is_protected() and ( + nodes_request := [ + ir.node + for ir in function.all_internal_calls() + if isinstance(ir, InternalCall) + and ir.function_name == "_requestRandomness" + ] + ): + # Sort so output is deterministic + nodes_request.sort(key=lambda x: (x.node_id, x.function.full_name)) + + for node in nodes_request: + info: DETECTOR_INFO = [ + function, + " is unprotected and request randomness from Gelato VRF\n\t- ", + node, + "\n", + ] + res = self.generate_result(info) + results.append(res) + + return results diff --git a/slither/detectors/functions/modifier.py b/slither/detectors/functions/modifier.py index 7f1487266..a888d5b70 100644 --- a/slither/detectors/functions/modifier.py +++ b/slither/detectors/functions/modifier.py @@ -17,7 +17,7 @@ from slither.utils.output import Output def is_revert(node: Node) -> bool: return node.type == NodeType.THROW or any( - c.name in ["revert()", "revert(string"] for c in node.internal_calls + ir.function.name in ["revert()", "revert(string"] for ir in node.internal_calls ) diff --git a/slither/detectors/functions/optimism_deprecation.py b/slither/detectors/functions/optimism_deprecation.py new file mode 100644 index 000000000..752e8bb2d --- /dev/null +++ b/slither/detectors/functions/optimism_deprecation.py @@ -0,0 +1,92 @@ +from typing import List + +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) +from slither.core.cfg.node import Node +from slither.core.variables.variable import Variable +from slither.core.expressions import TypeConversion, Literal +from slither.utils.output import Output + + +class OptimismDeprecation(AbstractDetector): + + ARGUMENT = "optimism-deprecation" + HELP = "Detect when deprecated Optimism predeploy or function is used." + IMPACT = DetectorClassification.LOW + CONFIDENCE = DetectorClassification.HIGH + + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#optimism-deprecation" + + WIKI_TITLE = "Optimism deprecated predeploy or function" + WIKI_DESCRIPTION = "Detect when deprecated Optimism predeploy or function is used." + + # region wiki_exploit_scenario + WIKI_EXPLOIT_SCENARIO = """ +```solidity +interface GasPriceOracle { + function scalar() external view returns (uint256); +} + +contract Test { + GasPriceOracle constant OPT_GAS = GasPriceOracle(0x420000000000000000000000000000000000000F); + + function a() public { + OPT_GAS.scalar(); + } +} +``` +The call to the `scalar` function of the Optimism GasPriceOracle predeploy always revert. +""" + # endregion wiki_exploit_scenario + + WIKI_RECOMMENDATION = "Do not use the deprecated components." + + def _detect(self) -> List[Output]: + results = [] + + deprecated_predeploys = [ + "0x4200000000000000000000000000000000000000", # LegacyMessagePasser + "0x4200000000000000000000000000000000000001", # L1MessageSender + "0x4200000000000000000000000000000000000002", # DeployerWhitelist + "0x4200000000000000000000000000000000000013", # L1BlockNumber + ] + + for contract in self.compilation_unit.contracts_derived: + use_deprecated: List[Node] = [] + + for _, ir in contract.all_high_level_calls: + # To avoid FPs we assume predeploy contracts are always assigned to a constant and typecasted to an interface + # and we check the target address of a high level call. + if ( + isinstance(ir.destination, Variable) + and isinstance(ir.destination.expression, TypeConversion) + and isinstance(ir.destination.expression.expression, Literal) + ): + if ir.destination.expression.expression.value in deprecated_predeploys: + use_deprecated.append(ir.node) + + if ( + ir.destination.expression.expression.value + == "0x420000000000000000000000000000000000000F" + and ir.function_name in ("overhead", "scalar", "getL1GasUsed") + ): + use_deprecated.append(ir.node) + # Sort so output is deterministic + use_deprecated.sort(key=lambda x: (x.node_id, x.function.full_name)) + if len(use_deprecated) > 0: + info: DETECTOR_INFO = [ + "A deprecated Optimism predeploy or function is used in the ", + contract.name, + " contract.\n", + ] + + for node in use_deprecated: + info.extend(["\t - ", node, "\n"]) + + res = self.generate_result(info) + results.append(res) + + return results diff --git a/slither/detectors/functions/out_of_order_retryable.py b/slither/detectors/functions/out_of_order_retryable.py index db9096f95..a11e31ef4 100644 --- a/slither/detectors/functions/out_of_order_retryable.py +++ b/slither/detectors/functions/out_of_order_retryable.py @@ -101,9 +101,9 @@ Bob calls `doStuffOnL2` but the first retryable ticket calling `claim_rewards` f # include ops from internal function calls internal_ops = [] - for internal_call in node.internal_calls: - if isinstance(internal_call, Function): - internal_ops += internal_call.all_slithir_operations() + for ir in node.internal_calls: + if isinstance(ir.function, Function): + internal_ops += ir.function.all_slithir_operations() # analyze node for retryable tickets for ir in node.irs + internal_ops: diff --git a/slither/detectors/functions/protected_variable.py b/slither/detectors/functions/protected_variable.py index 579672926..b9260abd6 100644 --- a/slither/detectors/functions/protected_variable.py +++ b/slither/detectors/functions/protected_variable.py @@ -61,7 +61,9 @@ contract Buggy{ if not function_protection: self.logger.error(f"{function_sig} not found") continue - if function_protection not in function.all_internal_calls(): + if function_protection not in [ + ir.function for ir in function.all_internal_calls() + ]: info: DETECTOR_INFO = [ function, " should have ", diff --git a/slither/detectors/functions/pyth_deprecated_functions.py b/slither/detectors/functions/pyth_deprecated_functions.py new file mode 100644 index 000000000..87cff9181 --- /dev/null +++ b/slither/detectors/functions/pyth_deprecated_functions.py @@ -0,0 +1,73 @@ +from typing import List + +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) +from slither.utils.output import Output + + +class PythDeprecatedFunctions(AbstractDetector): + """ + Documentation: This detector finds deprecated Pyth function calls + """ + + ARGUMENT = "pyth-deprecated-functions" + HELP = "Detect Pyth deprecated functions" + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.HIGH + + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#pyth-deprecated-functions" + WIKI_TITLE = "Pyth deprecated functions" + WIKI_DESCRIPTION = "Detect when a Pyth deprecated function is used" + WIKI_RECOMMENDATION = ( + "Do not use deprecated Pyth functions. Visit https://api-reference.pyth.network/." + ) + + WIKI_EXPLOIT_SCENARIO = """ +```solidity +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; + +contract C { + + IPyth pyth; + + constructor(IPyth _pyth) { + pyth = _pyth; + } + + function A(bytes32 priceId) public { + PythStructs.Price memory price = pyth.getPrice(priceId); + ... + } +} +``` +The function `A` uses the deprecated `getPrice` Pyth function. +""" + + def _detect(self): + DEPRECATED_PYTH_FUNCTIONS = [ + "getValidTimePeriod", + "getEmaPrice", + "getPrice", + ] + results: List[Output] = [] + + for contract in self.compilation_unit.contracts_derived: + for target_contract, ir in contract.all_high_level_calls: + if ( + target_contract.name == "IPyth" + and ir.function_name in DEPRECATED_PYTH_FUNCTIONS + ): + info: DETECTOR_INFO = [ + "The following Pyth deprecated function is used\n\t- ", + ir.node, + "\n", + ] + + res = self.generate_result(info) + results.append(res) + + return results diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py index f0af978ec..7c7d87f8a 100644 --- a/slither/detectors/functions/suicidal.py +++ b/slither/detectors/functions/suicidal.py @@ -59,7 +59,7 @@ Bob calls `kill` and destructs the contract.""" if func.visibility not in ["public", "external"]: return False - calls = [c.name for c in func.all_internal_calls()] + calls = [ir.function.name for ir in func.all_internal_calls()] if not ("suicide(address)" in calls or "selfdestruct(address)" in calls): return False diff --git a/slither/detectors/operations/encode_packed.py b/slither/detectors/operations/encode_packed.py index ea7b094df..b661ddcd7 100644 --- a/slither/detectors/operations/encode_packed.py +++ b/slither/detectors/operations/encode_packed.py @@ -3,14 +3,14 @@ Module detecting usage of more than one dynamic type in abi.encodePacked() argum """ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.core.declarations.solidity_variables import SolidityFunction -from slither.slithir.operations import SolidityCall +from slither.core.declarations import Contract, SolidityFunction +from slither.core.variables import Variable from slither.analyses.data_dependency.data_dependency import is_tainted from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ArrayType -def _is_dynamic_type(arg): +def _is_dynamic_type(arg: Variable): """ Args: arg (function argument) @@ -25,7 +25,7 @@ def _is_dynamic_type(arg): return False -def _detect_abi_encodePacked_collision(contract): +def _detect_abi_encodePacked_collision(contract: Contract): """ Args: contract (Contract) @@ -35,22 +35,19 @@ def _detect_abi_encodePacked_collision(contract): ret = [] # pylint: disable=too-many-nested-blocks for f in contract.functions_and_modifiers_declared: - for n in f.nodes: - for ir in n.irs: - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "abi.encodePacked()" - ): - dynamic_type_count = 0 - for arg in ir.arguments: - if is_tainted(arg, contract) and _is_dynamic_type(arg): - dynamic_type_count += 1 - elif dynamic_type_count > 1: - ret.append((f, n)) - dynamic_type_count = 0 - else: - dynamic_type_count = 0 - if dynamic_type_count > 1: - ret.append((f, n)) + for ir in f.solidity_calls: + if ir.function == SolidityFunction("abi.encodePacked()"): + dynamic_type_count = 0 + for arg in ir.arguments: + if is_tainted(arg, contract) and _is_dynamic_type(arg): + dynamic_type_count += 1 + elif dynamic_type_count > 1: + ret.append((f, ir.node)) + dynamic_type_count = 0 + else: + dynamic_type_count = 0 + if dynamic_type_count > 1: + ret.append((f, ir.node)) return ret diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py index 463c74875..4925fc466 100644 --- a/slither/detectors/operations/low_level_calls.py +++ b/slither/detectors/operations/low_level_calls.py @@ -44,10 +44,9 @@ class LowLevelCalls(AbstractDetector): ) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in [f for f in contract.functions if contract == f.contract_declarer]: - nodes = f.nodes - assembly_nodes = [n for n in nodes if self._contains_low_level_calls(n)] - if assembly_nodes: - ret.append((f, assembly_nodes)) + low_level_nodes = [ir.node for ir in f.low_level_calls] + if low_level_nodes: + ret.append((f, low_level_nodes)) return ret def _detect(self) -> List[Output]: diff --git a/slither/detectors/reentrancy/reentrancy.py b/slither/detectors/reentrancy/reentrancy.py index 8dd9aecc0..2982801cb 100644 --- a/slither/detectors/reentrancy/reentrancy.py +++ b/slither/detectors/reentrancy/reentrancy.py @@ -145,15 +145,16 @@ class AbstractState: ) slithir_operations = [] # Add the state variables written in internal calls - for internal_call in node.internal_calls: + for ir in node.internal_calls: # Filter to Function, as internal_call can be a solidity call - if isinstance(internal_call, Function): - for internal_node in internal_call.all_nodes(): + function = ir.function + if isinstance(function, Function): + for internal_node in function.all_nodes(): for read in internal_node.state_variables_read: state_vars_read[read].add(internal_node) for write in internal_node.state_variables_written: state_vars_written[write].add(internal_node) - slithir_operations += internal_call.all_slithir_operations() + slithir_operations += function.all_slithir_operations() contains_call = False diff --git a/slither/detectors/statements/assert_state_change.py b/slither/detectors/statements/assert_state_change.py index 769d730b8..d70495365 100644 --- a/slither/detectors/statements/assert_state_change.py +++ b/slither/detectors/statements/assert_state_change.py @@ -30,22 +30,22 @@ def detect_assert_state_change( # Loop for each function and modifier. for function in contract.functions_declared + list(contract.modifiers_declared): - for node in function.nodes: + for ir_call in function.internal_calls: # Detect assert() calls - if any(c.name == "assert(bool)" for c in node.internal_calls) and ( + if ir_call.function.name == "assert(bool)" and ( # Detect direct changes to state - node.state_variables_written + ir_call.node.state_variables_written or # Detect changes to state via function calls any( ir - for ir in node.irs + for ir in ir_call.node.irs if isinstance(ir, InternalCall) and ir.function and ir.function.state_variables_written ) ): - results.append((function, node)) + results.append((function, ir_call.node)) # Return the resulting set of nodes return results diff --git a/slither/detectors/statements/chronicle_unchecked_price.py b/slither/detectors/statements/chronicle_unchecked_price.py new file mode 100644 index 000000000..47ad2ddc5 --- /dev/null +++ b/slither/detectors/statements/chronicle_unchecked_price.py @@ -0,0 +1,147 @@ +from typing import List + +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) +from slither.utils.output import Output +from slither.slithir.operations import Binary, Assignment, Unpack, SolidityCall +from slither.core.variables import Variable +from slither.core.declarations.solidity_variables import SolidityFunction +from slither.core.cfg.node import Node + + +class ChronicleUncheckedPrice(AbstractDetector): + """ + Documentation: This detector finds calls to Chronicle oracle where the returned price is not checked + https://docs.chroniclelabs.org/Resources/FAQ/Oracles#how-do-i-check-if-an-oracle-becomes-inactive-gets-deprecated + """ + + ARGUMENT = "chronicle-unchecked-price" + HELP = "Detect when Chronicle price is not checked." + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.MEDIUM + + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#chronicle-unchecked-price" + + WIKI_TITLE = "Chronicle unchecked price" + WIKI_DESCRIPTION = "Chronicle oracle is used and the price returned is not checked to be valid. For more information https://docs.chroniclelabs.org/Resources/FAQ/Oracles#how-do-i-check-if-an-oracle-becomes-inactive-gets-deprecated." + + # region wiki_exploit_scenario + WIKI_EXPLOIT_SCENARIO = """ +```solidity +contract C { + IChronicle chronicle; + + constructor(address a) { + chronicle = IChronicle(a); + } + + function bad() public { + uint256 price = chronicle.read(); + } +``` +The `bad` function gets the price from Chronicle by calling the read function however it does not check if the price is valid.""" + # endregion wiki_exploit_scenario + + WIKI_RECOMMENDATION = "Validate that the price returned by the oracle is valid." + + def _var_is_checked(self, nodes: List[Node], var_to_check: Variable) -> bool: + visited = set() + checked = False + + while nodes: + if checked: + break + next_node = nodes[0] + nodes = nodes[1:] + + for node_ir in next_node.all_slithir_operations(): + if isinstance(node_ir, Binary) and var_to_check in node_ir.read: + checked = True + break + # This case is for tryRead and tryReadWithAge + # if the isValid boolean is checked inside a require(isValid) + if ( + isinstance(node_ir, SolidityCall) + and node_ir.function + in ( + SolidityFunction("require(bool)"), + SolidityFunction("require(bool,string)"), + SolidityFunction("require(bool,error)"), + ) + and var_to_check in node_ir.read + ): + checked = True + break + + if next_node not in visited: + visited.add(next_node) + for son in next_node.sons: + if son not in visited: + nodes.append(son) + return checked + + # pylint: disable=too-many-nested-blocks,too-many-branches + def _detect(self) -> List[Output]: + results: List[Output] = [] + + for contract in self.compilation_unit.contracts_derived: + for target_contract, ir in sorted( + contract.all_high_level_calls, + key=lambda x: (x[1].node.node_id, x[1].node.function.full_name), + ): + if target_contract.name in ("IScribe", "IChronicle") and ir.function_name in ( + "read", + "tryRead", + "readWithAge", + "tryReadWithAge", + "latestAnswer", + "latestRoundData", + ): + found = False + if ir.function_name in ("read", "latestAnswer"): + # We need to iterate the IRs as we are not always sure that the following IR is the assignment + # for example in case of type conversion it isn't + for node_ir in ir.node.irs: + if isinstance(node_ir, Assignment): + possible_unchecked_variable_ir = node_ir.lvalue + found = True + break + elif ir.function_name in ("readWithAge", "tryRead", "tryReadWithAge"): + # We are interested in the first item of the tuple + # readWithAge : value + # tryRead/tryReadWithAge : isValid + for node_ir in ir.node.irs: + if isinstance(node_ir, Unpack) and node_ir.index == 0: + possible_unchecked_variable_ir = node_ir.lvalue + found = True + break + elif ir.function_name == "latestRoundData": + found = False + for node_ir in ir.node.irs: + if isinstance(node_ir, Unpack) and node_ir.index == 1: + possible_unchecked_variable_ir = node_ir.lvalue + found = True + break + + # If we did not find the variable assignment we know it's not checked + checked = ( + self._var_is_checked(ir.node.sons, possible_unchecked_variable_ir) + if found + else False + ) + + if not checked: + info: DETECTOR_INFO = [ + "Chronicle price is not checked to be valid in ", + ir.node.function, + "\n\t- ", + ir.node, + "\n", + ] + res = self.generate_result(info) + results.append(res) + + return results diff --git a/slither/detectors/statements/controlled_delegatecall.py b/slither/detectors/statements/controlled_delegatecall.py index 32e59d6eb..bf78b3bf9 100644 --- a/slither/detectors/statements/controlled_delegatecall.py +++ b/slither/detectors/statements/controlled_delegatecall.py @@ -8,20 +8,18 @@ from slither.detectors.abstract_detector import ( DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import LowLevelCall from slither.utils.output import Output def controlled_delegatecall(function: FunctionContract) -> List[Node]: ret = [] - for node in function.nodes: - for ir in node.irs: - if isinstance(ir, LowLevelCall) and ir.function_name in [ - "delegatecall", - "callcode", - ]: - if is_tainted(ir.destination, function.contract): - ret.append(node) + for ir in function.low_level_calls: + if ir.function_name in [ + "delegatecall", + "callcode", + ]: + if is_tainted(ir.destination, function.contract): + ret.append(ir.node) return ret diff --git a/slither/detectors/statements/divide_before_multiply.py b/slither/detectors/statements/divide_before_multiply.py index e33477135..6734fb239 100644 --- a/slither/detectors/statements/divide_before_multiply.py +++ b/slither/detectors/statements/divide_before_multiply.py @@ -56,7 +56,7 @@ def is_assert(node: Node) -> bool: # Old Solidity code where using an internal 'assert(bool)' function # While we dont check that this function is correct, we assume it is # To avoid too many FP - if "assert(bool)" in [c.full_name for c in node.internal_calls]: + if "assert(bool)" in [ir.function.full_name for ir in node.internal_calls]: return True return False diff --git a/slither/detectors/statements/pyth_unchecked.py b/slither/detectors/statements/pyth_unchecked.py new file mode 100644 index 000000000..959aee6a5 --- /dev/null +++ b/slither/detectors/statements/pyth_unchecked.py @@ -0,0 +1,79 @@ +from typing import List + +from slither.detectors.abstract_detector import ( + AbstractDetector, + DETECTOR_INFO, +) +from slither.utils.output import Output +from slither.slithir.operations import Member, Binary, Assignment + + +class PythUnchecked(AbstractDetector): + """ + Documentation: This detector finds deprecated Pyth function calls + """ + + # To be overriden in the derived class + PYTH_FUNCTIONS = [] + PYTH_FIELD = "" + + # pylint: disable=too-many-nested-blocks + def _detect(self) -> List[Output]: + results: List[Output] = [] + + for contract in self.compilation_unit.contracts_derived: + for target_contract, ir in contract.all_high_level_calls: + if target_contract.name == "IPyth" and ir.function_name in self.PYTH_FUNCTIONS: + # We know for sure the second IR in the node is an Assignment operation of the TMP variable. Example: + # Expression: price = pyth.getEmaPriceNoOlderThan(id,age) + # IRs: + # TMP_0(PythStructs.Price) = HIGH_LEVEL_CALL, dest:pyth(IPyth), function:getEmaPriceNoOlderThan, arguments:['id', 'age'] + # price(PythStructs.Price) := TMP_0(PythStructs.Price) + assert isinstance(ir.node.irs[1], Assignment) + return_variable = ir.node.irs[1].lvalue + checked = False + + possible_unchecked_variable_ir = None + nodes = ir.node.sons + visited = set() + while nodes: + if checked: + break + next_node = nodes[0] + nodes = nodes[1:] + + for node_ir in next_node.all_slithir_operations(): + # We are accessing the unchecked_var field of the returned Price struct + if ( + isinstance(node_ir, Member) + and node_ir.variable_left == return_variable + and node_ir.variable_right.name == self.PYTH_FIELD + ): + possible_unchecked_variable_ir = node_ir.lvalue + # We assume that if unchecked_var happens to be inside a binary operation is checked + if ( + isinstance(node_ir, Binary) + and possible_unchecked_variable_ir is not None + and possible_unchecked_variable_ir in node_ir.read + ): + checked = True + break + + if next_node not in visited: + visited.add(next_node) + for son in next_node.sons: + if son not in visited: + nodes.append(son) + + if not checked: + info: DETECTOR_INFO = [ + f"Pyth price {self.PYTH_FIELD} field is not checked in ", + ir.node.function, + "\n\t- ", + ir.node, + "\n", + ] + res = self.generate_result(info) + results.append(res) + + return results diff --git a/slither/detectors/statements/pyth_unchecked_confidence.py b/slither/detectors/statements/pyth_unchecked_confidence.py new file mode 100644 index 000000000..2e99851a8 --- /dev/null +++ b/slither/detectors/statements/pyth_unchecked_confidence.py @@ -0,0 +1,50 @@ +from slither.detectors.abstract_detector import DetectorClassification +from slither.detectors.statements.pyth_unchecked import PythUnchecked + + +class PythUncheckedConfidence(PythUnchecked): + """ + Documentation: This detector finds when the confidence level of a Pyth price is not checked + """ + + ARGUMENT = "pyth-unchecked-confidence" + HELP = "Detect when the confidence level of a Pyth price is not checked" + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.HIGH + + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#pyth-unchecked-confidence" + WIKI_TITLE = "Pyth unchecked confidence level" + WIKI_DESCRIPTION = "Detect when the confidence level of a Pyth price is not checked" + WIKI_RECOMMENDATION = "Check the confidence level of a Pyth price. Visit https://docs.pyth.network/price-feeds/best-practices#confidence-intervals for more information." + + WIKI_EXPLOIT_SCENARIO = """ +```solidity +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; + +contract C { + IPyth pyth; + + constructor(IPyth _pyth) { + pyth = _pyth; + } + + function bad(bytes32 id, uint256 age) public { + PythStructs.Price memory price = pyth.getEmaPriceNoOlderThan(id, age); + // Use price + } +} +``` +The function `A` uses the price without checking its confidence level. +""" + + PYTH_FUNCTIONS = [ + "getEmaPrice", + "getEmaPriceNoOlderThan", + "getEmaPriceUnsafe", + "getPrice", + "getPriceNoOlderThan", + "getPriceUnsafe", + ] + + PYTH_FIELD = "conf" diff --git a/slither/detectors/statements/pyth_unchecked_publishtime.py b/slither/detectors/statements/pyth_unchecked_publishtime.py new file mode 100644 index 000000000..e3e2010d6 --- /dev/null +++ b/slither/detectors/statements/pyth_unchecked_publishtime.py @@ -0,0 +1,52 @@ +from slither.detectors.abstract_detector import DetectorClassification +from slither.detectors.statements.pyth_unchecked import PythUnchecked + + +class PythUncheckedPublishTime(PythUnchecked): + """ + Documentation: This detector finds when the publishTime of a Pyth price is not checked + """ + + ARGUMENT = "pyth-unchecked-publishtime" + HELP = "Detect when the publishTime of a Pyth price is not checked" + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.HIGH + + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#pyth-unchecked-publishtime" + ) + WIKI_TITLE = "Pyth unchecked publishTime" + WIKI_DESCRIPTION = "Detect when the publishTime of a Pyth price is not checked" + WIKI_RECOMMENDATION = "Check the publishTime of a Pyth price." + + WIKI_EXPLOIT_SCENARIO = """ +```solidity +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; + +contract C { + IPyth pyth; + + constructor(IPyth _pyth) { + pyth = _pyth; + } + + function bad(bytes32 id) public { + PythStructs.Price memory price = pyth.getEmaPriceUnsafe(id); + // Use price + } +} +``` +The function `A` uses the price without checking its `publishTime` coming from the `getEmaPriceUnsafe` function. +""" + + PYTH_FUNCTIONS = [ + "getEmaPrice", + # "getEmaPriceNoOlderThan", + "getEmaPriceUnsafe", + "getPrice", + # "getPriceNoOlderThan", + "getPriceUnsafe", + ] + + PYTH_FIELD = "publishTime" diff --git a/slither/detectors/statements/return_bomb.py b/slither/detectors/statements/return_bomb.py index 8b6cd07a2..6d7052cf4 100644 --- a/slither/detectors/statements/return_bomb.py +++ b/slither/detectors/statements/return_bomb.py @@ -9,7 +9,7 @@ from slither.detectors.abstract_detector import ( DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import LowLevelCall, HighLevelCall +from slither.slithir.operations import HighLevelCall from slither.analyses.data_dependency.data_dependency import is_tainted from slither.utils.output import Output @@ -71,34 +71,31 @@ Callee unexpectedly makes the caller OOG. def get_nodes_for_function(self, function: Function, contract: Contract) -> List[Node]: nodes = [] - for node in function.nodes: - for ir in node.irs: - if isinstance(ir, (HighLevelCall, LowLevelCall)): - if not is_tainted(ir.destination, contract): # type:ignore - # Only interested if the target address is controlled/tainted - continue - - if isinstance(ir, HighLevelCall) and isinstance(ir.function, Function): - # in normal highlevel calls return bombs are _possible_ - # if the return type is dynamic and the caller tries to copy and decode large data - has_dyn = False - if ir.function.return_type: - has_dyn = any( - self.is_dynamic_type(ty) for ty in ir.function.return_type - ) - - if not has_dyn: - continue - - # If a gas budget was specified then the - # user may not know about the return bomb - if ir.call_gas is None: - # if a gas budget was NOT specified then the caller - # may already suspect the call may spend all gas? - continue - - nodes.append(node) - # TODO: check that there is some state change after the call + + for ir in [ir for _, ir in function.high_level_calls] + function.low_level_calls: + if not is_tainted(ir.destination, contract): # type:ignore + # Only interested if the target address is controlled/tainted + continue + + if isinstance(ir, HighLevelCall) and isinstance(ir.function, Function): + # in normal highlevel calls return bombs are _possible_ + # if the return type is dynamic and the caller tries to copy and decode large data + has_dyn = False + if ir.function.return_type: + has_dyn = any(self.is_dynamic_type(ty) for ty in ir.function.return_type) + + if not has_dyn: + continue + + # If a gas budget was specified then the + # user may not know about the return bomb + if ir.call_gas is None: + # if a gas budget was NOT specified then the caller + # may already suspect the call may spend all gas? + continue + + nodes.append(ir.node) + # TODO: check that there is some state change after the call return nodes diff --git a/slither/detectors/statements/unprotected_upgradeable.py b/slither/detectors/statements/unprotected_upgradeable.py index d25aff187..aeb785da3 100644 --- a/slither/detectors/statements/unprotected_upgradeable.py +++ b/slither/detectors/statements/unprotected_upgradeable.py @@ -7,23 +7,28 @@ from slither.detectors.abstract_detector import ( DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations import LowLevelCall, SolidityCall from slither.utils.output import Output def _can_be_destroyed(contract: Contract) -> List[Function]: targets = [] for f in contract.functions_entry_points: - for ir in f.all_slithir_operations(): - if ( - isinstance(ir, LowLevelCall) and ir.function_name in ["delegatecall", "codecall"] - ) or ( - isinstance(ir, SolidityCall) - and ir.function - in [SolidityFunction("suicide(address)"), SolidityFunction("selfdestruct(address)")] - ): + found = False + for ir in f.all_low_level_calls(): + if ir.function_name in ["delegatecall", "codecall"]: targets.append(f) + found = True break + + if not found: + for ir in f.all_solidity_calls(): + if ir.function in [ + SolidityFunction("suicide(address)"), + SolidityFunction("selfdestruct(address)"), + ]: + targets.append(f) + break + return targets @@ -35,8 +40,8 @@ def _has_initializing_protection(functions: List[Function]) -> bool: for m in f.modifiers: if m.name == "initializer": return True - for ifc in f.all_internal_calls(): - if ifc.name == "_disableInitializers": + for ir in f.all_internal_calls(): + if ir.function.name == "_disableInitializers": return True # to avoid future FPs in different modifier + function naming implementations, we can also implement a broader check for state var "_initialized" being written to in the constructor diff --git a/slither/detectors/variables/unchanged_state_variables.py b/slither/detectors/variables/unchanged_state_variables.py index 5771d9630..64c4c350f 100644 --- a/slither/detectors/variables/unchanged_state_variables.py +++ b/slither/detectors/variables/unchanged_state_variables.py @@ -92,7 +92,7 @@ class UnchangedStateVariables: variables = [] functions = [] - variables.append(c.stored_state_variables) + variables.append(c.storage_variables_ordered) functions.append(c.all_functions_called) valid_candidates: Set[StateVariable] = { diff --git a/slither/detectors/variables/var_read_using_this.py b/slither/detectors/variables/var_read_using_this.py index 537eecf8a..1e4787e36 100644 --- a/slither/detectors/variables/var_read_using_this.py +++ b/slither/detectors/variables/var_read_using_this.py @@ -7,7 +7,6 @@ from slither.detectors.abstract_detector import ( DetectorClassification, DETECTOR_INFO, ) -from slither.slithir.operations.high_level_call import HighLevelCall from slither.utils.output import Output @@ -54,13 +53,11 @@ contract C { @staticmethod def _detect_var_read_using_this(func: Function) -> List[Node]: results: List[Node] = [] - for node in func.nodes: - for ir in node.irs: - if isinstance(ir, HighLevelCall): - if ( - ir.destination == SolidityVariable("this") - and ir.is_static_call() - and ir.function.visibility == "public" - ): - results.append(node) + for _, ir in func.high_level_calls: + if ( + ir.destination == SolidityVariable("this") + and ir.is_static_call() + and ir.function.visibility == "public" + ): + results.append(ir.node) return sorted(results, key=lambda x: x.node_id) diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index 38225e6d7..668606760 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -10,6 +10,7 @@ from typing import Optional, Union, Dict, Set, Tuple, Sequence from slither.core.declarations import Contract, FunctionContract from slither.core.declarations.function import Function +from slither.slithir.operations import HighLevelCall, InternalCall from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter @@ -49,26 +50,26 @@ def _node(node: str, label: Optional[str] = None) -> str: def _process_internal_call( contract: Contract, function: Function, - internal_call: Union[Function, SolidityFunction], + internal_call: InternalCall, contract_calls: Dict[Contract, Set[str]], solidity_functions: Set[str], solidity_calls: Set[str], ) -> None: - if isinstance(internal_call, (Function)): + if isinstance(internal_call.function, (Function)): contract_calls[contract].add( _edge( _function_node(contract, function), - _function_node(contract, internal_call), + _function_node(contract, internal_call.function), ) ) - elif isinstance(internal_call, (SolidityFunction)): + elif isinstance(internal_call.function, (SolidityFunction)): solidity_functions.add( - _node(_solidity_function_node(internal_call)), + _node(_solidity_function_node(internal_call.function)), ) solidity_calls.add( _edge( _function_node(contract, function), - _solidity_function_node(internal_call), + _solidity_function_node(internal_call.function), ) ) @@ -112,29 +113,29 @@ def _render_solidity_calls(solidity_functions: Set[str], solidity_calls: Set[str def _process_external_call( contract: Contract, function: Function, - external_call: Tuple[Contract, Union[Function, Variable]], + external_call: Tuple[Contract, HighLevelCall], contract_functions: Dict[Contract, Set[str]], external_calls: Set[str], all_contracts: Set[Contract], ) -> None: - external_contract, external_function = external_call + external_contract, ir = external_call if not external_contract in all_contracts: return # add variable as node to respective contract - if isinstance(external_function, (Variable)): + if isinstance(ir.function, (Variable)): contract_functions[external_contract].add( _node( - _function_node(external_contract, external_function), - external_function.name, + _function_node(external_contract, ir.function), + ir.function.name, ) ) external_calls.add( _edge( _function_node(contract, function), - _function_node(external_contract, external_function), + _function_node(external_contract, ir.function), ) ) diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index 32efeaabe..288392a46 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -19,7 +19,11 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): @staticmethod def get_msg_sender_checks(function: Function) -> List[str]: all_functions = ( - [f for f in function.all_internal_calls() if isinstance(f, Function)] + [ + ir.function + for ir in function.all_internal_calls() + if isinstance(ir.function, Function) + ] + [function] + [m for m in function.modifiers if isinstance(m, Function)] ) diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 0c47fa0f9..7e76cec0d 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -13,6 +13,7 @@ from slither.core.declarations.solidity_variables import ( from slither.core.expressions import NewContract from slither.core.slither_core import SlitherCore from slither.core.solidity_types import TypeAlias +from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter @@ -140,8 +141,8 @@ def _extract_assert(contracts: List[Contract]) -> Dict[str, Dict[str, List[Dict] for contract in contracts: functions_using_assert = [] # Dict[str, List[Dict]] = defaultdict(list) for f in contract.functions_entry_points: - for v in f.all_solidity_calls(): - if v == SolidityFunction("assert(bool)"): + for ir in f.all_solidity_calls(): + if ir.function == SolidityFunction("assert(bool)"): functions_using_assert.append(_get_name(f)) break # Revert https://github.com/crytic/slither/pull/2105 until format is supported by echidna. @@ -156,7 +157,7 @@ def _extract_assert(contracts: List[Contract]) -> Dict[str, Dict[str, List[Dict] # Create a named tuple that is serialization in json def json_serializable(cls): - # pylint: disable=unnecessary-comprehension + # pylint: disable=unnecessary-comprehension,unnecessary-dunder-call # TODO: the next line is a quick workaround to prevent pylint from crashing # It can be removed once https://github.com/PyCQA/pylint/pull/3810 is merged my_super = super @@ -179,7 +180,66 @@ class ConstantValue(NamedTuple): # pylint: disable=inherit-non-class,too-few-pu type: str -def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks +def _extract_constant_from_read( + ir: Operation, + r: SourceMapping, + all_cst_used: List[ConstantValue], + all_cst_used_in_binary: Dict[str, List[ConstantValue]], + context_explored: Set[Node], +) -> None: + var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r + # Do not report struct_name in a.struct_name + if isinstance(ir, Member): + return + if isinstance(var_read, Variable) and var_read.is_constant: + # In case of type conversion we use the destination type + if isinstance(ir, TypeConversion): + if isinstance(ir.type, TypeAlias): + value_type = ir.type.type + else: + value_type = ir.type + else: + value_type = var_read.type + try: + value = ConstantFolding(var_read.expression, value_type).result() + all_cst_used.append(ConstantValue(str(value), str(value_type))) + except NotConstant: + pass + if isinstance(var_read, Constant): + all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) + if isinstance(var_read, StateVariable): + if var_read.node_initialization: + if var_read.node_initialization.irs: + if var_read.node_initialization in context_explored: + return + context_explored.add(var_read.node_initialization) + _extract_constants_from_irs( + var_read.node_initialization.irs, + all_cst_used, + all_cst_used_in_binary, + context_explored, + ) + + +def _extract_constant_from_binary( + ir: Binary, + all_cst_used: List[ConstantValue], + all_cst_used_in_binary: Dict[str, List[ConstantValue]], +): + for r in ir.read: + if isinstance(r, Constant): + all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type))) + if isinstance(ir.variable_left, Constant) or isinstance(ir.variable_right, Constant): + if ir.lvalue: + try: + type_ = ir.lvalue.type + cst = ConstantFolding(ir.expression, type_).result() + all_cst_used.append(ConstantValue(str(cst.value), str(type_))) + except NotConstant: + pass + + +def _extract_constants_from_irs( irs: List[Operation], all_cst_used: List[ConstantValue], all_cst_used_in_binary: Dict[str, List[ConstantValue]], @@ -187,21 +247,7 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n ) -> None: for ir in irs: if isinstance(ir, Binary): - for r in ir.read: - if isinstance(r, Constant): - all_cst_used_in_binary[str(ir.type)].append( - ConstantValue(str(r.value), str(r.type)) - ) - if isinstance(ir.variable_left, Constant) or isinstance( - ir.variable_right, Constant - ): - if ir.lvalue: - try: - type_ = ir.lvalue.type - cst = ConstantFolding(ir.expression, type_).result() - all_cst_used.append(ConstantValue(str(cst.value), str(type_))) - except NotConstant: - pass + _extract_constant_from_binary(ir, all_cst_used, all_cst_used_in_binary) if isinstance(ir, TypeConversion): if isinstance(ir.variable, Constant): if isinstance(ir.type, TypeAlias): @@ -222,24 +268,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n except ValueError: # index could fail; should never happen in working solidity code pass for r in ir.read: - var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r - # Do not report struct_name in a.struct_name - if isinstance(ir, Member): - continue - if isinstance(var_read, Constant): - all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) - if isinstance(var_read, StateVariable): - if var_read.node_initialization: - if var_read.node_initialization.irs: - if var_read.node_initialization in context_explored: - continue - context_explored.add(var_read.node_initialization) - _extract_constants_from_irs( - var_read.node_initialization.irs, - all_cst_used, - all_cst_used_in_binary, - context_explored, - ) + _extract_constant_from_read( + ir, r, all_cst_used, all_cst_used_in_binary, context_explored + ) def _extract_constants( diff --git a/slither/printers/summary/modifier_calls.py b/slither/printers/summary/modifier_calls.py index cd6c4062e..225376a3c 100644 --- a/slither/printers/summary/modifier_calls.py +++ b/slither/printers/summary/modifier_calls.py @@ -29,12 +29,12 @@ class Modifiers(AbstractPrinter): table = MyPrettyTable(["Function", "Modifiers"]) for function in contract.functions: modifiers = function.modifiers - for call in function.all_internal_calls(): - if isinstance(call, Function): - modifiers += call.modifiers - for (_, call) in function.all_library_calls(): - if isinstance(call, Function): - modifiers += call.modifiers + for ir in function.all_internal_calls(): + if isinstance(ir.function, Function): + modifiers += ir.function.modifiers + for ir in function.all_library_calls(): + if isinstance(ir.function, Function): + modifiers += ir.function.modifiers table.add_row([function.name, sorted([m.name for m in set(modifiers)])]) txt += "\n" + str(table) self.info(txt) diff --git a/slither/printers/summary/require_calls.py b/slither/printers/summary/require_calls.py index 7823de160..ae79e9ed6 100644 --- a/slither/printers/summary/require_calls.py +++ b/slither/printers/summary/require_calls.py @@ -11,6 +11,7 @@ require_or_assert = [ SolidityFunction("assert(bool)"), SolidityFunction("require(bool)"), SolidityFunction("require(bool,string)"), + SolidityFunction("require(bool,error)"), ] diff --git a/slither/printers/summary/variable_order.py b/slither/printers/summary/variable_order.py index 0d8ce2612..fb19e3985 100644 --- a/slither/printers/summary/variable_order.py +++ b/slither/printers/summary/variable_order.py @@ -27,10 +27,17 @@ class VariableOrder(AbstractPrinter): for contract in self.slither.contracts_derived: txt += f"\n{contract.name}:\n" - table = MyPrettyTable(["Name", "Type", "Slot", "Offset"]) - for variable in contract.stored_state_variables_ordered: + table = MyPrettyTable(["Name", "Type", "Slot", "Offset", "State"]) + for variable in contract.storage_variables_ordered: slot, offset = contract.compilation_unit.storage_layout_of(contract, variable) - table.add_row([variable.canonical_name, str(variable.type), slot, offset]) + table.add_row( + [variable.canonical_name, str(variable.type), slot, offset, "Storage"] + ) + for variable in contract.transient_variables_ordered: + slot, offset = contract.compilation_unit.storage_layout_of(contract, variable) + table.add_row( + [variable.canonical_name, str(variable.type), slot, offset, "Transient"] + ) all_tables.append((contract.name, table)) txt += str(table) + "\n" diff --git a/slither/printers/summary/when_not_paused.py b/slither/printers/summary/when_not_paused.py index aaeeeacec..fc96268ef 100644 --- a/slither/printers/summary/when_not_paused.py +++ b/slither/printers/summary/when_not_paused.py @@ -11,8 +11,8 @@ from slither.utils.myprettytable import MyPrettyTable def _use_modifier(function: Function, modifier_name: str = "whenNotPaused") -> bool: - for internal_call in function.all_internal_calls(): - if isinstance(internal_call, SolidityFunction): + for ir in function.all_internal_calls(): + if isinstance(ir, function, SolidityFunction): continue if any(modifier.name == modifier_name for modifier in function.modifiers): return True diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 7d8aa543b..cfeb0fe39 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -36,7 +36,6 @@ from slither.core.solidity_types.elementary_type import ( ) from slither.core.solidity_types.type import Type from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAlias -from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.slithir.exceptions import SlithIRError @@ -81,7 +80,7 @@ from slither.slithir.tmp_operations.tmp_new_elementary_type import TmpNewElement from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure from slither.slithir.variables import Constant, ReferenceVariable, TemporaryVariable from slither.slithir.variables import TupleVariable -from slither.utils.function import get_function_id +from slither.utils.function import get_function_id, get_event_id from slither.utils.type import export_nested_types_from_variable from slither.utils.using_for import USING_FOR from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR @@ -279,20 +278,6 @@ def is_temporary(ins: Operation) -> bool: ) -def _make_function_type(func: Function) -> FunctionType: - parameters = [] - returns = [] - for parameter in func.parameters: - v = FunctionTypeVariable() - v.name = parameter.name - parameters.append(v) - for return_var in func.returns: - v = FunctionTypeVariable() - v.name = return_var.name - returns.append(v) - return FunctionType(parameters, returns) - - # endregion ################################################################################### ################################################################################### @@ -793,12 +778,29 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo assignment.set_node(ir.node) assignment.lvalue.set_type(ElementaryType("bytes4")) return assignment - if ir.variable_right == "selector" and isinstance( - ir.variable_left.type, (Function) + if ir.variable_right == "selector" and isinstance(ir.variable_left, (Event)): + # the event selector returns a bytes32, which is different from the error/function selector + # which returns a bytes4 + assignment = Assignment( + ir.lvalue, + Constant( + str(get_event_id(ir.variable_left.full_name)), ElementaryType("bytes32") + ), + ElementaryType("bytes32"), + ) + assignment.set_expression(ir.expression) + assignment.set_node(ir.node) + assignment.lvalue.set_type(ElementaryType("bytes32")) + return assignment + if ir.variable_right == "selector" and ( + isinstance(ir.variable_left.type, (Function)) ): assignment = Assignment( ir.lvalue, - Constant(str(get_function_id(ir.variable_left.type.full_name))), + Constant( + str(get_function_id(ir.variable_left.type.full_name)), + ElementaryType("bytes4"), + ), ElementaryType("bytes4"), ) assignment.set_expression(ir.expression) @@ -826,10 +828,9 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo targeted_function = next( (x for x in ir_func.contract.functions if x.name == str(ir.variable_right)) ) - t = _make_function_type(targeted_function) - ir.lvalue.set_type(t) + ir.lvalue.set_type(targeted_function) elif isinstance(left, (Variable, SolidityVariable)): - t = ir.variable_left.type + t = left.type elif isinstance(left, (Contract, Enum, Structure)): t = UserDefinedType(left) # can be None due to temporary operation @@ -846,10 +847,10 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo ir.lvalue.set_type(elems[elem].type) else: assert isinstance(type_t, Contract) - # Allow type propagtion as a Function + # Allow type propagation as a Function # Only for reference variables # This allows to track the selector keyword - # We dont need to check for function collision, as solc prevents the use of selector + # We don't need to check for function collision, as solc prevents the use of selector # if there are multiple functions with the same name f = next( (f for f in type_t.functions if f.name == ir.variable_right), @@ -858,7 +859,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo if f: ir.lvalue.set_type(f) else: - # Allow propgation for variable access through contract's name + # Allow propagation for variable access through contract's name # like Base_contract.my_variable v = next( ( diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py index 1f29ceb7b..ab6637faa 100644 --- a/slither/slithir/operations/assignment.py +++ b/slither/slithir/operations/assignment.py @@ -45,10 +45,19 @@ class Assignment(OperationWithLValue): def __str__(self) -> str: lvalue = self.lvalue + + # When rvalues are functions, we want to properly display their return type + # Fix: https://github.com/crytic/slither/issues/2266 + if isinstance(self.rvalue.type, list): + rvalue_type = ",".join(f"{rvalue_type}" for rvalue_type in self.rvalue.type) + else: + rvalue_type = f"{self.rvalue.type}" + assert lvalue if lvalue and isinstance(lvalue, ReferenceVariable): points = lvalue.points_to while isinstance(points, ReferenceVariable): points = points.points_to - return f"{lvalue}({lvalue.type}) (->{points}) := {self.rvalue}({self.rvalue.type})" - return f"{lvalue}({lvalue.type}) := {self.rvalue}({self.rvalue.type})" + return f"{lvalue}({lvalue.type}) (->{points}) := {self.rvalue}({rvalue_type})" + + return f"{lvalue}({lvalue.type}) := {self.rvalue}({rvalue_type})" diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index 0942813cf..55979572c 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -1,5 +1,5 @@ from typing import List, Union -from slither.core.declarations import Contract, Function +from slither.core.declarations import Contract, Function, Event from slither.core.declarations.custom_error import CustomError from slither.core.declarations.enum import Enum from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder @@ -33,14 +33,29 @@ class Member(OperationWithLValue): # Can be an ElementaryType because of bytes.concat, string.concat assert is_valid_rvalue(variable_left) or isinstance( variable_left, - (Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType), + ( + Contract, + Enum, + Function, + Event, + CustomError, + SolidityImportPlaceHolder, + ElementaryType, + ), ) assert isinstance(variable_right, Constant) assert isinstance(result, ReferenceVariable) super().__init__() self._variable_left: Union[ - RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + RVALUE, + Contract, + Enum, + Function, + Event, + CustomError, + SolidityImportPlaceHolder, + ElementaryType, ] = variable_left self._variable_right = variable_right self._lvalue = result diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 1ccdc5760..06fc03b7a 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -350,7 +350,7 @@ class ContractSolc(CallerContextExpression): if v.visibility != "private" } ) - self._contract.add_variables_ordered( + self._contract.add_state_variables_ordered( [ var for var in father.state_variables_ordered @@ -370,7 +370,7 @@ class ContractSolc(CallerContextExpression): if var_parser.reference_id is not None: self._contract.state_variables_by_ref_id[var_parser.reference_id] = var self._contract.variables_as_dict[var.name] = var - self._contract.add_variables_ordered([var]) + self._contract.add_state_variables_ordered([var]) def _parse_modifier(self, modifier_data: Dict) -> None: modif = Modifier(self._contract.compilation_unit) diff --git a/slither/solc_parsing/variables/state_variable.py b/slither/solc_parsing/variables/state_variable.py index a9c0ff730..227a84c61 100644 --- a/slither/solc_parsing/variables/state_variable.py +++ b/slither/solc_parsing/variables/state_variable.py @@ -13,3 +13,18 @@ class StateVariableSolc(VariableDeclarationSolc): # Todo: Not sure how to overcome this with mypy assert isinstance(self._variable, StateVariable) return self._variable + + def _analyze_variable_attributes(self, attributes: Dict) -> None: + """ + Variable Location + Can be default or transient + """ + if "storageLocation" in attributes: + self.underlying_variable.set_location(attributes["storageLocation"]) + else: + # We don't have to support legacy ast + # as transient location was added in 0.8.28 + # and we know it must be default + self.underlying_variable.set_location("default") + + super()._analyze_variable_attributes(attributes) diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 9cb2abc3f..182333d3f 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -294,10 +294,9 @@ class Flattening: self._export_list_used_contracts(inherited, exported, list_contract, list_top_level) # Find all the external contracts called - externals = contract.all_library_calls + contract.all_high_level_calls - # externals is a list of (contract, function) + # High level calls already includes library calls # We also filter call to itself to avoid infilite loop - externals = list({e[0] for e in externals if e[0] != contract}) + externals = list({e[0] for e in contract.all_high_level_calls if e[0] != contract}) for inherited in externals: self._export_list_used_contracts(inherited, exported, list_contract, list_top_level) diff --git a/slither/tools/possible_paths/possible_paths.py b/slither/tools/possible_paths/possible_paths.py index 6e836e76a..15218a872 100644 --- a/slither/tools/possible_paths/possible_paths.py +++ b/slither/tools/possible_paths/possible_paths.py @@ -123,10 +123,14 @@ def __find_target_paths( # Find all function calls in this function (except for low level) called_functions_list = [ - f for (_, f) in function.high_level_calls if isinstance(f, Function) + ir.function + for _, ir in function.high_level_calls + if isinstance(ir.function, Function) + ] + called_functions_list += [ir.function for ir in function.library_calls] + called_functions_list += [ + ir.function for ir in function.internal_calls if isinstance(ir.function, Function) ] - called_functions_list += [f for (_, f) in function.library_calls] - called_functions_list += [f for f in function.internal_calls if isinstance(f, Function)] called_functions = set(called_functions_list) # If any of our target functions are reachable from this function, it's a result. diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index b86036c87..047c652dc 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -43,7 +43,7 @@ Using initialize functions to write initial values in state variables. def _check(self) -> List[Output]: results = [] - for s in self.contract.stored_state_variables_ordered: + for s in self.contract.storage_variables_ordered: if s.initialized: info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 8d525a6dd..8f5017d74 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -115,29 +115,43 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() - order1 = contract1.stored_state_variables_ordered - order2 = contract2.stored_state_variables_ordered results: List[Output] = [] - for idx, _ in enumerate(order1): - if len(order2) <= idx: - # Handle by MissingVariable - return results - - variable1 = order1[idx] - variable2 = order2[idx] - if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info: CHECK_INFO = [ - "Different variables between ", - contract1, - " and ", - contract2, - "\n", - ] - info += ["\t ", variable1, "\n"] - info += ["\t ", variable2, "\n"] - json = self.generate_result(info) - results.append(json) + + def _check_internal( + contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool + ): + if is_transient: + order1 = contract1.transient_variables_ordered + order2 = contract2.transient_variables_ordered + else: + order1 = contract1.storage_variables_ordered + order2 = contract2.storage_variables_ordered + + for idx, _ in enumerate(order1): + if len(order2) <= idx: + # Handle by MissingVariable + return + + variable1 = order1[idx] + variable2 = order2[idx] + if (variable1.name != variable2.name) or (variable1.type != variable2.type): + info: CHECK_INFO = [ + "Different variables between ", + contract1, + " and ", + contract2, + "\n", + ] + info += ["\t ", variable1, "\n"] + info += ["\t ", variable2, "\n"] + json = self.generate_result(info) + results.append(json) + + # Checking state variables with storage location + _check_internal(contract1, contract2, results, False) + # Checking state variables with transient location + _check_internal(contract1, contract2, results, True) return results @@ -236,22 +250,35 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() - order1 = contract1.stored_state_variables_ordered - order2 = contract2.stored_state_variables_ordered - results = [] + results: List[Output] = [] - if len(order2) <= len(order1): - return [] + def _check_internal( + contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool + ): + if is_transient: + order1 = contract1.transient_variables_ordered + order2 = contract2.transient_variables_ordered + else: + order1 = contract1.storage_variables_ordered + order2 = contract2.storage_variables_ordered - idx = len(order1) + if len(order2) <= len(order1): + return - while idx < len(order2): - variable2 = order2[idx] - info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] - json = self.generate_result(info) - results.append(json) - idx = idx + 1 + idx = len(order1) + + while idx < len(order2): + variable2 = order2[idx] + info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] + json = self.generate_result(info) + results.append(json) + idx = idx + 1 + + # Checking state variables with storage location + _check_internal(contract1, contract2, results, False) + # Checking state variables with transient location + _check_internal(contract1, contract2, results, True) return results diff --git a/slither/utils/function.py b/slither/utils/function.py index 34e6f221b..64c098bfd 100644 --- a/slither/utils/function.py +++ b/slither/utils/function.py @@ -12,3 +12,16 @@ def get_function_id(sig: str) -> int: digest = keccak.new(digest_bits=256) digest.update(sig.encode("utf-8")) return int("0x" + digest.hexdigest()[:8], 16) + + +def get_event_id(sig: str) -> int: + """' + Return the event id of the given signature + Args: + sig (str) + Return: + (int) + """ + digest = keccak.new(digest_bits=256) + digest.update(sig.encode("utf-8")) + return int("0x" + digest.hexdigest(), 16) diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 22461dbcf..bedf08d4b 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -80,8 +80,8 @@ def compare( tainted-contracts: list[TaintedExternalContract] """ - order_vars1 = v1.stored_state_variables_ordered - order_vars2 = v2.stored_state_variables_ordered + order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered + order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs2 = [function.solidity_signature for function in v2.functions] @@ -123,7 +123,9 @@ def compare( ): continue modified_calls = [ - func for func in new_modified_functions if func in function.internal_calls + func + for func in new_modified_functions + if func in [ir.function for ir in function.internal_calls] ] tainted_vars = [ var @@ -179,7 +181,8 @@ def tainted_external_contracts(funcs: List[Function]) -> List[TaintedExternalCon tainted_list: list[TaintedExternalContract] = [] for func in funcs: - for contract, target in func.all_high_level_calls(): + for contract, ir in func.all_high_level_calls(): + target = ir.function if contract.is_library: # Not interested in library calls continue @@ -254,7 +257,11 @@ def tainted_inheriting_contracts( new_taint = TaintedExternalContract(c) for f in c.functions_declared: # Search for functions that call an inherited tainted function or access an inherited tainted variable - internal_calls = [c for c in f.all_internal_calls() if isinstance(c, Function)] + internal_calls = [ + ir.function + for ir in f.all_internal_calls() + if isinstance(ir.function, Function) + ] if any( call.canonical_name == t.canonical_name for t in tainted.tainted_functions @@ -299,8 +306,8 @@ def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]: List of StateVariables from v1 missing in v2 """ results = [] - order_vars1 = v1.stored_state_variables_ordered - order_vars2 = v2.stored_state_variables_ordered + order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered + order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered if len(order_vars2) < len(order_vars1): for variable in order_vars1: if variable.name not in [v.name for v in order_vars2]: diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index b1fa570c6..ddadb77a1 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -13,7 +13,9 @@ from slither.core.expressions import ( TupleExpression, TypeConversion, CallExpression, + MemberAccess, ) +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.variables import Variable from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor @@ -27,7 +29,13 @@ class NotConstant(Exception): KEY = "ConstantFolding" CONSTANT_TYPES_OPERATIONS = Union[ - Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, ] @@ -69,6 +77,9 @@ class ConstantFolding(ExpressionVisitor): # pylint: disable=import-outside-toplevel def _post_identifier(self, expression: Identifier) -> None: from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum + from slither.core.solidity_types.type_alias import TypeAlias + from slither.core.declarations.contract import Contract if isinstance(expression.value, Variable): if expression.value.is_constant: @@ -77,7 +88,14 @@ class ConstantFolding(ExpressionVisitor): # Everything outside of literal if isinstance( expr, - (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): cf = ConstantFolding(expr, self._type) expr = cf.result() @@ -88,7 +106,12 @@ class ConstantFolding(ExpressionVisitor): elif isinstance(expression.value, SolidityFunction): set_val(expression, expression.value) else: - raise NotConstant + # Enum: We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value + # We can't handle it here because we don't have the field accessed so we do it in _post_member_access + # TypeAlias: Support when a .wrap() is done with a constant + # Contract: Support when a constatn is use from a different contract + if not isinstance(expression.value, (Enum, TypeAlias, Contract)): + raise NotConstant # pylint: disable=too-many-branches,too-many-statements def _post_binary_operation(self, expression: BinaryOperation) -> None: @@ -96,12 +119,28 @@ class ConstantFolding(ExpressionVisitor): expression_right = expression.expression_right if not isinstance( expression_left, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant if not isinstance( expression_right, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant left = get_val(expression_left) @@ -205,6 +244,34 @@ class ConstantFolding(ExpressionVisitor): raise NotConstant def _post_call_expression(self, expression: expressions.CallExpression) -> None: + from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum + from slither.core.solidity_types import TypeAlias + + # pylint: disable=too-many-boolean-expressions + if ( + isinstance(expression.called, Identifier) + and expression.called.value == SolidityFunction("type()") + and len(expression.arguments) == 1 + and ( + isinstance(expression.arguments[0], ElementaryTypeNameExpression) + or isinstance(expression.arguments[0], Identifier) + and isinstance(expression.arguments[0].value, Enum) + ) + ): + # Returning early to support type(ElemType).max/min or type(MyEnum).max/min + return + if ( + isinstance(expression.called.expression, Identifier) + and isinstance(expression.called.expression.value, TypeAlias) + and isinstance(expression.called, MemberAccess) + and expression.called.member_name == "wrap" + and len(expression.arguments) == 1 + ): + # Handle constants in .wrap of user defined type + set_val(expression, get_val(expression.arguments[0])) + return + called = get_val(expression.called) args = [get_val(arg) for arg in expression.arguments] if called.name == "keccak256(bytes)": @@ -220,12 +287,104 @@ class ConstantFolding(ExpressionVisitor): def _post_elementary_type_name_expression( self, expression: expressions.ElementaryTypeNameExpression ) -> None: - raise NotConstant + # We don't have to raise an exception to support type(uint112).max or similar + pass def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant + # pylint: disable=too-many-locals def _post_member_access(self, expression: expressions.MemberAccess) -> None: + from slither.core.declarations import ( + SolidityFunction, + Contract, + EnumContract, + EnumTopLevel, + Enum, + ) + from slither.core.solidity_types import UserDefinedType, TypeAlias + + # pylint: disable=too-many-nested-blocks + if isinstance(expression.expression, CallExpression) and expression.member_name in [ + "min", + "max", + ]: + if isinstance(expression.expression.called, Identifier): + if expression.expression.called.value == SolidityFunction("type()"): + assert len(expression.expression.arguments) == 1 + type_expression_found = expression.expression.arguments[0] + type_found: Union[ElementaryType, UserDefinedType] + if isinstance(type_expression_found, ElementaryTypeNameExpression): + type_expression_found_type = type_expression_found.type + assert isinstance(type_expression_found_type, ElementaryType) + type_found = type_expression_found_type + value = ( + type_found.max if expression.member_name == "max" else type_found.min + ) + set_val(expression, value) + return + # type(enum).max/min + # Case when enum is in another contract e.g. type(C.E).max + if isinstance(type_expression_found, MemberAccess): + contract = type_expression_found.expression.value + assert isinstance(contract, Contract) + for enum in contract.enums: + if enum.name == type_expression_found.member_name: + type_found_in_expression = enum + type_found = UserDefinedType(enum) + break + else: + assert isinstance(type_expression_found, Identifier) + type_found_in_expression = type_expression_found.value + assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel)) + type_found = UserDefinedType(type_found_in_expression) + value = ( + type_found_in_expression.max + if expression.member_name == "max" + else type_found_in_expression.min + ) + set_val(expression, value) + return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, Enum + ): + # Handle direct access to enum field + set_val(expression, expression.expression.value.values.index(expression.member_name)) + return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, TypeAlias + ): + # User defined type .wrap call handled in _post_call_expression + return + elif ( + isinstance(expression.expression.value, Contract) + and expression.member_name in expression.expression.value.variables_as_dict + and expression.expression.value.variables_as_dict[expression.member_name].is_constant + ): + # Handles when a constant is accessed on another contract + variables = expression.expression.value.variables_as_dict + if isinstance(variables[expression.member_name].expression, MemberAccess): + self._post_member_access(variables[expression.member_name].expression) + set_val(expression, get_val(variables[expression.member_name].expression)) + return + + # If the variable is a Literal we convert its value to int + if isinstance(variables[expression.member_name].expression, Literal): + value = convert_string_to_int( + variables[expression.member_name].expression.converted_value + ) + # If the variable is a UnaryOperation we need convert its value to int + # and replacing possible spaces + elif isinstance(variables[expression.member_name].expression, UnaryOperation): + value = convert_string_to_int( + str(variables[expression.member_name].expression).replace(" ", "") + ) + else: + value = variables[expression.member_name].expression + + set_val(expression, value) + return + raise NotConstant def _post_new_array(self, expression: expressions.NewArray) -> None: @@ -272,6 +431,7 @@ class ConstantFolding(ExpressionVisitor): TupleExpression, TypeConversion, CallExpression, + MemberAccess, ), ): raise NotConstant diff --git a/slither/vyper_parsing/declarations/contract.py b/slither/vyper_parsing/declarations/contract.py index 64fab1c54..ddf516150 100644 --- a/slither/vyper_parsing/declarations/contract.py +++ b/slither/vyper_parsing/declarations/contract.py @@ -470,7 +470,7 @@ class ContractVyper: # pylint: disable=too-many-instance-attributes assert var.name self._contract.variables_as_dict[var.name] = var - self._contract.add_variables_ordered([var]) + self._contract.add_state_variables_ordered([var]) # Interfaces can refer to constants self._contract.file_scope.variables[var.name] = var diff --git a/tests/e2e/detectors/snapshots/detectors__detector_ChainlinkFeedRegistry_0_8_20_chainlink_feed_registry_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_ChainlinkFeedRegistry_0_8_20_chainlink_feed_registry_sol__0.txt new file mode 100644 index 000000000..6b7653ed0 --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_ChainlinkFeedRegistry_0_8_20_chainlink_feed_registry_sol__0.txt @@ -0,0 +1,3 @@ +The Chainlink Feed Registry is used in the A contract. It's only available on Ethereum Mainnet, consider to not use it if the contract needs to be deployed on other chains. + - (None,price,None,None,None) = registry.latestRoundData(base,quote) (tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol#25) + diff --git a/tests/e2e/detectors/snapshots/detectors__detector_ChronicleUncheckedPrice_0_8_20_chronicle_unchecked_price_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_ChronicleUncheckedPrice_0_8_20_chronicle_unchecked_price_sol__0.txt new file mode 100644 index 000000000..6ddbfa4e5 --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_ChronicleUncheckedPrice_0_8_20_chronicle_unchecked_price_sol__0.txt @@ -0,0 +1,18 @@ +Chronicle price is not checked to be valid in C.bad2() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#74-76) + - (price,None) = chronicle.readWithAge() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#75) + +Chronicle price is not checked to be valid in C.bad() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#65-67) + - price = chronicle.read() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#66) + +Chronicle price is not checked to be valid in C.bad5() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#101-103) + - price = scribe.latestAnswer() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#102) + +Chronicle price is not checked to be valid in C.bad4() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#92-94) + - (isValid,price,None) = chronicle.tryReadWithAge() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#93) + +Chronicle price is not checked to be valid in C.bad3() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#83-85) + - (isValid,price) = chronicle.tryRead() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#84) + +Chronicle price is not checked to be valid in C.bad6() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#110-112) + - (None,price,None,None,None) = scribe.latestRoundData() (tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol#111) + diff --git a/tests/e2e/detectors/snapshots/detectors__detector_GelatoUnprotectedRandomness_0_8_20_gelato_unprotected_randomness_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_GelatoUnprotectedRandomness_0_8_20_gelato_unprotected_randomness_sol__0.txt new file mode 100644 index 000000000..aee2ea4dd --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_GelatoUnprotectedRandomness_0_8_20_gelato_unprotected_randomness_sol__0.txt @@ -0,0 +1,6 @@ +C.bad() (tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol#42-44) is unprotected and request randomness from Gelato VRF + - id = _requestRandomness(abi.encode(msg.sender)) (tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol#43) + +C.good2() (tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol#51-54) is unprotected and request randomness from Gelato VRF + - id = _requestRandomness(abi.encode(msg.sender)) (tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol#53) + diff --git a/tests/e2e/detectors/snapshots/detectors__detector_OptimismDeprecation_0_8_20_optimism_deprecation_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_OptimismDeprecation_0_8_20_optimism_deprecation_sol__0.txt new file mode 100644 index 000000000..f6f4dccba --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_OptimismDeprecation_0_8_20_optimism_deprecation_sol__0.txt @@ -0,0 +1,4 @@ +A deprecated Optimism predeploy or function is used in the Test contract. + - OPT_GAS.scalar() (tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol#15) + - L1_BLOCK_NUMBER.q() (tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol#19) + diff --git a/tests/e2e/detectors/snapshots/detectors__detector_PythDeprecatedFunctions_0_8_20_pyth_deprecated_functions_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_PythDeprecatedFunctions_0_8_20_pyth_deprecated_functions_sol__0.txt new file mode 100644 index 000000000..4cc23d213 --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_PythDeprecatedFunctions_0_8_20_pyth_deprecated_functions_sol__0.txt @@ -0,0 +1,3 @@ +The following Pyth deprecated function is used + - price = pyth.getPrice(priceId) (tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol#23) + diff --git a/tests/e2e/detectors/snapshots/detectors__detector_PythUncheckedConfidence_0_8_20_pyth_unchecked_confidence_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_PythUncheckedConfidence_0_8_20_pyth_unchecked_confidence_sol__0.txt new file mode 100644 index 000000000..ae0dc2ae2 --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_PythUncheckedConfidence_0_8_20_pyth_unchecked_confidence_sol__0.txt @@ -0,0 +1,3 @@ +Pyth price conf field is not checked in C.bad(bytes32,uint256) (tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol#171-175) + - price = pyth.getEmaPriceNoOlderThan(id,age) (tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol#172) + diff --git a/tests/e2e/detectors/snapshots/detectors__detector_PythUncheckedPublishTime_0_8_20_pyth_unchecked_publishtime_sol__0.txt b/tests/e2e/detectors/snapshots/detectors__detector_PythUncheckedPublishTime_0_8_20_pyth_unchecked_publishtime_sol__0.txt new file mode 100644 index 000000000..cb331c8d5 --- /dev/null +++ b/tests/e2e/detectors/snapshots/detectors__detector_PythUncheckedPublishTime_0_8_20_pyth_unchecked_publishtime_sol__0.txt @@ -0,0 +1,3 @@ +Pyth price publishTime field is not checked in C.bad(bytes32) (tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol#171-175) + - price = pyth.getEmaPriceUnsafe(id) (tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol#172) + diff --git a/tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol b/tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol new file mode 100644 index 000000000..cf5d1ad4d --- /dev/null +++ b/tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol @@ -0,0 +1,37 @@ +interface FeedRegistryInterface { + function latestRoundData( + address base, + address quote + ) external view returns (uint80 roundId, int256 answer, uint256 startedAt, uint256 updatedAt, uint80 answeredInRound); +} + +interface MyInterface { + function latestRoundData( + address base, + address quote + ) external view returns (uint80 roundId, int256 answer, uint256 startedAt, uint256 updatedAt, uint80 answeredInRound); +} + +contract A { + FeedRegistryInterface public immutable registry; + MyInterface public immutable my_interface; + + constructor(FeedRegistryInterface _registry, MyInterface _my_interface) { + registry = _registry; + my_interface = _my_interface; + } + + function getPriceBad(address base, address quote) public returns (uint256) { + (, int256 price,,,) = registry.latestRoundData(base, quote); + // Do price validation + return uint256(price); + } + + function getPriceGood(address base, address quote) public returns (uint256) { + (, int256 price,,,) = my_interface.latestRoundData(base, quote); + // Do price validation + return uint256(price); + } + + +} \ No newline at end of file diff --git a/tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol-0.8.20.zip b/tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol-0.8.20.zip new file mode 100644 index 000000000..262ede23f Binary files /dev/null and b/tests/e2e/detectors/test_data/chainlink-feed-registry/0.8.20/chainlink_feed_registry.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol b/tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol new file mode 100644 index 000000000..e12560fa7 --- /dev/null +++ b/tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol @@ -0,0 +1,119 @@ +interface IChronicle { + /// @notice Returns the oracle's current value. + /// @dev Reverts if no value set. + /// @return value The oracle's current value. + function read() external view returns (uint value); + + /// @notice Returns the oracle's current value and its age. + /// @dev Reverts if no value set. + /// @return value The oracle's current value. + /// @return age The value's age. + function readWithAge() external view returns (uint value, uint age); + + /// @notice Returns the oracle's current value. + /// @return isValid True if value exists, false otherwise. + /// @return value The oracle's current value if it exists, zero otherwise. + function tryRead() external view returns (bool isValid, uint value); + + /// @notice Returns the oracle's current value and its age. + /// @return isValid True if value exists, false otherwise. + /// @return value The oracle's current value if it exists, zero otherwise. + /// @return age The value's age if value exists, zero otherwise. + function tryReadWithAge() + external + view + returns (bool isValid, uint value, uint age); +} + +interface IScribe is IChronicle { + /// @notice Returns the oracle's latest value. + /// @dev Provides partial compatibility with Chainlink's + /// IAggregatorV3Interface. + /// @return roundId 1. + /// @return answer The oracle's latest value. + /// @return startedAt 0. + /// @return updatedAt The timestamp of oracle's latest update. + /// @return answeredInRound 1. + function latestRoundData() + external + view + returns ( + uint80 roundId, + int answer, + uint startedAt, + uint updatedAt, + uint80 answeredInRound + ); + + /// @notice Returns the oracle's latest value. + /// @dev Provides partial compatibility with Chainlink's + /// IAggregatorV3Interface. + /// @custom:deprecated See https://docs.chain.link/data-feeds/api-reference/#latestanswer. + /// @return answer The oracle's latest value. + function latestAnswer() external view returns (int); +} + +contract C { + IScribe scribe; + IChronicle chronicle; + + constructor(address a) { + scribe = IScribe(a); + chronicle = IChronicle(a); + } + + function bad() public { + uint256 price = chronicle.read(); + } + + function good() public { + uint256 price = chronicle.read(); + require(price != 0); + } + + function bad2() public { + (uint256 price,) = chronicle.readWithAge(); + } + + function good2() public { + (uint256 price,) = chronicle.readWithAge(); + require(price != 0); + } + + function bad3() public { + (bool isValid, uint256 price) = chronicle.tryRead(); + } + + function good3() public { + (bool isValid, uint256 price) = chronicle.tryRead(); + require(isValid); + } + + function bad4() public { + (bool isValid, uint256 price,) = chronicle.tryReadWithAge(); + } + + function good4() public { + (bool isValid, uint256 price,) = chronicle.tryReadWithAge(); + require(isValid); + } + + function bad5() public { + int256 price = scribe.latestAnswer(); + } + + function good5() public { + int256 price = scribe.latestAnswer(); + require(price != 0); + } + + function bad6() public { + (, int256 price,,,) = scribe.latestRoundData(); + } + + function good6() public { + (, int256 price,,,) = scribe.latestRoundData(); + require(price != 0); + } + +} \ No newline at end of file diff --git a/tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol-0.8.20.zip b/tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol-0.8.20.zip new file mode 100644 index 000000000..746efabf6 Binary files /dev/null and b/tests/e2e/detectors/test_data/chronicle-unchecked-price/0.8.20/chronicle_unchecked_price.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol b/tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol new file mode 100644 index 000000000..108859e9e --- /dev/null +++ b/tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol @@ -0,0 +1,62 @@ +// Mock GelatoVRFConsumerBase for what we need +abstract contract GelatoVRFConsumerBase { + bool[] public requestPending; + mapping(uint256 => bytes32) public requestedHash; + + function _fulfillRandomness( + uint256 randomness, + uint256 requestId, + bytes memory extraData + ) internal virtual; + + function _requestRandomness( + bytes memory extraData + ) internal returns (uint256 requestId) { + requestId = uint256(requestPending.length); + requestPending.push(); + requestPending[requestId] = true; + + bytes memory data = abi.encode(requestId, extraData); + uint256 round = 111; + + bytes memory dataWithRound = abi.encode(round, data); + bytes32 requestHash = keccak256(dataWithRound); + + requestedHash[requestId] = requestHash; + } + +} + +contract C is GelatoVRFConsumerBase { + address owner; + mapping(address => bool) authorized; + + function _fulfillRandomness( + uint256 randomness, + uint256, + bytes memory extraData + ) internal override { + // Do something with the random number + } + + function bad() public { + uint id = _requestRandomness(abi.encode(msg.sender)); + } + + function good() public { + require(msg.sender == owner); + uint id = _requestRandomness(abi.encode(msg.sender)); + } + + // This is currently a FP due to the limitation of function.is_protected + function good2() public { + require(authorized[msg.sender]); + uint id = _requestRandomness(abi.encode(msg.sender)); + } + + function good3() public { + if (msg.sender != owner) { revert(); } + uint id = _requestRandomness(abi.encode(msg.sender)); + } + +} diff --git a/tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol-0.8.20.zip b/tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol-0.8.20.zip new file mode 100644 index 000000000..013d3ef28 Binary files /dev/null and b/tests/e2e/detectors/test_data/gelato-unprotected-randomness/0.8.20/gelato_unprotected_randomness.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol b/tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol new file mode 100644 index 000000000..7ad55f3dd --- /dev/null +++ b/tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol @@ -0,0 +1,27 @@ +interface GasPriceOracle { + function scalar() external view returns (uint256); + function baseFee() external view returns (uint256); +} + +interface L1BlockNumber { + function q() external view returns (uint256); +} + +contract Test { + GasPriceOracle constant OPT_GAS = GasPriceOracle(0x420000000000000000000000000000000000000F); + L1BlockNumber constant L1_BLOCK_NUMBER = L1BlockNumber(0x4200000000000000000000000000000000000013); + + function bad() public { + OPT_GAS.scalar(); + } + + function bad2() public { + L1_BLOCK_NUMBER.q(); + } + + function good() public { + OPT_GAS.baseFee(); + } + + +} diff --git a/tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol-0.8.20.zip b/tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol-0.8.20.zip new file mode 100644 index 000000000..de18d4a0d Binary files /dev/null and b/tests/e2e/detectors/test_data/optimism-deprecation/0.8.20/optimism_deprecation.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol b/tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol new file mode 100644 index 000000000..dc8130db5 --- /dev/null +++ b/tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol @@ -0,0 +1,35 @@ + +// Fake Pyth interface +interface IPyth { + function getPrice(bytes32 id) external returns (uint256 price); + function notDeprecated(bytes32 id) external returns (uint256 price); +} + +interface INotPyth { + function getPrice(bytes32 id) external returns (uint256 price); +} + +contract C { + + IPyth pyth; + INotPyth notPyth; + + constructor(IPyth _pyth, INotPyth _notPyth) { + pyth = _pyth; + notPyth = _notPyth; + } + + function Deprecated(bytes32 priceId) public { + uint256 price = pyth.getPrice(priceId); + } + + function notDeprecated(bytes32 priceId) public { + uint256 price = pyth.notDeprecated(priceId); + } + + function notPythCall(bytes32 priceId) public { + uint256 price = notPyth.getPrice(priceId); + } + + +} diff --git a/tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol-0.8.20.zip b/tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol-0.8.20.zip new file mode 100644 index 000000000..258a28c93 Binary files /dev/null and b/tests/e2e/detectors/test_data/pyth-deprecated-functions/0.8.20/pyth_deprecated_functions.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol b/tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol new file mode 100644 index 000000000..58880c382 --- /dev/null +++ b/tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol @@ -0,0 +1,193 @@ +contract PythStructs { + // A price with a degree of uncertainty, represented as a price +- a confidence interval. + // + // The confidence interval roughly corresponds to the standard error of a normal distribution. + // Both the price and confidence are stored in a fixed-point numeric representation, + // `x * (10^expo)`, where `expo` is the exponent. + // + // Please refer to the documentation at https://docs.pyth.network/consumers/best-practices for how + // to how this price safely. + struct Price { + // Price + int64 price; + // Confidence interval around the price + uint64 conf; + // Price exponent + int32 expo; + // Unix timestamp describing when the price was published + uint publishTime; + } + + // PriceFeed represents a current aggregate price from pyth publisher feeds. + struct PriceFeed { + // The price ID. + bytes32 id; + // Latest available price + Price price; + // Latest available exponentially-weighted moving average price + Price emaPrice; + } +} + +interface IPyth { + /// @notice Returns the period (in seconds) that a price feed is considered valid since its publish time + function getValidTimePeriod() external view returns (uint validTimePeriod); + + /// @notice Returns the price and confidence interval. + /// @dev Reverts if the price has not been updated within the last `getValidTimePeriod()` seconds. + /// @param id The Pyth Price Feed ID of which to fetch the price and confidence interval. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getPrice( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the exponentially-weighted moving average price and confidence interval. + /// @dev Reverts if the EMA price is not available. + /// @param id The Pyth Price Feed ID of which to fetch the EMA price and confidence interval. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getEmaPrice( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the price of a price feed without any sanity checks. + /// @dev This function returns the most recent price update in this contract without any recency checks. + /// This function is unsafe as the returned price update may be arbitrarily far in the past. + /// + /// Users of this function should check the `publishTime` in the price to ensure that the returned price is + /// sufficiently recent for their application. If you are considering using this function, it may be + /// safer / easier to use either `getPrice` or `getPriceNoOlderThan`. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getPriceUnsafe( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the price that is no older than `age` seconds of the current time. + /// @dev This function is a sanity-checked version of `getPriceUnsafe` which is useful in + /// applications that require a sufficiently-recent price. Reverts if the price wasn't updated sufficiently + /// recently. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getPriceNoOlderThan( + bytes32 id, + uint age + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the exponentially-weighted moving average price of a price feed without any sanity checks. + /// @dev This function returns the same price as `getEmaPrice` in the case where the price is available. + /// However, if the price is not recent this function returns the latest available price. + /// + /// The returned price can be from arbitrarily far in the past; this function makes no guarantees that + /// the returned price is recent or useful for any particular application. + /// + /// Users of this function should check the `publishTime` in the price to ensure that the returned price is + /// sufficiently recent for their application. If you are considering using this function, it may be + /// safer / easier to use either `getEmaPrice` or `getEmaPriceNoOlderThan`. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getEmaPriceUnsafe( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the exponentially-weighted moving average price that is no older than `age` seconds + /// of the current time. + /// @dev This function is a sanity-checked version of `getEmaPriceUnsafe` which is useful in + /// applications that require a sufficiently-recent price. Reverts if the price wasn't updated sufficiently + /// recently. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getEmaPriceNoOlderThan( + bytes32 id, + uint age + ) external view returns (PythStructs.Price memory price); + + /// @notice Update price feeds with given update messages. + /// This method requires the caller to pay a fee in wei; the required fee can be computed by calling + /// `getUpdateFee` with the length of the `updateData` array. + /// Prices will be updated if they are more recent than the current stored prices. + /// The call will succeed even if the update is not the most recent. + /// @dev Reverts if the transferred fee is not sufficient or the updateData is invalid. + /// @param updateData Array of price update data. + function updatePriceFeeds(bytes[] calldata updateData) external payable; + + /// @notice Wrapper around updatePriceFeeds that rejects fast if a price update is not necessary. A price update is + /// necessary if the current on-chain publishTime is older than the given publishTime. It relies solely on the + /// given `publishTimes` for the price feeds and does not read the actual price update publish time within `updateData`. + /// + /// This method requires the caller to pay a fee in wei; the required fee can be computed by calling + /// `getUpdateFee` with the length of the `updateData` array. + /// + /// `priceIds` and `publishTimes` are two arrays with the same size that correspond to senders known publishTime + /// of each priceId when calling this method. If all of price feeds within `priceIds` have updated and have + /// a newer or equal publish time than the given publish time, it will reject the transaction to save gas. + /// Otherwise, it calls updatePriceFeeds method to update the prices. + /// + /// @dev Reverts if update is not needed or the transferred fee is not sufficient or the updateData is invalid. + /// @param updateData Array of price update data. + /// @param priceIds Array of price ids. + /// @param publishTimes Array of publishTimes. `publishTimes[i]` corresponds to known `publishTime` of `priceIds[i]` + function updatePriceFeedsIfNecessary( + bytes[] calldata updateData, + bytes32[] calldata priceIds, + uint64[] calldata publishTimes + ) external payable; + + /// @notice Returns the required fee to update an array of price updates. + /// @param updateData Array of price update data. + /// @return feeAmount The required fee in Wei. + function getUpdateFee( + bytes[] calldata updateData + ) external view returns (uint feeAmount); + + /// @notice Parse `updateData` and return price feeds of the given `priceIds` if they are all published + /// within `minPublishTime` and `maxPublishTime`. + /// + /// You can use this method if you want to use a Pyth price at a fixed time and not the most recent price; + /// otherwise, please consider using `updatePriceFeeds`. This method does not store the price updates on-chain. + /// + /// This method requires the caller to pay a fee in wei; the required fee can be computed by calling + /// `getUpdateFee` with the length of the `updateData` array. + /// + /// + /// @dev Reverts if the transferred fee is not sufficient or the updateData is invalid or there is + /// no update for any of the given `priceIds` within the given time range. + /// @param updateData Array of price update data. + /// @param priceIds Array of price ids. + /// @param minPublishTime minimum acceptable publishTime for the given `priceIds`. + /// @param maxPublishTime maximum acceptable publishTime for the given `priceIds`. + /// @return priceFeeds Array of the price feeds corresponding to the given `priceIds` (with the same order). + function parsePriceFeedUpdates( + bytes[] calldata updateData, + bytes32[] calldata priceIds, + uint64 minPublishTime, + uint64 maxPublishTime + ) external payable returns (PythStructs.PriceFeed[] memory priceFeeds); +} + + +contract C { + IPyth pyth; + + constructor(IPyth _pyth) { + pyth = _pyth; + } + + function bad(bytes32 id, uint256 age) public { + PythStructs.Price memory price = pyth.getEmaPriceNoOlderThan(id, age); + require(price.publishTime > block.timestamp - 120); + // Use price + } + + function good(bytes32 id, uint256 age) public { + PythStructs.Price memory price = pyth.getEmaPriceNoOlderThan(id, age); + require(price.conf < 10000); + require(price.publishTime > block.timestamp - 120); + // Use price + } + + function good2(bytes32 id, uint256 age) public { + PythStructs.Price memory price = pyth.getEmaPriceNoOlderThan(id, age); + require(price.publishTime > block.timestamp - 120); + if (price.conf >= 10000) { + revert(); + } + // Use price + } + +} \ No newline at end of file diff --git a/tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol-0.8.20.zip b/tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol-0.8.20.zip new file mode 100644 index 000000000..6e5fa1b9f Binary files /dev/null and b/tests/e2e/detectors/test_data/pyth-unchecked-confidence/0.8.20/pyth_unchecked_confidence.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol b/tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol new file mode 100644 index 000000000..74ab10fe3 --- /dev/null +++ b/tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol @@ -0,0 +1,193 @@ +contract PythStructs { + // A price with a degree of uncertainty, represented as a price +- a confidence interval. + // + // The confidence interval roughly corresponds to the standard error of a normal distribution. + // Both the price and confidence are stored in a fixed-point numeric representation, + // `x * (10^expo)`, where `expo` is the exponent. + // + // Please refer to the documentation at https://docs.pyth.network/consumers/best-practices for how + // to how this price safely. + struct Price { + // Price + int64 price; + // Confidence interval around the price + uint64 conf; + // Price exponent + int32 expo; + // Unix timestamp describing when the price was published + uint publishTime; + } + + // PriceFeed represents a current aggregate price from pyth publisher feeds. + struct PriceFeed { + // The price ID. + bytes32 id; + // Latest available price + Price price; + // Latest available exponentially-weighted moving average price + Price emaPrice; + } +} + +interface IPyth { + /// @notice Returns the period (in seconds) that a price feed is considered valid since its publish time + function getValidTimePeriod() external view returns (uint validTimePeriod); + + /// @notice Returns the price and confidence interval. + /// @dev Reverts if the price has not been updated within the last `getValidTimePeriod()` seconds. + /// @param id The Pyth Price Feed ID of which to fetch the price and confidence interval. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getPrice( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the exponentially-weighted moving average price and confidence interval. + /// @dev Reverts if the EMA price is not available. + /// @param id The Pyth Price Feed ID of which to fetch the EMA price and confidence interval. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getEmaPrice( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the price of a price feed without any sanity checks. + /// @dev This function returns the most recent price update in this contract without any recency checks. + /// This function is unsafe as the returned price update may be arbitrarily far in the past. + /// + /// Users of this function should check the `publishTime` in the price to ensure that the returned price is + /// sufficiently recent for their application. If you are considering using this function, it may be + /// safer / easier to use either `getPrice` or `getPriceNoOlderThan`. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getPriceUnsafe( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the price that is no older than `age` seconds of the current time. + /// @dev This function is a sanity-checked version of `getPriceUnsafe` which is useful in + /// applications that require a sufficiently-recent price. Reverts if the price wasn't updated sufficiently + /// recently. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getPriceNoOlderThan( + bytes32 id, + uint age + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the exponentially-weighted moving average price of a price feed without any sanity checks. + /// @dev This function returns the same price as `getEmaPrice` in the case where the price is available. + /// However, if the price is not recent this function returns the latest available price. + /// + /// The returned price can be from arbitrarily far in the past; this function makes no guarantees that + /// the returned price is recent or useful for any particular application. + /// + /// Users of this function should check the `publishTime` in the price to ensure that the returned price is + /// sufficiently recent for their application. If you are considering using this function, it may be + /// safer / easier to use either `getEmaPrice` or `getEmaPriceNoOlderThan`. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getEmaPriceUnsafe( + bytes32 id + ) external view returns (PythStructs.Price memory price); + + /// @notice Returns the exponentially-weighted moving average price that is no older than `age` seconds + /// of the current time. + /// @dev This function is a sanity-checked version of `getEmaPriceUnsafe` which is useful in + /// applications that require a sufficiently-recent price. Reverts if the price wasn't updated sufficiently + /// recently. + /// @return price - please read the documentation of PythStructs.Price to understand how to use this safely. + function getEmaPriceNoOlderThan( + bytes32 id, + uint age + ) external view returns (PythStructs.Price memory price); + + /// @notice Update price feeds with given update messages. + /// This method requires the caller to pay a fee in wei; the required fee can be computed by calling + /// `getUpdateFee` with the length of the `updateData` array. + /// Prices will be updated if they are more recent than the current stored prices. + /// The call will succeed even if the update is not the most recent. + /// @dev Reverts if the transferred fee is not sufficient or the updateData is invalid. + /// @param updateData Array of price update data. + function updatePriceFeeds(bytes[] calldata updateData) external payable; + + /// @notice Wrapper around updatePriceFeeds that rejects fast if a price update is not necessary. A price update is + /// necessary if the current on-chain publishTime is older than the given publishTime. It relies solely on the + /// given `publishTimes` for the price feeds and does not read the actual price update publish time within `updateData`. + /// + /// This method requires the caller to pay a fee in wei; the required fee can be computed by calling + /// `getUpdateFee` with the length of the `updateData` array. + /// + /// `priceIds` and `publishTimes` are two arrays with the same size that correspond to senders known publishTime + /// of each priceId when calling this method. If all of price feeds within `priceIds` have updated and have + /// a newer or equal publish time than the given publish time, it will reject the transaction to save gas. + /// Otherwise, it calls updatePriceFeeds method to update the prices. + /// + /// @dev Reverts if update is not needed or the transferred fee is not sufficient or the updateData is invalid. + /// @param updateData Array of price update data. + /// @param priceIds Array of price ids. + /// @param publishTimes Array of publishTimes. `publishTimes[i]` corresponds to known `publishTime` of `priceIds[i]` + function updatePriceFeedsIfNecessary( + bytes[] calldata updateData, + bytes32[] calldata priceIds, + uint64[] calldata publishTimes + ) external payable; + + /// @notice Returns the required fee to update an array of price updates. + /// @param updateData Array of price update data. + /// @return feeAmount The required fee in Wei. + function getUpdateFee( + bytes[] calldata updateData + ) external view returns (uint feeAmount); + + /// @notice Parse `updateData` and return price feeds of the given `priceIds` if they are all published + /// within `minPublishTime` and `maxPublishTime`. + /// + /// You can use this method if you want to use a Pyth price at a fixed time and not the most recent price; + /// otherwise, please consider using `updatePriceFeeds`. This method does not store the price updates on-chain. + /// + /// This method requires the caller to pay a fee in wei; the required fee can be computed by calling + /// `getUpdateFee` with the length of the `updateData` array. + /// + /// + /// @dev Reverts if the transferred fee is not sufficient or the updateData is invalid or there is + /// no update for any of the given `priceIds` within the given time range. + /// @param updateData Array of price update data. + /// @param priceIds Array of price ids. + /// @param minPublishTime minimum acceptable publishTime for the given `priceIds`. + /// @param maxPublishTime maximum acceptable publishTime for the given `priceIds`. + /// @return priceFeeds Array of the price feeds corresponding to the given `priceIds` (with the same order). + function parsePriceFeedUpdates( + bytes[] calldata updateData, + bytes32[] calldata priceIds, + uint64 minPublishTime, + uint64 maxPublishTime + ) external payable returns (PythStructs.PriceFeed[] memory priceFeeds); +} + + +contract C { + IPyth pyth; + + constructor(IPyth _pyth) { + pyth = _pyth; + } + + function bad(bytes32 id) public { + PythStructs.Price memory price = pyth.getEmaPriceUnsafe(id); + require(price.conf < 10000); + // Use price + } + + function good(bytes32 id) public { + PythStructs.Price memory price = pyth.getEmaPriceUnsafe(id); + require(price.publishTime > block.timestamp - 120); + require(price.conf < 10000); + // Use price + } + + function good2(bytes32 id) public { + PythStructs.Price memory price = pyth.getEmaPriceUnsafe(id); + require(price.conf < 10000); + if (price.publishTime <= block.timestamp - 120) { + revert(); + } + // Use price + } + +} \ No newline at end of file diff --git a/tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol-0.8.20.zip b/tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol-0.8.20.zip new file mode 100644 index 000000000..178b65b38 Binary files /dev/null and b/tests/e2e/detectors/test_data/pyth-unchecked-publishtime/0.8.20/pyth_unchecked_publishtime.sol-0.8.20.zip differ diff --git a/tests/e2e/detectors/test_detectors.py b/tests/e2e/detectors/test_detectors.py index 2c6a5f55a..d2f191a4d 100644 --- a/tests/e2e/detectors/test_detectors.py +++ b/tests/e2e/detectors/test_detectors.py @@ -1714,6 +1714,41 @@ ALL_TESTS = [ "out_of_order_retryable.sol", "0.8.20", ), + Test( + all_detectors.GelatoUnprotectedRandomness, + "gelato_unprotected_randomness.sol", + "0.8.20", + ), + Test( + all_detectors.ChronicleUncheckedPrice, + "chronicle_unchecked_price.sol", + "0.8.20", + ), + Test( + all_detectors.PythUncheckedConfidence, + "pyth_unchecked_confidence.sol", + "0.8.20", + ), + Test( + all_detectors.PythUncheckedPublishTime, + "pyth_unchecked_publishtime.sol", + "0.8.20", + ), + Test( + all_detectors.ChainlinkFeedRegistry, + "chainlink_feed_registry.sol", + "0.8.20", + ), + Test( + all_detectors.PythDeprecatedFunctions, + "pyth_deprecated_functions.sol", + "0.8.20", + ), + Test( + all_detectors.OptimismDeprecation, + "optimism_deprecation.sol", + "0.8.20", + ), # Test( # all_detectors.UnusedImport, # "ConstantContractLevelUsedInContractTest.sol", diff --git a/tests/e2e/printers/test_data/test_printer_slithir/bug-2266.sol b/tests/e2e/printers/test_data/test_printer_slithir/bug-2266.sol new file mode 100644 index 000000000..5c11a2914 --- /dev/null +++ b/tests/e2e/printers/test_data/test_printer_slithir/bug-2266.sol @@ -0,0 +1,13 @@ +pragma solidity ^0.8.0; + +contract A { + function add(uint256 a, uint256 b) public returns (uint256) { + return a + b; + } +} + +contract B is A { + function assignFunction() public { + function(uint256, uint256) returns (uint256) myFunction = super.add; + } +} \ No newline at end of file diff --git a/tests/e2e/printers/test_printers.py b/tests/e2e/printers/test_printers.py index 3dea8b74a..aa5d7f8a4 100644 --- a/tests/e2e/printers/test_printers.py +++ b/tests/e2e/printers/test_printers.py @@ -7,6 +7,7 @@ from crytic_compile.platform.solc_standard_json import SolcStandardJson from slither import Slither from slither.printers.inheritance.inheritance_graph import PrinterInheritanceGraph +from slither.printers.summary.slithir import PrinterSlithIR TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" @@ -34,8 +35,7 @@ def test_inheritance_printer(solc_binary_path) -> None: assert counter["B -> A"] == 2 assert counter["C -> A"] == 1 - - # Lets also test the include/exclude interface behavior + # Let also test the include/exclude interface behavior # Check that the interface is not included assert "MyInterfaceX" not in content @@ -46,3 +46,18 @@ def test_inheritance_printer(solc_binary_path) -> None: # Remove test generated files Path("test_printer.dot").unlink(missing_ok=True) + + +def test_slithir_printer(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.0") + standard_json = SolcStandardJson() + standard_json.add_source_file( + Path(TEST_DATA_DIR, "test_printer_slithir", "bug-2266.sol").as_posix() + ) + compilation = CryticCompile(standard_json, solc=solc_path) + slither = Slither(compilation) + + printer = PrinterSlithIR(slither, logger=None) + output = printer.output("test_printer_slithir.dot") + + assert "slither.core.solidity_types" not in output.data["description"] diff --git a/tests/e2e/solc_parsing/test_ast_parsing.py b/tests/e2e/solc_parsing/test_ast_parsing.py index ca3872f8c..6ec7b6fbd 100644 --- a/tests/e2e/solc_parsing/test_ast_parsing.py +++ b/tests/e2e/solc_parsing/test_ast_parsing.py @@ -475,6 +475,7 @@ ALL_TESTS = [ Test("solidity-0.8.24.sol", ["0.8.24"], solc_args="--evm-version cancun"), Test("scope/inherited_function_scope.sol", ["0.8.24"]), Test("using_for_global_user_defined_operator_1.sol", ["0.8.24"]), + Test("require-error.sol", ["0.8.27"]), ] # create the output folder if needed try: diff --git a/tests/e2e/solc_parsing/test_data/compile/require-error.sol-0.8.27-compact.zip b/tests/e2e/solc_parsing/test_data/compile/require-error.sol-0.8.27-compact.zip new file mode 100644 index 000000000..63aa223b3 Binary files /dev/null and b/tests/e2e/solc_parsing/test_data/compile/require-error.sol-0.8.27-compact.zip differ diff --git a/tests/e2e/solc_parsing/test_data/expected/require-error.sol-0.8.27-compact.json b/tests/e2e/solc_parsing/test_data/expected/require-error.sol-0.8.27-compact.json new file mode 100644 index 000000000..3c3089c04 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/expected/require-error.sol-0.8.27-compact.json @@ -0,0 +1,5 @@ +{ + "TestToken": { + "transferWithRequireError(address,uint256)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n1->2;\n2[label=\"Node Type: EXPRESSION 2\n\"];\n2->3;\n3[label=\"Node Type: EXPRESSION 3\n\"];\n}\n" + } +} \ No newline at end of file diff --git a/tests/e2e/solc_parsing/test_data/require-error.sol b/tests/e2e/solc_parsing/test_data/require-error.sol new file mode 100644 index 000000000..97c8ecac4 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/require-error.sol @@ -0,0 +1,20 @@ +pragma solidity 0.8.27; + +/// Insufficient balance for transfer. Needed `required` but only +/// `available` available. +/// @param available balance available. +/// @param required requested amount to transfer. +error InsufficientBalance(uint256 available, uint256 required); + +contract TestToken { + mapping(address => uint) balance; + function transferWithRequireError(address to, uint256 amount) public { + require( + balance[msg.sender] >= amount, + InsufficientBalance(balance[msg.sender], amount) + ); + balance[msg.sender] -= amount; + balance[to] += amount; + } + // ... +} diff --git a/tests/unit/slithir/test_constantfolding.py b/tests/unit/slithir/test_constantfolding.py new file mode 100644 index 000000000..fcf00035b --- /dev/null +++ b/tests/unit/slithir/test_constantfolding.py @@ -0,0 +1,22 @@ +from pathlib import Path + +from slither import Slither +from slither.printers.guidance.echidna import _extract_constants, ConstantValue + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_enum_max_min(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.19") + slither = Slither(Path(TEST_DATA_DIR, "constantfolding.sol").as_posix(), solc=solc_path) + + contracts = slither.get_contract_from_name("A") + + constants = _extract_constants(contracts)[0]["A"]["use()"] + + assert set(constants) == { + ConstantValue(value="2", type="uint256"), + ConstantValue(value="10", type="uint256"), + ConstantValue(value="100", type="uint256"), + ConstantValue(value="4294967295", type="uint32"), + } diff --git a/tests/unit/slithir/test_data/constantfolding.sol b/tests/unit/slithir/test_data/constantfolding.sol new file mode 100644 index 000000000..aef4a2427 --- /dev/null +++ b/tests/unit/slithir/test_data/constantfolding.sol @@ -0,0 +1,19 @@ +type MyType is uint256; + +contract A{ + + enum E{ + a,b,c + } + + + uint a = 10; + E b = type(E).max; + uint c = type(uint32).max; + MyType d = MyType.wrap(100); + + function use() public returns(uint){ + E e = b; + return a +c + MyType.unwrap(d); + } +} \ No newline at end of file diff --git a/tests/unit/slithir/test_data/selector.sol b/tests/unit/slithir/test_data/selector.sol new file mode 100644 index 000000000..60ec33cd6 --- /dev/null +++ b/tests/unit/slithir/test_data/selector.sol @@ -0,0 +1,47 @@ +interface I{ + function testFunction(uint a) external ; +} + +contract A{ + function testFunction() public{} +} + +contract Test{ + event TestEvent(); + struct St{ + uint a; + } + error TestError(); + + function testFunction(uint a) public {} + + + function testFunctionStructure(St memory s) public {} + + function returnEvent() public returns (bytes32){ + return TestEvent.selector; + } + + function returnError() public returns (bytes4){ + return TestError.selector; + } + + + function returnFunctionFromContract() public returns (bytes4){ + return I.testFunction.selector; + } + + + function returnFunction() public returns (bytes4){ + return this.testFunction.selector; + } + + function returnFunctionWithStructure() public returns (bytes4){ + return this.testFunctionStructure.selector; + } + + function returnFunctionThroughLocaLVar() public returns(bytes4){ + A a; + return a.testFunction.selector; + } +} \ No newline at end of file diff --git a/tests/unit/slithir/test_selector.py b/tests/unit/slithir/test_selector.py new file mode 100644 index 000000000..34643b58d --- /dev/null +++ b/tests/unit/slithir/test_selector.py @@ -0,0 +1,32 @@ +from pathlib import Path +from slither import Slither +from slither.slithir.operations import Assignment +from slither.slithir.variables import Constant + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +func_to_results = { + "returnEvent()": "16700440330922901039223184000601971290390760458944929668086539975128325467771", + "returnError()": "224292994", + "returnFunctionFromContract()": "890000139", + "returnFunction()": "890000139", + "returnFunctionWithStructure()": "1430834845", + "returnFunctionThroughLocaLVar()": "3781905051", +} + + +def test_enum_max_min(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.19") + slither = Slither(Path(TEST_DATA_DIR, "selector.sol").as_posix(), solc=solc_path) + + contract = slither.get_contract_from_name("Test")[0] + + for func_name, value in func_to_results.items(): + f = contract.get_function_from_signature(func_name) + assignment = f.slithir_operations[0] + assert ( + isinstance(assignment, Assignment) + and isinstance(assignment.rvalue, Constant) + and assignment.rvalue.value == value + ) diff --git a/tests/unit/slithir/vyper/test_ir_generation.py b/tests/unit/slithir/vyper/test_ir_generation.py index 73c9b5e70..efcf5ce54 100644 --- a/tests/unit/slithir/vyper/test_ir_generation.py +++ b/tests/unit/slithir/vyper/test_ir_generation.py @@ -35,9 +35,9 @@ def bar(): interface = next(iter(x for x in sl.contracts if x.is_interface)) contract = next(iter(x for x in sl.contracts if not x.is_interface)) func = contract.get_function_from_signature("bar()") - (contract, function) = func.high_level_calls[0] + (contract, ir) = func.high_level_calls[0] assert contract == interface - assert function.signature_str == "foo() returns(int128,uint256)" + assert ir.function.signature_str == "foo() returns(int128,uint256)" def test_phi_entry_point_internal_call(slither_from_vyper_source):