Merge pull request #2588 from crytic/dev-support-transient

Improve transient storage support
pull/2238/merge
Josselin Feist 3 weeks ago committed by GitHub
commit 9e89bbbe69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 60
      slither/core/compilation_unit.py
  2. 46
      slither/core/declarations/contract.py
  3. 16
      slither/core/variables/state_variable.py
  4. 7
      slither/core/variables/variable.py
  5. 2
      slither/detectors/variables/unchanged_state_variables.py
  6. 13
      slither/printers/summary/variable_order.py
  7. 4
      slither/solc_parsing/declarations/contract.py
  8. 2
      slither/tools/upgradeability/checks/variable_initialization.py
  9. 93
      slither/tools/upgradeability/checks/variables_order.py
  10. 8
      slither/utils/upgradeability.py
  11. 2
      slither/vyper_parsing/declarations/contract.py

@ -73,7 +73,8 @@ class SlitherCompilationUnit(Context):
# Memoize # Memoize
self._all_state_variables: Optional[Set[StateVariable]] = None self._all_state_variables: Optional[Set[StateVariable]] = None
self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} self._persistent_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._transient_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._contract_with_missing_inheritance: Set[Contract] = set() self._contract_with_missing_inheritance: Set[Contract] = set()
@ -297,33 +298,52 @@ class SlitherCompilationUnit(Context):
def compute_storage_layout(self) -> None: def compute_storage_layout(self) -> None:
assert self.is_solidity assert self.is_solidity
for contract in self.contracts_derived: for contract in self.contracts_derived:
self._storage_layouts[contract.name] = {} self._compute_storage_layout(contract.name, contract.storage_variables_ordered, False)
self._compute_storage_layout(contract.name, contract.transient_variables_ordered, True)
slot = 0
offset = 0 def _compute_storage_layout(
for var in contract.stored_state_variables_ordered: self, contract_name: str, state_variables_ordered: List[StateVariable], is_transient: bool
assert var.type ):
size, new_slot = var.type.storage_size if is_transient:
self._transient_storage_layouts[contract_name] = {}
if new_slot: else:
if offset > 0: self._persistent_storage_layouts[contract_name] = {}
slot += 1
offset = 0 slot = 0
elif size + offset > 32: offset = 0
for var in state_variables_ordered:
assert var.type
size, new_slot = var.type.storage_size
if new_slot:
if offset > 0:
slot += 1 slot += 1
offset = 0 offset = 0
elif size + offset > 32:
slot += 1
offset = 0
self._storage_layouts[contract.name][var.canonical_name] = ( if is_transient:
self._transient_storage_layouts[contract_name][var.canonical_name] = (
slot, slot,
offset, offset,
) )
if new_slot: else:
slot += math.ceil(size / 32) self._persistent_storage_layouts[contract_name][var.canonical_name] = (
else: slot,
offset += size offset,
)
if new_slot:
slot += math.ceil(size / 32)
else:
offset += size
def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]: def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]:
return self._storage_layouts[contract.name][var.canonical_name] if var.is_stored:
return self._persistent_storage_layouts[contract.name][var.canonical_name]
return self._transient_storage_layouts[contract.name][var.canonical_name]
# endregion # endregion

@ -440,55 +440,43 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
def state_variables(self) -> List["StateVariable"]: def state_variables(self) -> List["StateVariable"]:
""" """
Returns all the accessible variables (do not include private variable from inherited contract). Returns all the accessible variables (do not include private variable from inherited contract).
Use state_variables_ordered for all the variables following the storage order Use stored_state_variables_ordered for all the storage variables following the storage order
Use transient_state_variables_ordered for all the transient variables following the storage order
list(StateVariable): List of the state variables. list(StateVariable): List of the state variables.
""" """
return list(self._variables.values()) return list(self._variables.values())
@property @property
def stored_state_variables(self) -> List["StateVariable"]: def state_variables_entry_points(self) -> List["StateVariable"]:
""" """
Returns state variables with storage locations, excluding private variables from inherited contracts. list(StateVariable): List of the state variables that are public.
Use stored_state_variables_ordered to access variables with storage locations in their declaration order.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations.
""" """
return [variable for variable in self.state_variables if variable.is_stored] return [var for var in self._variables.values() if var.visibility == "public"]
@property @property
def stored_state_variables_ordered(self) -> List["StateVariable"]: def state_variables_ordered(self) -> List["StateVariable"]:
""" """
list(StateVariable): List of the state variables with storage locations by order of declaration. list(StateVariable): List of the state variables by order of declaration.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations ordered by declaration.
""" """
return [variable for variable in self.state_variables_ordered if variable.is_stored] return self._variables_ordered
def add_state_variables_ordered(self, new_vars: List["StateVariable"]) -> None:
self._variables_ordered += new_vars
@property @property
def state_variables_entry_points(self) -> List["StateVariable"]: def storage_variables_ordered(self) -> List["StateVariable"]:
""" """
list(StateVariable): List of the state variables that are public. list(StateVariable): List of the state variables in storage location by order of declaration.
""" """
return [var for var in self._variables.values() if var.visibility == "public"] return [v for v in self._variables_ordered if v.is_stored]
@property @property
def state_variables_ordered(self) -> List["StateVariable"]: def transient_variables_ordered(self) -> List["StateVariable"]:
""" """
list(StateVariable): List of the state variables by order of declaration. list(StateVariable): List of the state variables in transient location by order of declaration.
""" """
return list(self._variables_ordered) return [v for v in self._variables_ordered if v.is_transient]
def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None:
self._variables_ordered += new_vars
@property @property
def state_variables_inherited(self) -> List["StateVariable"]: def state_variables_inherited(self) -> List["StateVariable"]:

