Merge pull request #2588 from crytic/dev-support-transient

Improve transient storage support
pull/2238/merge
Josselin Feist 3 weeks ago committed by GitHub
commit 9e89bbbe69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 60
      slither/core/compilation_unit.py
  2. 46
      slither/core/declarations/contract.py
  3. 16
      slither/core/variables/state_variable.py
  4. 7
      slither/core/variables/variable.py
  5. 2
      slither/detectors/variables/unchanged_state_variables.py
  6. 13
      slither/printers/summary/variable_order.py
  7. 4
      slither/solc_parsing/declarations/contract.py
  8. 2
      slither/tools/upgradeability/checks/variable_initialization.py
  9. 93
      slither/tools/upgradeability/checks/variables_order.py
  10. 8
      slither/utils/upgradeability.py
  11. 2
      slither/vyper_parsing/declarations/contract.py

@ -73,7 +73,8 @@ class SlitherCompilationUnit(Context):
# Memoize
self._all_state_variables: Optional[Set[StateVariable]] = None
self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._persistent_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._transient_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._contract_with_missing_inheritance: Set[Contract] = set()
@ -297,33 +298,52 @@ class SlitherCompilationUnit(Context):
def compute_storage_layout(self) -> None:
assert self.is_solidity
for contract in self.contracts_derived:
self._storage_layouts[contract.name] = {}
slot = 0
offset = 0
for var in contract.stored_state_variables_ordered:
assert var.type
size, new_slot = var.type.storage_size
if new_slot:
if offset > 0:
slot += 1
offset = 0
elif size + offset > 32:
self._compute_storage_layout(contract.name, contract.storage_variables_ordered, False)
self._compute_storage_layout(contract.name, contract.transient_variables_ordered, True)
def _compute_storage_layout(
self, contract_name: str, state_variables_ordered: List[StateVariable], is_transient: bool
):
if is_transient:
self._transient_storage_layouts[contract_name] = {}
else:
self._persistent_storage_layouts[contract_name] = {}
slot = 0
offset = 0
for var in state_variables_ordered:
assert var.type
size, new_slot = var.type.storage_size
if new_slot:
if offset > 0:
slot += 1
offset = 0
elif size + offset > 32:
slot += 1
offset = 0
self._storage_layouts[contract.name][var.canonical_name] = (
if is_transient:
self._transient_storage_layouts[contract_name][var.canonical_name] = (
slot,
offset,
)
if new_slot:
slot += math.ceil(size / 32)
else:
offset += size
else:
self._persistent_storage_layouts[contract_name][var.canonical_name] = (
slot,
offset,
)
if new_slot:
slot += math.ceil(size / 32)
else:
offset += size
def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]:
return self._storage_layouts[contract.name][var.canonical_name]
if var.is_stored:
return self._persistent_storage_layouts[contract.name][var.canonical_name]
return self._transient_storage_layouts[contract.name][var.canonical_name]
# endregion

@ -440,55 +440,43 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
def state_variables(self) -> List["StateVariable"]:
"""
Returns all the accessible variables (do not include private variable from inherited contract).
Use state_variables_ordered for all the variables following the storage order
Use stored_state_variables_ordered for all the storage variables following the storage order
Use transient_state_variables_ordered for all the transient variables following the storage order
list(StateVariable): List of the state variables.
"""
return list(self._variables.values())
@property
def stored_state_variables(self) -> List["StateVariable"]:
def state_variables_entry_points(self) -> List["StateVariable"]:
"""
Returns state variables with storage locations, excluding private variables from inherited contracts.
Use stored_state_variables_ordered to access variables with storage locations in their declaration order.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations.
list(StateVariable): List of the state variables that are public.
"""
return [variable for variable in self.state_variables if variable.is_stored]
return [var for var in self._variables.values() if var.visibility == "public"]
@property
def stored_state_variables_ordered(self) -> List["StateVariable"]:
def state_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables with storage locations by order of declaration.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations ordered by declaration.
list(StateVariable): List of the state variables by order of declaration.
"""
return [variable for variable in self.state_variables_ordered if variable.is_stored]
return self._variables_ordered
def add_state_variables_ordered(self, new_vars: List["StateVariable"]) -> None:
self._variables_ordered += new_vars
@property
def state_variables_entry_points(self) -> List["StateVariable"]:
def storage_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables that are public.
list(StateVariable): List of the state variables in storage location by order of declaration.
"""
return [var for var in self._variables.values() if var.visibility == "public"]
return [v for v in self._variables_ordered if v.is_stored]
@property
def state_variables_ordered(self) -> List["StateVariable"]:
def transient_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables by order of declaration.
list(StateVariable): List of the state variables in transient location by order of declaration.
"""
return list(self._variables_ordered)
def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None:
self._variables_ordered += new_vars
return [v for v in self._variables_ordered if v.is_transient]
@property
def state_variables_inherited(self) -> List["StateVariable"]:

@ -35,6 +35,22 @@ class StateVariable(ContractLevel, Variable):
"""
return self._location
@property
def is_stored(self) -> bool:
"""
Checks if the state variable is stored, based on it not being constant or immutable or transient.
"""
return (
not self._is_constant and not self._is_immutable and not self._location == "transient"
)
@property
def is_transient(self) -> bool:
"""
Checks if the state variable is transient. A transient variable can not be constant or immutable.
"""
return self._location == "transient"
# endregion
###################################################################################
###################################################################################

