diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index fd8f761c6..8e56120fe 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -93,6 +93,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods self._signatures: Optional[List[str]] = None self._signatures_declared: Optional[List[str]] = None + self._fallback_function: Optional["FunctionContract"] = None + self._receive_function: Optional["FunctionContract"] = None + self._is_upgradeable: Optional[bool] = None self._is_upgradeable_proxy: Optional[bool] = None self._upgradeable_version: Optional[str] = None @@ -663,6 +666,24 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return self.functions_declared + self.modifiers_declared # type: ignore + @property + def fallback_function(self) -> Optional["FunctionContract"]: + if self._fallback_function is None: + for f in self.functions: + if f.is_fallback: + self._fallback_function = f + break + return self._fallback_function + + @property + def receive_function(self) -> Optional["FunctionContract"]: + if self._receive_function is None: + for f in self.functions: + if f.is_receive: + self._receive_function = f + break + return self._receive_function + def available_elements_from_inheritances( self, elements: Dict[str, "Function"], diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index fc83c44c6..002559b6e 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -6,6 +6,7 @@ from slither.tools.upgradeability.checks.abstract_checks import ( AbstractCheck, CHECK_INFO, ) +from slither.utils.upgradeability import get_missing_vars from slither.utils.output import Output @@ -55,25 +56,13 @@ Do not change the order of the state variables in the updated contract. contract2 = self.contract_v2 assert contract2 - - order1 = [ - variable - for variable in contract1.state_variables_ordered - if not (variable.is_constant or variable.is_immutable) - ] - order2 = [ - variable - for variable in contract2.state_variables_ordered - if not (variable.is_constant or variable.is_immutable) - ] + missing = get_missing_vars(contract1, contract2) results = [] - for idx, _ in enumerate(order1): - variable1 = order1[idx] - if len(order2) <= idx: - info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"] - json = self.generate_result(info) - results.append(json) + for variable1 in missing: + info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"] + json = self.generate_result(info) + results.append(json) return results diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py new file mode 100644 index 000000000..7b4e8493a --- /dev/null +++ b/slither/utils/upgradeability.py @@ -0,0 +1,580 @@ +from typing import Optional, Tuple, List, Union +from slither.core.declarations import ( + Contract, + Structure, + Enum, + SolidityVariableComposed, + SolidityVariable, + Function, +) +from slither.core.solidity_types import ( + Type, + ElementaryType, + ArrayType, + MappingType, + UserDefinedType, +) +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.state_variable import StateVariable +from slither.analyses.data_dependency.data_dependency import get_dependencies +from slither.core.variables.variable import Variable +from slither.core.expressions.literal import Literal +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.cfg.node import Node, NodeType +from slither.slithir.operations import ( + Operation, + Assignment, + Index, + Member, + Length, + Binary, + Unary, + Condition, + NewArray, + NewStructure, + NewContract, + NewElementaryType, + SolidityCall, + Delete, + EventCall, + LibraryCall, + InternalDynamicCall, + HighLevelCall, + LowLevelCall, + TypeConversion, + Return, + Transfer, + Send, + Unpack, + InitArray, + InternalCall, +) +from slither.slithir.variables import ( + TemporaryVariable, + TupleVariable, + Constant, + ReferenceVariable, +) +from slither.tools.read_storage.read_storage import SlotInfo, SlitherReadStorage + + +# pylint: disable=too-many-locals +def compare( + v1: Contract, v2: Contract +) -> Tuple[ + List[Variable], List[Variable], List[Variable], List[Function], List[Function], List[Function] +]: + """ + Compares two versions of a contract. Most useful for upgradeable (logic) contracts, + but does not require that Contract.is_upgradeable returns true for either contract. + + Args: + v1: Original version of (upgradeable) contract + v2: Updated version of (upgradeable) contract + + Returns: + missing-vars-in-v2: list[Variable], + new-variables: list[Variable], + tainted-variables: list[Variable], + new-functions: list[Function], + modified-functions: list[Function], + tainted-functions: list[Function] + """ + + order_vars1 = [ + v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + order_vars2 = [ + v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + func_sigs1 = [function.solidity_signature for function in v1.functions] + func_sigs2 = [function.solidity_signature for function in v2.functions] + + missing_vars_in_v2 = [] + new_variables = [] + tainted_variables = [] + new_functions = [] + modified_functions = [] + tainted_functions = [] + + # Since this is not a detector, include any missing variables in the v2 contract + if len(order_vars2) < len(order_vars1): + missing_vars_in_v2.extend(get_missing_vars(v1, v2)) + + # Find all new and modified functions in the v2 contract + new_modified_functions = [] + new_modified_function_vars = [] + for sig in func_sigs2: + function = v2.get_function_from_signature(sig) + orig_function = v1.get_function_from_signature(sig) + if sig not in func_sigs1: + new_modified_functions.append(function) + new_functions.append(function) + new_modified_function_vars += ( + function.state_variables_read + function.state_variables_written + ) + elif not function.is_constructor_variables and is_function_modified( + orig_function, function + ): + new_modified_functions.append(function) + modified_functions.append(function) + new_modified_function_vars += ( + function.state_variables_read + function.state_variables_written + ) + + # Find all unmodified functions that call a modified function or read/write the + # same state variable(s) as a new/modified function, i.e., tainted functions + for function in v2.functions: + if ( + function in new_modified_functions + or function.is_constructor + or function.name.startswith("slither") + ): + continue + modified_calls = [ + func for func in new_modified_functions if func in function.internal_calls + ] + tainted_vars = [ + var + for var in set(new_modified_function_vars) + if var in function.variables_read_or_written + and not var.is_constant + and not var.is_immutable + ] + if len(modified_calls) > 0 or len(tainted_vars) > 0: + tainted_functions.append(function) + + # Find all new or tainted variables, i.e., variables that are read or written by a new/modified/tainted function + for var in order_vars2: + read_by = v2.get_functions_reading_from_variable(var) + written_by = v2.get_functions_writing_to_variable(var) + if v1.get_state_variable_from_name(var.name) is None: + new_variables.append(var) + elif any( + func in read_by or func in written_by + for func in new_modified_functions + tainted_functions + ): + tainted_variables.append(var) + + return ( + missing_vars_in_v2, + new_variables, + tainted_variables, + new_functions, + modified_functions, + tainted_functions, + ) + + +def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]: + """ + Gets all non-constant/immutable StateVariables that appear in v1 but not v2 + Args: + v1: Contract version 1 + v2: Contract version 2 + + Returns: + List of StateVariables from v1 missing in v2 + """ + results = [] + order_vars1 = [ + v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + order_vars2 = [ + v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable + ] + if len(order_vars2) < len(order_vars1): + for variable in order_vars1: + if variable.name not in [v.name for v in order_vars2]: + results.append(variable) + return results + + +def is_function_modified(f1: Function, f2: Function) -> bool: + """ + Compares two versions of a function, and returns True if the function has been modified. + First checks whether the functions' content hashes are equal to quickly rule out identical functions. + Walks the CFGs and compares IR operations if hashes differ to rule out false positives, i.e., from changed comments. + + Args: + f1: Original version of the function + f2: New version of the function + + Returns: + True if the functions differ, otherwise False + """ + # If the function content hashes are the same, no need to investigate the function further + if f1.source_mapping.content_hash == f2.source_mapping.content_hash: + return False + # If the hashes differ, it is possible a change in a name or in a comment could be the only difference + # So we need to resort to walking through the CFG and comparing the IR operations + queue_f1 = [f1.entry_point] + queue_f2 = [f2.entry_point] + visited = [] + while len(queue_f1) > 0 and len(queue_f2) > 0: + node_f1 = queue_f1.pop(0) + node_f2 = queue_f2.pop(0) + visited.extend([node_f1, node_f2]) + queue_f1.extend(son for son in node_f1.sons if son not in visited) + queue_f2.extend(son for son in node_f2.sons if son not in visited) + for i, ir in enumerate(node_f1.irs): + if encode_ir_for_compare(ir) != encode_ir_for_compare(node_f2.irs[i]): + return True + return False + + +# pylint: disable=too-many-branches +def ntype(_type: Union[Type, str]) -> str: + if isinstance(_type, ElementaryType): + _type = str(_type) + elif isinstance(_type, ArrayType): + if isinstance(_type.type, ElementaryType): + _type = str(_type) + else: + _type = "user_defined_array" + elif isinstance(_type, Structure): + _type = str(_type) + elif isinstance(_type, Enum): + _type = str(_type) + elif isinstance(_type, MappingType): + _type = str(_type) + elif isinstance(_type, UserDefinedType): + if isinstance(_type.type, Contract): + _type = f"contract({_type.type.name})" + elif isinstance(_type.type, Structure): + _type = f"struct({_type.type.name})" + elif isinstance(_type.type, Enum): + _type = f"enum({_type.type.name})" + else: + _type = str(_type) + + _type = _type.replace(" memory", "") + _type = _type.replace(" storage ref", "") + + if "struct" in _type: + return "struct" + if "enum" in _type: + return "enum" + if "tuple" in _type: + return "tuple" + if "contract" in _type: + return "contract" + if "mapping" in _type: + return "mapping" + return _type.replace(" ", "_") + + +# pylint: disable=too-many-branches +def encode_ir_for_compare(ir: Operation) -> str: + # operations + if isinstance(ir, Assignment): + return f"({encode_var_for_compare(ir.lvalue)}):=({encode_var_for_compare(ir.rvalue)})" + if isinstance(ir, Index): + return f"index({ntype(ir.index_type)})" + if isinstance(ir, Member): + return "member" # .format(ntype(ir._type)) + if isinstance(ir, Length): + return "length" + if isinstance(ir, Binary): + return f"binary({str(ir.variable_left)}{str(ir.type)}{str(ir.variable_right)})" + if isinstance(ir, Unary): + return f"unary({str(ir.type)})" + if isinstance(ir, Condition): + return f"condition({encode_var_for_compare(ir.value)})" + if isinstance(ir, NewStructure): + return "new_structure" + if isinstance(ir, NewContract): + return "new_contract" + if isinstance(ir, NewArray): + return f"new_array({ntype(ir.array_type)})" + if isinstance(ir, NewElementaryType): + return f"new_elementary({ntype(ir.type)})" + if isinstance(ir, Delete): + return f"delete({encode_var_for_compare(ir.lvalue)},{encode_var_for_compare(ir.variable)})" + if isinstance(ir, SolidityCall): + return f"solidity_call({ir.function.full_name})" + if isinstance(ir, InternalCall): + return f"internal_call({ntype(ir.type_call)})" + if isinstance(ir, EventCall): # is this useful? + return "event" + if isinstance(ir, LibraryCall): + return "library_call" + if isinstance(ir, InternalDynamicCall): + return "internal_dynamic_call" + if isinstance(ir, HighLevelCall): # TODO: improve + return "high_level_call" + if isinstance(ir, LowLevelCall): # TODO: improve + return "low_level_call" + if isinstance(ir, TypeConversion): + return f"type_conversion({ntype(ir.type)})" + if isinstance(ir, Return): # this can be improved using values + return "return" # .format(ntype(ir.type)) + if isinstance(ir, Transfer): + return f"transfer({encode_var_for_compare(ir.call_value)})" + if isinstance(ir, Send): + return f"send({encode_var_for_compare(ir.call_value)})" + if isinstance(ir, Unpack): # TODO: improve + return "unpack" + if isinstance(ir, InitArray): # TODO: improve + return "init_array" + + # default + return "" + + +# pylint: disable=too-many-branches +def encode_var_for_compare(var: Variable) -> str: + + # variables + if isinstance(var, Constant): + return f"constant({ntype(var.type)})" + if isinstance(var, SolidityVariableComposed): + return f"solidity_variable_composed({var.name})" + if isinstance(var, SolidityVariable): + return f"solidity_variable{var.name}" + if isinstance(var, TemporaryVariable): + return "temporary_variable" + if isinstance(var, ReferenceVariable): + return f"reference({ntype(var.type)})" + if isinstance(var, LocalVariable): + return f"local_solc_variable({var.location})" + if isinstance(var, StateVariable): + return f"state_solc_variable({ntype(var.type)})" + if isinstance(var, LocalVariableInitFromTuple): + return "local_variable_init_tuple" + if isinstance(var, TupleVariable): + return "tuple_variable" + + # default + return "" + + +def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]: + """ + Gets information about the storage slot where a proxy's implementation address is stored. + Args: + proxy: A Contract object (proxy.is_upgradeable_proxy should be true). + + Returns: + (`SlotInfo`) | None : A dictionary of the slot information. + """ + + delegate = get_proxy_implementation_var(proxy) + if isinstance(delegate, StateVariable): + if not delegate.is_constant and not delegate.is_immutable: + srs = SlitherReadStorage([proxy], 20) + return srs.get_storage_slot(delegate, proxy) + if delegate.is_constant and delegate.type.name == "bytes32": + return SlotInfo( + name=delegate.name, + type_string="address", + slot=int(delegate.expression.value, 16), + size=160, + offset=0, + ) + return None + + +def get_proxy_implementation_var(proxy: Contract) -> Optional[Variable]: + """ + Gets the Variable that stores a proxy's implementation address. Uses data dependency to trace any LocalVariable + that is passed into a delegatecall as the target address back to its data source, ideally a StateVariable. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. + Args: + proxy: A Contract object (proxy.is_upgradeable_proxy should be true). + + Returns: + (`Variable`) | None : The variable, ideally a StateVariable, which stores the proxy's implementation address. + """ + if not proxy.is_upgradeable_proxy or not proxy.fallback_function: + return None + + delegate = find_delegate_in_fallback(proxy) + if isinstance(delegate, LocalVariable): + dependencies = get_dependencies(delegate, proxy) + try: + delegate = next(var for var in dependencies if isinstance(var, StateVariable)) + except StopIteration: + return delegate + return delegate + + +def find_delegate_in_fallback(proxy: Contract) -> Optional[Variable]: + """ + Searches a proxy's fallback function for a delegatecall, then extracts the Variable being passed in as the target. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. + Should typically be called by get_proxy_implementation_var(proxy). + Args: + proxy: A Contract object (should have a fallback function). + + Returns: + (`Variable`) | None : The variable being passed as the destination argument in a delegatecall in the fallback. + """ + delegate: Optional[Variable] = None + fallback = proxy.fallback_function + for node in fallback.all_nodes(): + for ir in node.irs: + if isinstance(ir, LowLevelCall) and ir.function_name == "delegatecall": + delegate = ir.destination + if delegate is not None: + break + if ( + node.type == NodeType.ASSEMBLY + and isinstance(node.inline_asm, str) + and "delegatecall" in node.inline_asm + ): + delegate = extract_delegate_from_asm(proxy, node) + elif node.type == NodeType.EXPRESSION: + expression = node.expression + if isinstance(expression, AssignmentOperation): + expression = expression.expression_right + if ( + isinstance(expression, CallExpression) + and "delegatecall" in str(expression.called) + and len(expression.arguments) > 1 + ): + dest = expression.arguments[1] + if isinstance(dest, CallExpression) and "sload" in str(dest.called): + dest = dest.arguments[0] + if isinstance(dest, Identifier): + delegate = dest.value + break + if ( + isinstance(dest, Literal) and len(dest.value) == 66 + ): # 32 bytes = 64 chars + "0x" = 66 chars + # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, + # so create a new StateVariable to represent it. + delegate = create_state_variable_from_slot(dest.value) + break + return delegate + + +def extract_delegate_from_asm(contract: Contract, node: Node) -> Optional[Variable]: + """ + Finds a Variable with a name matching the argument passed into a delegatecall, when all we have is an Assembly node + with a block of code as one long string. Usually only the case for solc versions < 0.6.0. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. + Should typically be called by find_delegate_in_fallback(proxy). + Args: + contract: The parent Contract. + node: The Assembly Node (i.e., node.type == NodeType.ASSEMBLY) + + Returns: + (`Variable`) | None : The variable being passed as the destination argument in a delegatecall in the fallback. + """ + asm_split = str(node.inline_asm).split("\n") + asm = next(line for line in asm_split if "delegatecall" in line) + params = asm.split("call(")[1].split(", ") + dest = params[1] + if dest.endswith(")") and not dest.startswith("sload("): + dest = params[2] + if dest.startswith("sload("): + dest = dest.replace(")", "(").split("(")[1] + if dest.startswith("0x"): + return create_state_variable_from_slot(dest) + if dest.isnumeric(): + slot_idx = int(dest) + return next( + ( + v + for v in contract.state_variables_ordered + if SlitherReadStorage.get_variable_info(contract, v)[0] == slot_idx + ), + None, + ) + for v in node.function.variables_read_or_written: + if v.name == dest: + if isinstance(v, LocalVariable) and v.expression is not None: + e = v.expression + if isinstance(e, Identifier) and isinstance(e.value, StateVariable): + v = e.value + # Fall through, return constant storage slot + if isinstance(v, StateVariable) and v.is_constant: + return v + if "_fallback_asm" in dest or "_slot" in dest: + dest = dest.split("_")[0] + return find_delegate_from_name(contract, dest, node.function) + + +def find_delegate_from_name( + contract: Contract, dest: str, parent_func: Function +) -> Optional[Variable]: + """ + Searches for a variable with a given name, starting with StateVariables declared in the contract, followed by + LocalVariables in the parent function, either declared in the function body or as parameters in the signature. + Can return a newly created StateVariable if an `sload` from a hardcoded storage slot is found in assembly. + Args: + contract: The Contract object to search. + dest: The variable name to search for. + parent_func: The Function object to search. + + Returns: + (`Variable`) | None : The variable with the matching name, if found + """ + for sv in contract.state_variables: + if sv.name == dest: + return sv + for lv in parent_func.local_variables: + if lv.name == dest: + return lv + for pv in parent_func.parameters + parent_func.returns: + if pv.name == dest: + return pv + if parent_func.contains_assembly: + for node in parent_func.all_nodes(): + if node.type == NodeType.ASSEMBLY and isinstance(node.inline_asm, str): + asm = next( + ( + s + for s in node.inline_asm.split("\n") + if f"{dest}:=sload(" in s.replace(" ", "") + ), + None, + ) + if asm: + slot = asm.split("sload(")[1].split(")")[0] + if slot.startswith("0x"): + return create_state_variable_from_slot(slot, name=dest) + try: + slot_idx = int(slot) + return next( + ( + v + for v in contract.state_variables_ordered + if SlitherReadStorage.get_variable_info(contract, v)[0] == slot_idx + ), + None, + ) + except TypeError: + continue + return None + + +def create_state_variable_from_slot(slot: str, name: str = None) -> Optional[StateVariable]: + """ + Creates a new StateVariable object to wrap a hardcoded storage slot found in assembly. + Args: + slot: The storage slot hex string. + name: Optional name for the variable. The slot string is used if name is not provided. + + Returns: + A newly created constant StateVariable of type bytes32, with the slot as the variable's expression and name, + if slot matches the length and prefix of a bytes32. Otherwise, returns None. + """ + if len(slot) == 66 and slot.startswith("0x"): # 32 bytes = 64 chars + "0x" = 66 chars + # Storage slot is not declared as a constant, but rather is hardcoded in the assembly, + # so create a new StateVariable to represent it. + v = StateVariable() + v.is_constant = True + v.expression = Literal(slot, ElementaryType("bytes32")) + if name is not None: + v.name = name + else: + v.name = slot + v.type = ElementaryType("bytes32") + return v + # This should probably also handle hashed strings, but for now return None + return None diff --git a/tests/unit/core/test_data/fallback.sol b/tests/unit/core/test_data/fallback.sol new file mode 100644 index 000000000..cd7dc1812 --- /dev/null +++ b/tests/unit/core/test_data/fallback.sol @@ -0,0 +1,29 @@ +pragma solidity ^0.6.12; + +contract FakeFallback { + mapping(address => uint) public contributions; + address payable public owner; + + constructor() public { + owner = payable(msg.sender); + contributions[msg.sender] = 1000 * (1 ether); + } + + function fallback() public payable { + contributions[msg.sender] += msg.value; + } + + function receive() public payable { + contributions[msg.sender] += msg.value; + } +} + +contract Fallback is FakeFallback { + receive() external payable { + contributions[msg.sender] += msg.value; + } + + fallback() external payable { + contributions[msg.sender] += msg.value; + } +} diff --git a/tests/unit/core/test_fallback_receive.py b/tests/unit/core/test_fallback_receive.py new file mode 100644 index 000000000..505a9dd6f --- /dev/null +++ b/tests/unit/core/test_fallback_receive.py @@ -0,0 +1,20 @@ +from pathlib import Path +from solc_select import solc_select + +from slither import Slither +from slither.core.declarations.function import FunctionType + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_fallback_receive(): + solc_select.switch_global_version("0.6.12", always_install=True) + file = Path(TEST_DATA_DIR, "fallback.sol").as_posix() + slither = Slither(file) + fake_fallback = slither.get_contract_from_name("FakeFallback")[0] + real_fallback = slither.get_contract_from_name("Fallback")[0] + + assert fake_fallback.fallback_function is None + assert fake_fallback.receive_function is None + assert real_fallback.fallback_function.function_type == FunctionType.FALLBACK + assert real_fallback.receive_function.function_type == FunctionType.RECEIVE diff --git a/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.5.0.sol b/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.5.0.sol new file mode 100644 index 000000000..eaecfa6e9 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.5.0.sol @@ -0,0 +1,6 @@ +pragma solidity ^0.5.0; + +import "./src/EIP1822Proxy.sol"; +import "./src/ZosProxy.sol"; +import "./src/MasterCopyProxy.sol"; +import "./src/SynthProxy.sol"; diff --git a/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.8.2.sol b/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.8.2.sol new file mode 100644 index 000000000..d3371d3c6 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/TestUpgrades-0.8.2.sol @@ -0,0 +1,6 @@ +pragma solidity ^0.8.2; + +import "./src/ContractV1.sol"; +import "./src/ContractV2.sol"; +import "./src/InheritedStorageProxy.sol"; +import "./src/ERC1967Proxy.sol"; diff --git a/tests/unit/utils/test_data/upgradeability_util/src/Address.sol b/tests/unit/utils/test_data/upgradeability_util/src/Address.sol new file mode 100644 index 000000000..d440b259e --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/Address.sol @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v4.8.0) (utils/Address.sol) + +pragma solidity ^0.8.1; + +/** + * @dev Collection of functions related to the address type + */ +library Address { + /** + * @dev Returns true if `account` is a contract. + * + * [IMPORTANT] + * ==== + * It is unsafe to assume that an address for which this function returns + * false is an externally-owned account (EOA) and not a contract. + * + * Among others, `isContract` will return false for the following + * types of addresses: + * + * - an externally-owned account + * - a contract in construction + * - an address where a contract will be created + * - an address where a contract lived, but was destroyed + * ==== + * + * [IMPORTANT] + * ==== + * You shouldn't rely on `isContract` to protect against flash loan attacks! + * + * Preventing calls from contracts is highly discouraged. It breaks composability, breaks support for smart wallets + * like Gnosis Safe, and does not provide security since it can be circumvented by calling from a contract + * constructor. + * ==== + */ + function isContract(address account) internal view returns (bool) { + // This method relies on extcodesize/address.code.length, which returns 0 + // for contracts in construction, since the code is only stored at the end + // of the constructor execution. + + return account.code.length > 0; + } + + /** + * @dev Replacement for Solidity's `transfer`: sends `amount` wei to + * `recipient`, forwarding all available gas and reverting on errors. + * + * https://eips.ethereum.org/EIPS/eip-1884[EIP1884] increases the gas cost + * of certain opcodes, possibly making contracts go over the 2300 gas limit + * imposed by `transfer`, making them unable to receive funds via + * `transfer`. {sendValue} removes this limitation. + * + * https://diligence.consensys.net/posts/2019/09/stop-using-soliditys-transfer-now/[Learn more]. + * + * IMPORTANT: because control is transferred to `recipient`, care must be + * taken to not create reentrancy vulnerabilities. Consider using + * {ReentrancyGuard} or the + * https://solidity.readthedocs.io/en/v0.5.11/security-considerations.html#use-the-checks-effects-interactions-pattern[checks-effects-interactions pattern]. + */ + function sendValue(address payable recipient, uint256 amount) internal { + require(address(this).balance >= amount, "Address: insufficient balance"); + + (bool success, ) = recipient.call{value: amount}(""); + require(success, "Address: unable to send value, recipient may have reverted"); + } + + /** + * @dev Performs a Solidity function call using a low level `call`. A + * plain `call` is an unsafe replacement for a function call: use this + * function instead. + * + * If `target` reverts with a revert reason, it is bubbled up by this + * function (like regular Solidity function calls). + * + * Returns the raw returned data. To convert to the expected return value, + * use https://solidity.readthedocs.io/en/latest/units-and-global-variables.html?highlight=abi.decode#abi-encoding-and-decoding-functions[`abi.decode`]. + * + * Requirements: + * + * - `target` must be a contract. + * - calling `target` with `data` must not revert. + * + * _Available since v3.1._ + */ + function functionCall(address target, bytes memory data) internal returns (bytes memory) { + return functionCallWithValue(target, data, 0, "Address: low-level call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], but with + * `errorMessage` as a fallback revert reason when `target` reverts. + * + * _Available since v3.1._ + */ + function functionCall( + address target, + bytes memory data, + string memory errorMessage + ) internal returns (bytes memory) { + return functionCallWithValue(target, data, 0, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but also transferring `value` wei to `target`. + * + * Requirements: + * + * - the calling contract must have an ETH balance of at least `value`. + * - the called Solidity function must be `payable`. + * + * _Available since v3.1._ + */ + function functionCallWithValue( + address target, + bytes memory data, + uint256 value + ) internal returns (bytes memory) { + return functionCallWithValue(target, data, value, "Address: low-level call with value failed"); + } + + /** + * @dev Same as {xref-Address-functionCallWithValue-address-bytes-uint256-}[`functionCallWithValue`], but + * with `errorMessage` as a fallback revert reason when `target` reverts. + * + * _Available since v3.1._ + */ + function functionCallWithValue( + address target, + bytes memory data, + uint256 value, + string memory errorMessage + ) internal returns (bytes memory) { + require(address(this).balance >= value, "Address: insufficient balance for call"); + (bool success, bytes memory returndata) = target.call{value: value}(data); + return verifyCallResultFromTarget(target, success, returndata, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but performing a static call. + * + * _Available since v3.3._ + */ + function functionStaticCall(address target, bytes memory data) internal view returns (bytes memory) { + return functionStaticCall(target, data, "Address: low-level static call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-string-}[`functionCall`], + * but performing a static call. + * + * _Available since v3.3._ + */ + function functionStaticCall( + address target, + bytes memory data, + string memory errorMessage + ) internal view returns (bytes memory) { + (bool success, bytes memory returndata) = target.staticcall(data); + return verifyCallResultFromTarget(target, success, returndata, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but performing a delegate call. + * + * _Available since v3.4._ + */ + function functionDelegateCall(address target, bytes memory data) internal returns (bytes memory) { + return functionDelegateCall(target, data, "Address: low-level delegate call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-string-}[`functionCall`], + * but performing a delegate call. + * + * _Available since v3.4._ + */ + function functionDelegateCall( + address target, + bytes memory data, + string memory errorMessage + ) internal returns (bytes memory) { + (bool success, bytes memory returndata) = target.delegatecall(data); + return verifyCallResultFromTarget(target, success, returndata, errorMessage); + } + + /** + * @dev Tool to verify that a low level call to smart-contract was successful, and revert (either by bubbling + * the revert reason or using the provided one) in case of unsuccessful call or if target was not a contract. + * + * _Available since v4.8._ + */ + function verifyCallResultFromTarget( + address target, + bool success, + bytes memory returndata, + string memory errorMessage + ) internal view returns (bytes memory) { + if (success) { + if (returndata.length == 0) { + // only check isContract if the call was successful and the return data is empty + // otherwise we already know that it was a contract + require(isContract(target), "Address: call to non-contract"); + } + return returndata; + } else { + _revert(returndata, errorMessage); + } + } + + /** + * @dev Tool to verify that a low level call was successful, and revert if it wasn't, either by bubbling the + * revert reason or using the provided one. + * + * _Available since v4.3._ + */ + function verifyCallResult( + bool success, + bytes memory returndata, + string memory errorMessage + ) internal pure returns (bytes memory) { + if (success) { + return returndata; + } else { + _revert(returndata, errorMessage); + } + } + + function _revert(bytes memory returndata, string memory errorMessage) private pure { + // Look for revert reason and bubble it up if present + if (returndata.length > 0) { + // The easiest way to bubble the revert reason is using memory via assembly + /// @solidity memory-safe-assembly + assembly { + let returndata_size := mload(returndata) + revert(add(32, returndata), returndata_size) + } + } else { + revert(errorMessage); + } + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/ContractV1.sol b/tests/unit/utils/test_data/upgradeability_util/src/ContractV1.sol new file mode 100644 index 000000000..1e2c4b476 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/ContractV1.sol @@ -0,0 +1,36 @@ +pragma solidity ^0.8.2; + +import "./ProxyStorage.sol"; + +contract ContractV1 is ProxyStorage { + uint private stateA = 0; + uint private stateB = 0; + uint constant CONST = 32; + bool bug = false; + + function f(uint x) public { + if (msg.sender == admin) { + stateA = x; + } + } + + function g(uint y) public { + if (checkA()) { + stateB = y - 10; + } + } + + function h() public { + if (checkB()) { + bug = true; + } + } + + function checkA() internal returns (bool) { + return stateA % CONST == 1; + } + + function checkB() internal returns (bool) { + return stateB == 62; + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/ContractV2.sol b/tests/unit/utils/test_data/upgradeability_util/src/ContractV2.sol new file mode 100644 index 000000000..9b102f3e9 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/ContractV2.sol @@ -0,0 +1,41 @@ +pragma solidity ^0.8.2; + +import "./ProxyStorage.sol"; + +contract ContractV2 is ProxyStorage { + uint private stateA = 0; + uint private stateB = 0; + uint constant CONST = 32; + bool bug = false; + uint private stateC = 0; + + function f(uint x) public { + if (msg.sender == admin) { + stateA = x; + } + } + + function g(uint y) public { + if (checkA()) { + stateB = y - 10; + } + } + + function h() public { + if (checkB()) { + bug = true; + } + } + + function i() public { + stateC = stateC + 1; + } + + function checkA() internal returns (bool) { + return stateA % CONST == 1; + } + + function checkB() internal returns (bool) { + return stateB == 32; + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/EIP1822Proxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/EIP1822Proxy.sol new file mode 100644 index 000000000..3145eb17e --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/EIP1822Proxy.sol @@ -0,0 +1,47 @@ +pragma solidity ^0.5.0; + +contract EIP1822Proxy { + // Code position in storage is keccak256("PROXIABLE") = "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" + constructor(bytes memory constructData, address contractLogic) public { + // save the code address + assembly { // solium-disable-line + sstore(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7, contractLogic) + } + (bool success, bytes memory _ ) = contractLogic.delegatecall(constructData); // solium-disable-line + require(success, "Construction failed"); + } + + function() external payable { + assembly { // solium-disable-line + let contractLogic := sload(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7) + calldatacopy(0x0, 0x0, calldatasize) + let success := delegatecall(sub(gas, 10000), contractLogic, 0x0, calldatasize, 0, 0) + let retSz := returndatasize + returndatacopy(0, 0, retSz) + switch success + case 0 { + revert(0, retSz) + } + default { + return(0, retSz) + } + } + } +} + +contract EIP1822Proxiable { + // Code position in storage is keccak256("PROXIABLE") = "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" + + function updateCodeAddress(address newAddress) internal { + require( + bytes32(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7) == EIP1822Proxiable(newAddress).proxiableUUID(), + "Not compatible" + ); + assembly { // solium-disable-line + sstore(0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7, newAddress) + } + } + function proxiableUUID() public pure returns (bytes32) { + return 0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7; + } +} \ No newline at end of file diff --git a/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Proxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Proxy.sol new file mode 100644 index 000000000..f1496c27e --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Proxy.sol @@ -0,0 +1,15 @@ +pragma solidity ^0.8.0; + +import "./Proxy.sol"; +import "./ERC1967Upgrade.sol"; + +contract ERC1967Proxy is Proxy, ERC1967Upgrade { + + constructor(address _logic, bytes memory _data) payable { + _upgradeToAndCall(_logic, _data, false); + } + + function _implementation() internal view virtual override returns (address impl) { + return ERC1967Upgrade._getImplementation(); + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Upgrade.sol b/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Upgrade.sol new file mode 100644 index 000000000..d089e94d9 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/ERC1967Upgrade.sol @@ -0,0 +1,105 @@ +pragma solidity ^0.8.2; + +import "./Address.sol"; +import "./StorageSlot.sol"; + +interface IBeacon { + function implementation() external view returns (address); +} + +interface IERC1822Proxiable { + function proxiableUUID() external view returns (bytes32); +} + +abstract contract ERC1967Upgrade { + + bytes32 private constant _ROLLBACK_SLOT = 0x4910fdfa16fed3260ed0e7147f7cc6da11a60208b5b9406d12a635614ffd9143; + bytes32 internal constant _IMPLEMENTATION_SLOT = 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + bytes32 internal constant _ADMIN_SLOT = 0xb53127684a568b3173ae13b9f8a6016e243e63b6e8ee1178d6a717850b5d6103; + bytes32 internal constant _BEACON_SLOT = 0xa3f0ad74e5423aebfd80d3ef4346578335a9a72aeaee59ff6cb3582b35133d50; + + event Upgraded(address indexed implementation); + event AdminChanged(address previousAdmin, address newAdmin); + event BeaconUpgraded(address indexed beacon); + + function _getImplementation() internal view returns (address) { + return StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value; + } + + function _setImplementation(address newImplementation) private { + require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract"); + StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation; + } + + function _upgradeTo(address newImplementation) internal { + _setImplementation(newImplementation); + emit Upgraded(newImplementation); + } + + function _upgradeToAndCall( + address newImplementation, + bytes memory data, + bool forceCall + ) internal { + _upgradeTo(newImplementation); + if (data.length > 0 || forceCall) { + Address.functionDelegateCall(newImplementation, data); + } + } + + function _upgradeToAndCallUUPS( + address newImplementation, + bytes memory data, + bool forceCall + ) internal { + if (StorageSlot.getBooleanSlot(_ROLLBACK_SLOT).value) { + _setImplementation(newImplementation); + } else { + try IERC1822Proxiable(newImplementation).proxiableUUID() returns (bytes32 slot) { + require(slot == _IMPLEMENTATION_SLOT, "ERC1967Upgrade: unsupported proxiableUUID"); + } catch { + revert("ERC1967Upgrade: new implementation is not UUPS"); + } + _upgradeToAndCall(newImplementation, data, forceCall); + } + } + + function _getAdmin() internal view returns (address) { + return StorageSlot.getAddressSlot(_ADMIN_SLOT).value; + } + + function _setAdmin(address newAdmin) private { + require(newAdmin != address(0), "ERC1967: new admin is the zero address"); + StorageSlot.getAddressSlot(_ADMIN_SLOT).value = newAdmin; + } + + function _changeAdmin(address newAdmin) internal { + emit AdminChanged(_getAdmin(), newAdmin); + _setAdmin(newAdmin); + } + + function _getBeacon() internal view returns (address) { + return StorageSlot.getAddressSlot(_BEACON_SLOT).value; + } + + function _setBeacon(address newBeacon) private { + require(Address.isContract(newBeacon), "ERC1967: new beacon is not a contract"); + require( + Address.isContract(IBeacon(newBeacon).implementation()), + "ERC1967: beacon implementation is not a contract" + ); + StorageSlot.getAddressSlot(_BEACON_SLOT).value = newBeacon; + } + + function _upgradeBeaconToAndCall( + address newBeacon, + bytes memory data, + bool forceCall + ) internal { + _setBeacon(newBeacon); + emit BeaconUpgraded(newBeacon); + if (data.length > 0 || forceCall) { + Address.functionDelegateCall(IBeacon(newBeacon).implementation(), data); + } + } +} \ No newline at end of file diff --git a/tests/unit/utils/test_data/upgradeability_util/src/InheritedStorageProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/InheritedStorageProxy.sol new file mode 100644 index 000000000..eddbfb0f1 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/InheritedStorageProxy.sol @@ -0,0 +1,39 @@ +pragma solidity ^0.8.0; + +import "./Proxy.sol"; +import "./ProxyStorage.sol"; + +contract InheritedStorageProxy is Proxy, ProxyStorage { + constructor(address _implementation) { + admin = msg.sender; + implementation = _implementation; + } + + function getImplementation() external view returns (address) { + return _implementation(); + } + + function getAdmin() external view returns (address) { + return _admin(); + } + + function upgrade(address _newImplementation) external { + require(msg.sender == admin, "Only admin can upgrade"); + implementation = _newImplementation; + } + + function setAdmin(address _newAdmin) external { + require(msg.sender == admin, "Only current admin can change admin"); + admin = _newAdmin; + } + + function _implementation() internal view override returns (address) { + return implementation; + } + + function _admin() internal view returns (address) { + return admin; + } + + function _beforeFallback() internal override {} +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/MasterCopyProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/MasterCopyProxy.sol new file mode 100644 index 000000000..d25a2a920 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/MasterCopyProxy.sol @@ -0,0 +1,27 @@ +pragma solidity ^0.5.0; + +contract MasterCopyProxy { + address internal masterCopy; + + constructor(address _masterCopy) + public + { + require(_masterCopy != address(0), "Invalid master copy address provided"); + masterCopy = _masterCopy; + } + + /// @dev Fallback function forwards all transactions and returns all received return data. + function () + external + payable + { + // solium-disable-next-line security/no-inline-assembly + assembly { + calldatacopy(0, 0, calldatasize()) + let success := delegatecall(gas, sload(0), 0, calldatasize(), 0, 0) + returndatacopy(0, 0, returndatasize()) + if eq(success, 0) { revert(0, returndatasize()) } + return(0, returndatasize()) + } + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/Proxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/Proxy.sol new file mode 100644 index 000000000..445ddb170 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/Proxy.sol @@ -0,0 +1,36 @@ +pragma solidity ^0.8.0; + +abstract contract Proxy { + + function _delegate(address implementation) internal virtual { + assembly { + calldatacopy(0, 0, calldatasize()) + let result := delegatecall(gas(), implementation, 0, calldatasize(), 0, 0) + returndatacopy(0, 0, returndatasize()) + switch result + case 0 { + revert(0, returndatasize()) + } + default { + return(0, returndatasize()) + } + } + } + + function _implementation() internal view virtual returns (address); + + function _fallback() internal virtual { + _beforeFallback(); + _delegate(_implementation()); + } + + fallback() external payable virtual { + _fallback(); + } + + receive() external payable virtual { + _fallback(); + } + + function _beforeFallback() internal virtual {} +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/ProxyStorage.sol b/tests/unit/utils/test_data/upgradeability_util/src/ProxyStorage.sol new file mode 100644 index 000000000..d591040bd --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/ProxyStorage.sol @@ -0,0 +1,6 @@ +pragma solidity ^0.8.0; + +contract ProxyStorage { + address internal admin; + address internal implementation; +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/StorageSlot.sol b/tests/unit/utils/test_data/upgradeability_util/src/StorageSlot.sol new file mode 100644 index 000000000..6ab8f5dc6 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/StorageSlot.sol @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v4.7.0) (utils/StorageSlot.sol) + +pragma solidity ^0.8.0; + +/** + * @dev Library for reading and writing primitive types to specific storage slots. + * + * Storage slots are often used to avoid storage conflict when dealing with upgradeable contracts. + * This library helps with reading and writing to such slots without the need for inline assembly. + * + * The functions in this library return Slot structs that contain a `value` member that can be used to read or write. + * + * Example usage to set ERC1967 implementation slot: + * ``` + * contract ERC1967 { + * bytes32 internal constant _IMPLEMENTATION_SLOT = 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + * + * function _getImplementation() internal view returns (address) { + * return StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value; + * } + * + * function _setImplementation(address newImplementation) internal { + * require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract"); + * StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation; + * } + * } + * ``` + * + * _Available since v4.1 for `address`, `bool`, `bytes32`, and `uint256`._ + */ +library StorageSlot { + struct AddressSlot { + address value; + } + + struct BooleanSlot { + bool value; + } + + struct Bytes32Slot { + bytes32 value; + } + + struct Uint256Slot { + uint256 value; + } + + /** + * @dev Returns an `AddressSlot` with member `value` located at `slot`. + */ + function getAddressSlot(bytes32 slot) internal pure returns (AddressSlot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } + + /** + * @dev Returns an `BooleanSlot` with member `value` located at `slot`. + */ + function getBooleanSlot(bytes32 slot) internal pure returns (BooleanSlot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } + + /** + * @dev Returns an `Bytes32Slot` with member `value` located at `slot`. + */ + function getBytes32Slot(bytes32 slot) internal pure returns (Bytes32Slot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } + + /** + * @dev Returns an `Uint256Slot` with member `value` located at `slot`. + */ + function getUint256Slot(bytes32 slot) internal pure returns (Uint256Slot storage r) { + /// @solidity memory-safe-assembly + assembly { + r.slot := slot + } + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/SynthProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/SynthProxy.sol new file mode 100644 index 000000000..9b3a6bdef --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/SynthProxy.sol @@ -0,0 +1,58 @@ +pragma solidity ^0.5.0; + +contract Owned { + address public owner; + + constructor(address _owner) public { + require(_owner != address(0), "Owner address cannot be 0"); + owner = _owner; + } + + modifier onlyOwner { + require(msg.sender == owner, "Only the contract owner may perform this action"); + _; + } +} + +contract Proxyable is Owned { + /* The proxy this contract exists behind. */ + SynthProxy public proxy; + + constructor(address payable _proxy) internal { + // This contract is abstract, and thus cannot be instantiated directly + require(owner != address(0), "Owner must be set"); + + proxy = SynthProxy(_proxy); + } + + function setProxy(address payable _proxy) external onlyOwner { + proxy = SynthProxy(_proxy); + } +} + + +contract SynthProxy is Owned { + Proxyable public target; + + constructor(address _owner) public Owned(_owner) {} + + function setTarget(Proxyable _target) external onlyOwner { + target = _target; + } + + // solhint-disable no-complex-fallback + function() external payable { + assembly { + calldatacopy(0, 0, calldatasize) + + /* We must explicitly forward ether to the underlying contract as well. */ + let result := delegatecall(gas, sload(target_slot), 0, calldatasize, 0, 0) + returndatacopy(0, 0, returndatasize) + + if iszero(result) { + revert(0, returndatasize) + } + return(0, returndatasize) + } + } +} diff --git a/tests/unit/utils/test_data/upgradeability_util/src/ZosProxy.sol b/tests/unit/utils/test_data/upgradeability_util/src/ZosProxy.sol new file mode 100644 index 000000000..db44f4c98 --- /dev/null +++ b/tests/unit/utils/test_data/upgradeability_util/src/ZosProxy.sol @@ -0,0 +1,67 @@ +pragma solidity ^0.5.0; + +contract ZosProxy { + function () payable external { + _fallback(); + } + + function _implementation() internal view returns (address); + + function _delegate(address implementation) internal { + assembly { + calldatacopy(0, 0, calldatasize) + let result := delegatecall(gas, implementation, 0, calldatasize, 0, 0) + returndatacopy(0, 0, returndatasize) + switch result + case 0 { revert(0, returndatasize) } + default { return(0, returndatasize) } + } + } + + function _willFallback() internal { + } + + function _fallback() internal { + _willFallback(); + _delegate(_implementation()); + } +} + +library AddressUtils { + function isContract(address addr) internal view returns (bool) { + uint256 size; + assembly { size := extcodesize(addr) } + return size > 0; + } +} + +contract UpgradeabilityProxy is ZosProxy { + event Upgraded(address indexed implementation); + + bytes32 private constant IMPLEMENTATION_SLOT = 0x7050c9e0f4ca769c69bd3a8ef740bc37934f8e2c036e5a723fd8ee048ed3f8c3; + + constructor(address _implementation) public payable { + assert(IMPLEMENTATION_SLOT == keccak256("org.zeppelinos.proxy.implementation")); + _setImplementation(_implementation); + } + + function _implementation() internal view returns (address impl) { + bytes32 slot = IMPLEMENTATION_SLOT; + assembly { + impl := sload(slot) + } + } + + function _upgradeTo(address newImplementation) internal { + _setImplementation(newImplementation); + emit Upgraded(newImplementation); + } + + function _setImplementation(address newImplementation) private { + require(AddressUtils.isContract(newImplementation), "Cannot set a proxy implementation to a non-contract address"); + bytes32 slot = IMPLEMENTATION_SLOT; + assembly { + sstore(slot, newImplementation) + } + } +} diff --git a/tests/unit/utils/test_upgradeability_util.py b/tests/unit/utils/test_upgradeability_util.py new file mode 100644 index 000000000..7d6fb82da --- /dev/null +++ b/tests/unit/utils/test_upgradeability_util.py @@ -0,0 +1,85 @@ +import os +from pathlib import Path + +from solc_select import solc_select + +from slither import Slither +from slither.core.expressions import Literal +from slither.utils.upgradeability import ( + compare, + get_proxy_implementation_var, + get_proxy_implementation_slot, +) + +SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" / "upgradeability_util" + + +# pylint: disable=too-many-locals +def test_upgrades_compare() -> None: + solc_select.switch_global_version("0.8.2", always_install=True) + + sl = Slither(os.path.join(TEST_DATA_DIR, "TestUpgrades-0.8.2.sol")) + v1 = sl.get_contract_from_name("ContractV1")[0] + v2 = sl.get_contract_from_name("ContractV2")[0] + missing_vars, new_vars, tainted_vars, new_funcs, modified_funcs, tainted_funcs = compare(v1, v2) + assert len(missing_vars) == 0 + assert new_vars == [v2.get_state_variable_from_name("stateC")] + assert tainted_vars == [ + v2.get_state_variable_from_name("stateB"), + v2.get_state_variable_from_name("bug"), + ] + assert new_funcs == [v2.get_function_from_signature("i()")] + assert modified_funcs == [v2.get_function_from_signature("checkB()")] + assert tainted_funcs == [ + v2.get_function_from_signature("g(uint256)"), + v2.get_function_from_signature("h()"), + ] + + +def test_upgrades_implementation_var() -> None: + solc_select.switch_global_version("0.8.2", always_install=True) + sl = Slither(os.path.join(TEST_DATA_DIR, "TestUpgrades-0.8.2.sol")) + + erc_1967_proxy = sl.get_contract_from_name("ERC1967Proxy")[0] + storage_proxy = sl.get_contract_from_name("InheritedStorageProxy")[0] + + target = get_proxy_implementation_var(erc_1967_proxy) + slot = get_proxy_implementation_slot(erc_1967_proxy) + assert target == erc_1967_proxy.get_state_variable_from_name("_IMPLEMENTATION_SLOT") + assert slot.slot == 0x360894A13BA1A3210667C828492DB98DCA3E2076CC3735A920A3CA505D382BBC + target = get_proxy_implementation_var(storage_proxy) + slot = get_proxy_implementation_slot(storage_proxy) + assert target == storage_proxy.get_state_variable_from_name("implementation") + assert slot.slot == 1 + + solc_select.switch_global_version("0.5.0", always_install=True) + sl = Slither(os.path.join(TEST_DATA_DIR, "TestUpgrades-0.5.0.sol")) + + eip_1822_proxy = sl.get_contract_from_name("EIP1822Proxy")[0] + # zos_proxy = sl.get_contract_from_name("ZosProxy")[0] + master_copy_proxy = sl.get_contract_from_name("MasterCopyProxy")[0] + synth_proxy = sl.get_contract_from_name("SynthProxy")[0] + + target = get_proxy_implementation_var(eip_1822_proxy) + slot = get_proxy_implementation_slot(eip_1822_proxy) + assert target not in eip_1822_proxy.state_variables_ordered + assert target.name == "contractLogic" and isinstance(target.expression, Literal) + assert ( + target.expression.value + == "0xc5f16f0fcc639fa48a6947836d9850f504798523bf8c9a3a87d5876cf622bcf7" + ) + assert slot.slot == 0xC5F16F0FCC639FA48A6947836D9850F504798523BF8C9A3A87D5876CF622BCF7 + # # The util fails with this proxy due to how Slither parses assembly w/ Solidity versions < 0.6.0 (see issue #1775) + # target = get_proxy_implementation_var(zos_proxy) + # slot = get_proxy_implementation_slot(zos_proxy) + # assert target == zos_proxy.get_state_variable_from_name("IMPLEMENTATION_SLOT") + # assert slot.slot == 0x7050C9E0F4CA769C69BD3A8EF740BC37934F8E2C036E5A723FD8EE048ED3F8C3 + target = get_proxy_implementation_var(master_copy_proxy) + slot = get_proxy_implementation_slot(master_copy_proxy) + assert target == master_copy_proxy.get_state_variable_from_name("masterCopy") + assert slot.slot == 0 + target = get_proxy_implementation_var(synth_proxy) + slot = get_proxy_implementation_slot(synth_proxy) + assert target == synth_proxy.get_state_variable_from_name("target") + assert slot.slot == 1