@ -35,6 +35,22 @@ class StateVariable(ContractLevel, Variable):
""" """
return self._location return self._location
@property
def is_stored(self) -> bool:
"""
Checks if the state variable is stored, based on it not being constant or immutable or transient.
"""
return (
not self._is_constant and not self._is_immutable and not self._location == "transient"
)
@property
def is_transient(self) -> bool:
"""
Checks if the state variable is transient. A transient variable can not be constant or immutable.
"""
return self._location == "transient"
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################

@ -93,13 +93,6 @@ class Variable(SourceMapping):
def is_constant(self, is_cst: bool) -> None: def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst self._is_constant = is_cst
@property
def is_stored(self) -> bool:
"""
Checks if a variable is stored, based on it not being constant or immutable. Future updates may adjust for new non-storage keywords.
"""
return not self._is_constant and not self._is_immutable
@property @property
def is_reentrant(self) -> bool: def is_reentrant(self) -> bool:
return self._is_reentrant return self._is_reentrant

@ -92,7 +92,7 @@ class UnchangedStateVariables:
variables = [] variables = []
functions = [] functions = []
variables.append(c.stored_state_variables) variables.append(c.storage_variables_ordered)
functions.append(c.all_functions_called) functions.append(c.all_functions_called)
valid_candidates: Set[StateVariable] = { valid_candidates: Set[StateVariable] = {

@ -27,10 +27,17 @@ class VariableOrder(AbstractPrinter):
for contract in self.slither.contracts_derived: for contract in self.slither.contracts_derived:
txt += f"\n{contract.name}:\n" txt += f"\n{contract.name}:\n"
table = MyPrettyTable(["Name", "Type", "Slot", "Offset"]) table = MyPrettyTable(["Name", "Type", "Slot", "Offset", "State"])
for variable in contract.stored_state_variables_ordered: for variable in contract.storage_variables_ordered:
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable) slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
table.add_row([variable.canonical_name, str(variable.type), slot, offset]) table.add_row(
[variable.canonical_name, str(variable.type), slot, offset, "Storage"]
)
for variable in contract.transient_variables_ordered:
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
table.add_row(
[variable.canonical_name, str(variable.type), slot, offset, "Transient"]
)
all_tables.append((contract.name, table)) all_tables.append((contract.name, table))
txt += str(table) + "\n" txt += str(table) + "\n"

@ -350,7 +350,7 @@ class ContractSolc(CallerContextExpression):
if v.visibility != "private" if v.visibility != "private"
} }
) )
self._contract.add_variables_ordered( self._contract.add_state_variables_ordered(
[ [
var var
for var in father.state_variables_ordered for var in father.state_variables_ordered
@ -370,7 +370,7 @@ class ContractSolc(CallerContextExpression):
if var_parser.reference_id is not None: if var_parser.reference_id is not None:
self._contract.state_variables_by_ref_id[var_parser.reference_id] = var self._contract.state_variables_by_ref_id[var_parser.reference_id] = var
self._contract.variables_as_dict[var.name] = var self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var]) self._contract.add_state_variables_ordered([var])
def _parse_modifier(self, modifier_data: Dict) -> None: def _parse_modifier(self, modifier_data: Dict) -> None:
modif = Modifier(self._contract.compilation_unit) modif = Modifier(self._contract.compilation_unit)

@ -43,7 +43,7 @@ Using initialize functions to write initial values in state variables.
def _check(self) -> List[Output]: def _check(self) -> List[Output]:
results = [] results = []
for s in self.contract.stored_state_variables_ordered: for s in self.contract.storage_variables_ordered:
if s.initialized: if s.initialized:
info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] info: CHECK_INFO = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info) json = self.generate_result(info)

