Merge pull request #2323 from dokzai/issue-1644

Create a variable API that filters out constants and immutables
pull/2146/head
alpharush 9 months ago committed by GitHub
commit 9c868e7400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 5
      slither/core/compilation_unit.py
  2. 27
      slither/core/declarations/contract.py
  3. 7
      slither/core/variables/variable.py
  4. 4
      slither/detectors/variables/unchanged_state_variables.py
  5. 7
      slither/printers/summary/variable_order.py
  6. 2
      slither/tools/read_storage/read_storage.py
  7. 4
      slither/tools/upgradeability/checks/variable_initialization.py
  8. 24
      slither/tools/upgradeability/checks/variables_order.py
  9. 2
      slither/utils/encoding.py
  10. 20
      slither/utils/upgradeability.py

@ -302,10 +302,7 @@ class SlitherCompilationUnit(Context):
slot = 0 slot = 0
offset = 0 offset = 0
for var in contract.state_variables_ordered: for var in contract.stored_state_variables_ordered:
if var.is_constant or var.is_immutable:
continue
assert var.type assert var.type
size, new_slot = var.type.storage_size size, new_slot = var.type.storage_size

@ -436,6 +436,33 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
return list(self._variables.values()) return list(self._variables.values())
@property
def stored_state_variables(self) -> List["StateVariable"]:
"""
Returns state variables with storage locations, excluding private variables from inherited contracts.
Use stored_state_variables_ordered to access variables with storage locations in their declaration order.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations.
"""
return [variable for variable in self.state_variables if variable.is_stored]
@property
def stored_state_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables with storage locations by order of declaration.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations ordered by declaration.
"""
return [variable for variable in self.state_variables_ordered if variable.is_stored]
@property @property
def state_variables_entry_points(self) -> List["StateVariable"]: def state_variables_entry_points(self) -> List["StateVariable"]:
""" """

@ -93,6 +93,13 @@ class Variable(SourceMapping):
def is_constant(self, is_cst: bool) -> None: def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst self._is_constant = is_cst
@property
def is_stored(self) -> bool:
"""
Checks if a variable is stored, based on it not being constant or immutable. Future updates may adjust for new non-storage keywords.
"""
return not self._is_constant and not self._is_immutable
@property @property
def is_reentrant(self) -> bool: def is_reentrant(self) -> bool:
return self._is_reentrant return self._is_reentrant

@ -25,7 +25,7 @@ def _is_valid_type(v: StateVariable) -> bool:
def _valid_candidate(v: StateVariable) -> bool: def _valid_candidate(v: StateVariable) -> bool:
return _is_valid_type(v) and not (v.is_constant or v.is_immutable) return _is_valid_type(v)
def _is_constant_var(v: Variable) -> bool: def _is_constant_var(v: Variable) -> bool:
@ -92,7 +92,7 @@ class UnchangedStateVariables:
variables = [] variables = []
functions = [] functions = []
variables.append(c.state_variables) variables.append(c.stored_state_variables)
functions.append(c.all_functions_called) functions.append(c.all_functions_called)
valid_candidates: Set[StateVariable] = { valid_candidates: Set[StateVariable] = {

@ -28,10 +28,9 @@ class VariableOrder(AbstractPrinter):
for contract in self.slither.contracts_derived: for contract in self.slither.contracts_derived:
txt += f"\n{contract.name}:\n" txt += f"\n{contract.name}:\n"
table = MyPrettyTable(["Name", "Type", "Slot", "Offset"]) table = MyPrettyTable(["Name", "Type", "Slot", "Offset"])
for variable in contract.state_variables_ordered: for variable in contract.stored_state_variables_ordered:
if not variable.is_constant and not variable.is_immutable: slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable) table.add_row([variable.canonical_name, str(variable.type), slot, offset])
table.add_row([variable.canonical_name, str(variable.type), slot, offset])
all_tables.append((contract.name, table)) all_tables.append((contract.name, table))
txt += str(table) + "\n" txt += str(table) + "\n"

@ -398,7 +398,7 @@ class SlitherReadStorage:
for contract in self.contracts: for contract in self.contracts:
for var in contract.state_variables_ordered: for var in contract.state_variables_ordered:
if func(var): if func(var):
if not var.is_constant and not var.is_immutable: if var.is_stored:
self._target_variables.append((contract, var)) self._target_variables.append((contract, var))
elif ( elif (
self.unstructured self.unstructured

@ -43,8 +43,8 @@ Using initialize functions to write initial values in state variables.
def _check(self) -> List[Output]: def _check(self) -> List[Output]:
results = [] results = []
for s in self.contract.state_variables_ordered: for s in self.contract.stored_state_variables_ordered:
if s.initialized and not (s.is_constant or s.is_immutable): if s.initialized:
info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] info: CHECK_INFO = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -115,16 +115,8 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
def _check(self) -> List[Output]: def _check(self) -> List[Output]:
contract1 = self._contract1() contract1 = self._contract1()
contract2 = self._contract2() contract2 = self._contract2()
order1 = [ order1 = contract1.stored_state_variables_ordered
variable order2 = contract2.stored_state_variables_ordered
for variable in contract1.state_variables_ordered
if not (variable.is_constant or variable.is_immutable)
]
order2 = [
variable
for variable in contract2.state_variables_ordered
if not (variable.is_constant or variable.is_immutable)
]
results: List[Output] = [] results: List[Output] = []
for idx, _ in enumerate(order1): for idx, _ in enumerate(order1):
@ -244,16 +236,8 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
def _check(self) -> List[Output]: def _check(self) -> List[Output]:
contract1 = self._contract1() contract1 = self._contract1()
contract2 = self._contract2() contract2 = self._contract2()
order1 = [ order1 = contract1.stored_state_variables_ordered
variable order2 = contract2.stored_state_variables_ordered
for variable in contract1.state_variables_ordered
if not (variable.is_constant or variable.is_immutable)
]
order2 = [
variable
for variable in contract2.state_variables_ordered
if not (variable.is_constant or variable.is_immutable)
]
results = [] results = []

@ -71,7 +71,7 @@ def encode_var_for_compare(var: Union[variables.Variable, SolidityVariable]) ->
if isinstance(var, variables.LocalVariable): if isinstance(var, variables.LocalVariable):
return f"local_solc_variable({ntype(var.type)},{var.location})" return f"local_solc_variable({ntype(var.type)},{var.location})"
if isinstance(var, variables.StateVariable): if isinstance(var, variables.StateVariable):
if not (var.is_constant or var.is_immutable): if var.is_stored:
try: try:
slot, _ = var.contract.compilation_unit.storage_layout_of(var.contract, var) slot, _ = var.contract.compilation_unit.storage_layout_of(var.contract, var)
except KeyError: except KeyError:

@ -81,12 +81,8 @@ def compare(
tainted-contracts: list[TaintedExternalContract] tainted-contracts: list[TaintedExternalContract]
""" """
order_vars1 = [ order_vars1 = v1.stored_state_variables_ordered
v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable order_vars2 = v2.stored_state_variables_ordered
]
order_vars2 = [
v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable
]
func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs1 = [function.solidity_signature for function in v1.functions]
func_sigs2 = [function.solidity_signature for function in v2.functions] func_sigs2 = [function.solidity_signature for function in v2.functions]
@ -206,7 +202,7 @@ def tainted_external_contracts(funcs: List[Function]) -> List[TaintedExternalCon
elif ( elif (
isinstance(target, StateVariable) isinstance(target, StateVariable)
and target not in (v for v in tainted_contracts[contract.name].tainted_variables) and target not in (v for v in tainted_contracts[contract.name].tainted_variables)
and not (target.is_constant or target.is_immutable) and target.is_stored
): ):
# Found a new high-level call to a public state variable getter # Found a new high-level call to a public state variable getter
tainted_contracts[contract.name].add_tainted_variable(target) tainted_contracts[contract.name].add_tainted_variable(target)
@ -304,12 +300,8 @@ def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]:
List of StateVariables from v1 missing in v2 List of StateVariables from v1 missing in v2
""" """
results = [] results = []
order_vars1 = [ order_vars1 = v1.stored_state_variables_ordered
v for v in v1.state_variables_ordered if not v.is_constant and not v.is_immutable order_vars2 = v2.stored_state_variables_ordered
]
order_vars2 = [
v for v in v2.state_variables_ordered if not v.is_constant and not v.is_immutable
]
if len(order_vars2) < len(order_vars1): if len(order_vars2) < len(order_vars1):
for variable in order_vars1: for variable in order_vars1:
if variable.name not in [v.name for v in order_vars2]: if variable.name not in [v.name for v in order_vars2]:
@ -366,7 +358,7 @@ def get_proxy_implementation_slot(proxy: Contract) -> Optional[SlotInfo]:
delegate = get_proxy_implementation_var(proxy) delegate = get_proxy_implementation_var(proxy)
if isinstance(delegate, StateVariable): if isinstance(delegate, StateVariable):
if not delegate.is_constant and not delegate.is_immutable: if delegate.is_stored:
srs = SlitherReadStorage([proxy], 20) srs = SlitherReadStorage([proxy], 20)
return srs.get_storage_slot(delegate, proxy) return srs.get_storage_slot(delegate, proxy)
if delegate.is_constant and delegate.type.name == "bytes32": if delegate.is_constant and delegate.type.name == "bytes32":

Loading…
Cancel
Save