@ -93,13 +93,6 @@ class Variable(SourceMapping):
def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst
@property
def is_stored(self) -> bool:
"""
Checks if a variable is stored, based on it not being constant or immutable. Future updates may adjust for new non-storage keywords.
"""
return not self._is_constant and not self._is_immutable
@property
def is_reentrant(self) -> bool:
return self._is_reentrant

@ -92,7 +92,7 @@ class UnchangedStateVariables:
variables = []
functions = []
variables.append(c.stored_state_variables)
variables.append(c.storage_variables_ordered)
functions.append(c.all_functions_called)
valid_candidates: Set[StateVariable] = {

@ -27,10 +27,17 @@ class VariableOrder(AbstractPrinter):
for contract in self.slither.contracts_derived:
txt += f"\n{contract.name}:\n"
table = MyPrettyTable(["Name", "Type", "Slot", "Offset"])
for variable in contract.stored_state_variables_ordered:
table = MyPrettyTable(["Name", "Type", "Slot", "Offset", "State"])
for variable in contract.storage_variables_ordered:
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
table.add_row([variable.canonical_name, str(variable.type), slot, offset])
table.add_row(
[variable.canonical_name, str(variable.type), slot, offset, "Storage"]
)
for variable in contract.transient_variables_ordered:
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
table.add_row(
[variable.canonical_name, str(variable.type), slot, offset, "Transient"]
)
all_tables.append((contract.name, table))
txt += str(table) + "\n"

@ -350,7 +350,7 @@ class ContractSolc(CallerContextExpression):
if v.visibility != "private"
}
)
self._contract.add_variables_ordered(
self._contract.add_state_variables_ordered(
[
var
for var in father.state_variables_ordered
@ -370,7 +370,7 @@ class ContractSolc(CallerContextExpression):
if var_parser.reference_id is not None:
self._contract.state_variables_by_ref_id[var_parser.reference_id] = var
self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var])
self._contract.add_state_variables_ordered([var])
def _parse_modifier(self, modifier_data: Dict) -> None:
modif = Modifier(self._contract.compilation_unit)

@ -43,7 +43,7 @@ Using initialize functions to write initial values in state variables.
def _check(self) -> List[Output]:
results = []
for s in self.contract.stored_state_variables_ordered:
for s in self.contract.storage_variables_ordered:
if s.initialized:
info: CHECK_INFO = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info)

@ -115,29 +115,43 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
def _check(self) -> List[Output]:
contract1 = self._contract1()
contract2 = self._contract2()
order1 = contract1.stored_state_variables_ordered
order2 = contract2.stored_state_variables_ordered
results: List[Output] = []
for idx, _ in enumerate(order1):
if len(order2) <= idx:
# Handle by MissingVariable
return results
variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info: CHECK_INFO = [
"Different variables between ",
contract1,
" and ",
contract2,
"\n",
]
info += ["\t ", variable1, "\n"]
info += ["\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
def _check_internal(
contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool
):
if is_transient:
order1 = contract1.transient_variables_ordered
order2 = contract2.transient_variables_ordered
else:
order1 = contract1.storage_variables_ordered
order2 = contract2.storage_variables_ordered
for idx, _ in enumerate(order1):
if len(order2) <= idx:
# Handle by MissingVariable
return
variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info: CHECK_INFO = [
"Different variables between ",
contract1,
" and ",
contract2,
"\n",
]
info += ["\t ", variable1, "\n"]
info += ["\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
# Checking state variables with storage location
_check_internal(contract1, contract2, results, False)
# Checking state variables with transient location
_check_internal(contract1, contract2, results, True)
return results
@ -236,22 +250,35 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
def _check(self) -> List[Output]:
contract1 = self._contract1()
contract2 = self._contract2()
order1 = contract1.stored_state_variables_ordered
order2 = contract2.stored_state_variables_ordered
results = []
results: List[Output] = []
if len(order2) <= len(order1):
return []
def _check_internal(
contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool
):
if is_transient:
order1 = contract1.transient_variables_ordered
order2 = contract2.transient_variables_ordered
else:
order1 = contract1.storage_variables_ordered
order2 = contract2.storage_variables_ordered
idx = len(order1)
if len(order2) <= len(order1):
return
while idx < len(order2):
variable2 = order2[idx]
info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1
idx = len(order1)
while idx < len(order2):
variable2 = order2[idx]
info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1
# Checking state variables with storage location
_check_internal(contract1, contract2, results, False)
# Checking state variables with transient location
_check_internal(contract1, contract2, results, True)
return results

@ -80,8 +80,8 @@ def compare(
tainted-contracts: list[TaintedExternalContract]
"""
order_vars1 = v1.stored_state_variables_ordered
order_vars2 = v2.stored_state_variables_ordered
order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered
order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered
func_sigs1 = [function.solidity_signature for function in v1.functions]
func_sigs2 = [function.solidity_signature for function in v2.functions]
@ -306,8 +306,8 @@ def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]:
List of StateVariables from v1 missing in v2
"""
results = []
order_vars1 = v1.stored_state_variables_ordered
order_vars2 = v2.stored_state_variables_ordered
order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered
order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered
if len(order_vars2) < len(order_vars1):
for variable in order_vars1:
if variable.name not in [v.name for v in order_vars2]:

@ -470,7 +470,7 @@ class ContractVyper: # pylint: disable=too-many-instance-attributes
assert var.name
self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var])
self._contract.add_state_variables_ordered([var])
# Interfaces can refer to constants
self._contract.file_scope.variables[var.name] = var

Loading…
Cancel
Save