@ -115,29 +115,43 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
def _check(self) -> List[Output]: def _check(self) -> List[Output]:
contract1 = self._contract1() contract1 = self._contract1()
contract2 = self._contract2() contract2 = self._contract2()
order1 = contract1.stored_state_variables_ordered
order2 = contract2.stored_state_variables_ordered
results: List[Output] = [] results: List[Output] = []
for idx, _ in enumerate(order1):
if len(order2) <= idx: def _check_internal(
# Handle by MissingVariable contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool
return results ):
if is_transient:
variable1 = order1[idx] order1 = contract1.transient_variables_ordered
variable2 = order2[idx] order2 = contract2.transient_variables_ordered
if (variable1.name != variable2.name) or (variable1.type != variable2.type): else:
info: CHECK_INFO = [ order1 = contract1.storage_variables_ordered
"Different variables between ", order2 = contract2.storage_variables_ordered
contract1,
" and ", for idx, _ in enumerate(order1):
contract2, if len(order2) <= idx:
"\n", # Handle by MissingVariable
] return
info += ["\t ", variable1, "\n"]
info += ["\t ", variable2, "\n"] variable1 = order1[idx]
json = self.generate_result(info) variable2 = order2[idx]
results.append(json) if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info: CHECK_INFO = [
"Different variables between ",
contract1,
" and ",
contract2,
"\n",
]
info += ["\t ", variable1, "\n"]
info += ["\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
# Checking state variables with storage location
_check_internal(contract1, contract2, results, False)
# Checking state variables with transient location
_check_internal(contract1, contract2, results, True)
return results return results
@ -236,22 +250,35 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s
def _check(self) -> List[Output]: def _check(self) -> List[Output]:
contract1 = self._contract1() contract1 = self._contract1()
contract2 = self._contract2() contract2 = self._contract2()
order1 = contract1.stored_state_variables_ordered
order2 = contract2.stored_state_variables_ordered
results = [] results: List[Output] = []
if len(order2) <= len(order1): def _check_internal(
return [] contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool
):
if is_transient:
order1 = contract1.transient_variables_ordered
order2 = contract2.transient_variables_ordered
else:
order1 = contract1.storage_variables_ordered
order2 = contract2.storage_variables_ordered
idx = len(order1) if len(order2) <= len(order1):
return
while idx < len(order2): idx = len(order1)
variable2 = order2[idx]
info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] while idx < len(order2):
json = self.generate_result(info) variable2 = order2[idx]
results.append(json) info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
idx = idx + 1 json = self.generate_result(info)
results.append(json)
idx = idx + 1
# Checking state variables with storage location
_check_internal(contract1, contract2, results, False)
# Checking state variables with transient location
_check_internal(contract1, contract2, results, True)
return results return results

@ -80,8 +80,8 @@ def compare(
tainted-contracts: list[TaintedExternalContract] tainted-contracts: list[TaintedExternalContract]
""" """
order_vars1 = v1.stored_state_variables_ordered order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered
order_vars2 = v2.stored_state_variables_ordered order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered
func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs1 = [function.solidity_signature for function in v1.functions]
func_sigs2 = [function.solidity_signature for function in v2.functions] func_sigs2 = [function.solidity_signature for function in v2.functions]
@ -306,8 +306,8 @@ def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]:
List of StateVariables from v1 missing in v2 List of StateVariables from v1 missing in v2
""" """
results = [] results = []
order_vars1 = v1.stored_state_variables_ordered order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered
order_vars2 = v2.stored_state_variables_ordered order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered
if len(order_vars2) < len(order_vars1): if len(order_vars2) < len(order_vars1):
for variable in order_vars1: for variable in order_vars1:
if variable.name not in [v.name for v in order_vars2]: if variable.name not in [v.name for v in order_vars2]:

@ -470,7 +470,7 @@ class ContractVyper: # pylint: disable=too-many-instance-attributes
assert var.name assert var.name
self._contract.variables_as_dict[var.name] = var self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var]) self._contract.add_state_variables_ordered([var])
# Interfaces can refer to constants # Interfaces can refer to constants
self._contract.file_scope.variables[var.name] = var self._contract.file_scope.variables[var.name] = var

Loading…
Cancel
Save