Additional improvements - CI still failing

pull/1311/head
Josselin Feist 2 years ago
parent 46cb8a5cd2
commit 485e4a9674
  1. 141
      slither/tools/read_storage/read_storage.py
  2. 4
      tests/test_read_storage.py

@ -2,7 +2,7 @@ import logging
import sys import sys
from math import floor from math import floor
from os import environ from os import environ
from typing import Callable, Optional, Tuple, Union, List, Dict from typing import Callable, Optional, Tuple, Union, List, Dict, Any
try: try:
from web3 import Web3 from web3 import Web3
@ -31,7 +31,7 @@ from dataclasses import dataclass, field
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.solidity_types import ArrayType, ElementaryType, UserDefinedType, MappingType from slither.core.solidity_types import ArrayType, ElementaryType, UserDefinedType, MappingType
from slither.core.declarations import Contract, Structure from slither.core.declarations import Contract, Structure, StructureContract
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.variables.structure_variable import StructureVariable from slither.core.variables.structure_variable import StructureVariable
@ -47,7 +47,8 @@ class SlotInfo:
size: int size: int
offset: int offset: int
value: Optional[Union[int, bool, str, ChecksumAddress]] = None value: Optional[Union[int, bool, str, ChecksumAddress]] = None
elems: Dict[int, "SlotInfo"] = field(default_factory=lambda: {}) # For structure, str->SlotInfo, for array, str-> SlotInfo
elems: Dict[str, "SlotInfo"] = field(default_factory=lambda: {})
struct_var: Optional[str] = None struct_var: Optional[str] = None
@ -116,9 +117,8 @@ class SlitherReadStorage:
tmp[var.name] = info tmp[var.name] = info
if isinstance(type_, UserDefinedType) and isinstance(type_.type, Structure): if isinstance(type_, UserDefinedType) and isinstance(type_.type, Structure):
tmp[var.name].elems = self._all_struct_slots(var, type_.type, contract) tmp[var.name].elems = self._all_struct_slots(var, type_.type, contract)
continue
if isinstance(type_, ArrayType): elif isinstance(type_, ArrayType):
elems = self._all_array_slots(var, contract, type_, info.slot) elems = self._all_array_slots(var, contract, type_, info.slot)
tmp[var.name].elems = elems tmp[var.name].elems = elems
@ -127,27 +127,27 @@ class SlitherReadStorage:
# TODO: remove this pylint exception (montyly) # TODO: remove this pylint exception (montyly)
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
def get_storage_slot( def get_storage_slot(
self, self,
target_variable: StateVariable, target_variable: StateVariable,
contract: Contract, contract: Contract,
**kwargs, **kwargs: Any,
) -> Union[SlotInfo, None]: ) -> Union[SlotInfo, None]:
"""Finds the storage slot of a variable in a given contract. """Finds the storage slot of a variable in a given contract.
Args: Args:
target_variable (`StateVariable`): The variable to retrieve the slot for. target_variable (`StateVariable`): The variable to retrieve the slot for.
contracts (`Contract`): The contract that contains the given state variable. contracts (`Contract`): The contract that contains the given state variable.
**kwargs: **kwargs:
key (str): Key of a mapping or index position if an array. key (int): Key of a mapping or index position if an array.
deep_key (str): Key of a mapping embedded within another mapping or deep_key (int): Key of a mapping embedded within another mapping or
secondary index if array. secondary index if array.
struct_var (str): Structure variable name. struct_var (str): Structure variable name.
Returns: Returns:
(`SlotInfo`) | None : A dictionary of the slot information. (`SlotInfo`) | None : A dictionary of the slot information.
""" """
key = kwargs.get("key", None) key: Optional[int] = kwargs.get("key", None)
deep_key = kwargs.get("deep_key", None) deep_key: Optional[int]= kwargs.get("deep_key", None)
struct_var = kwargs.get("struct_var", None) struct_var: Optional[str] = kwargs.get("struct_var", None)
info: str info: str
var_log_name = target_variable.name var_log_name = target_variable.name
try: try:
@ -178,7 +178,9 @@ class SlitherReadStorage:
elif isinstance(target_variable_type, UserDefinedType) and struct_var: elif isinstance(target_variable_type, UserDefinedType) and struct_var:
var_log_name = struct_var var_log_name = struct_var
elems = target_variable_type.type.elems_ordered target_variable_type_type = target_variable_type.type
assert isinstance(target_variable_type_type, Structure)
elems = target_variable_type_type.elems_ordered
info, type_to, slot, size, offset = self._find_struct_var_slot(elems, slot, struct_var) info, type_to, slot, size, offset = self._find_struct_var_slot(elems, slot, struct_var)
self.log += info self.log += info
@ -226,7 +228,11 @@ class SlitherReadStorage:
""" """
Fetches the slot values Fetches the slot values
""" """
for slot_info in self.slot_info.values(): stack = list(self.slot_info.values())
while stack:
slot_info = stack.pop()
if slot_info.elems:
stack.extend(slot_info.elems.values())
hex_bytes = get_storage_data( hex_bytes = get_storage_data(
self.web3, self.checksum_address, int.to_bytes(slot_info.slot, 32, byteorder="big") self.web3, self.checksum_address, int.to_bytes(slot_info.slot, 32, byteorder="big")
) )
@ -269,12 +275,12 @@ class SlitherReadStorage:
if isinstance(type_, UserDefinedType) and isinstance(type_.type, Structure): if isinstance(type_, UserDefinedType) and isinstance(type_.type, Structure):
tabulate_data.pop() tabulate_data.pop()
for item in info.elems: for item in info.elems.values():
slot = info.elems[item].slot slot = item.slot
offset = info.elems[item].offset offset = item.offset
size = info.elems[item].size size = item.size
type_string = info.elems[item].type_string type_string = item.type_string
struct_var = info.elems[item].struct_var struct_var = item.struct_var
# doesn't handle deep keys currently # doesn't handle deep keys currently
var_name_struct_or_array_var = f"{var} -> {struct_var}" var_name_struct_or_array_var = f"{var} -> {struct_var}"
@ -285,20 +291,19 @@ class SlitherReadStorage:
if isinstance(type_, ArrayType): if isinstance(type_, ArrayType):
tabulate_data.pop() tabulate_data.pop()
for item in info.elems: for item_key, item in info.elems.items():
for elem in info.elems.values(): slot = item.slot
slot = elem.slot offset = item.offset
offset = elem.offset size = item.size
size = elem.size type_string = item.type_string
type_string = elem.type_string struct_var = item.struct_var
struct_var = elem.struct_var
# doesn't handle deep keys currently # doesn't handle deep keys currently
var_name_struct_or_array_var = f"{var}[{item}] -> {struct_var}" var_name_struct_or_array_var = f"{var}[{item_key}] -> {struct_var}"
tabulate_data.append( tabulate_data.append(
[slot, offset, size, type_string, var_name_struct_or_array_var] [slot, offset, size, type_string, var_name_struct_or_array_var]
) )
print( print(
tabulate( tabulate(
@ -403,12 +408,12 @@ class SlitherReadStorage:
@staticmethod @staticmethod
def _find_struct_var_slot( def _find_struct_var_slot(
elems: List[StructureVariable], slot: bytes, struct_var: str elems: List[StructureVariable], slot_as_bytes: bytes, struct_var: str
) -> Tuple[str, str, bytes, int, int]: ) -> Tuple[str, str, bytes, int, int]:
"""Finds the slot of a structure variable. """Finds the slot of a structure variable.
Args: Args:
elems (List[StructureVariable]): Ordered list of structure variables. elems (List[StructureVariable]): Ordered list of structure variables.
slot (bytes): The slot of the struct to begin searching at. slot_as_bytes (bytes): The slot of the struct to begin searching at.
struct_var (str): The target structure variable. struct_var (str): The target structure variable.
Returns: Returns:
info (str): Info about the target variable to log. info (str): Info about the target variable to log.
@ -417,15 +422,15 @@ class SlitherReadStorage:
size (int): The size (in bits) of the target variable. size (int): The size (in bits) of the target variable.
offset (int): The size of other variables that share the same slot. offset (int): The size of other variables that share the same slot.
""" """
slot = int.from_bytes(slot, "big") slot = int.from_bytes(slot_as_bytes, "big")
offset = 0 offset = 0
type_to = "" type_to = ""
size = 0 # TODO: find out what size should return here? (montyly) size = 0
for var in elems: for var in elems:
var_type = var.type var_type = var.type
if isinstance(var_type, ElementaryType): if isinstance(var_type, ElementaryType):
size_ = var_type.size size = var_type.size
if size_: if size:
if offset >= 256: if offset >= 256:
slot += 1 slot += 1
offset = 0 offset = 0
@ -434,20 +439,20 @@ class SlitherReadStorage:
break # found struct var break # found struct var
offset += size offset += size
else: else:
print(f"{type(var_type)} is current not implemented in structure") logger.info(f"{type(var_type)} is current not implemented in _find_struct_var_slot")
slot = int.to_bytes(slot, 32, byteorder="big") slot_as_bytes = int.to_bytes(slot, 32, byteorder="big")
info = f"\nStruct Variable: {struct_var}" info = f"\nStruct Variable: {struct_var}"
return info, type_to, slot, size, offset return info, type_to, slot_as_bytes, size, offset
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
@staticmethod @staticmethod
def _find_array_slot( def _find_array_slot(
target_variable_type: ArrayType, target_variable_type: ArrayType,
slot: bytes, slot: bytes,
key: int, key: int,
deep_key: int = None, deep_key: int = None,
struct_var: str = None, struct_var: str = None,
) -> Tuple[str, str, bytes, int, int]: ) -> Tuple[str, str, bytes, int, int]:
"""Finds the slot of array's index. """Finds the slot of array's index.
Args: Args:
@ -470,7 +475,7 @@ class SlitherReadStorage:
target_variable_type_type = target_variable_type.type target_variable_type_type = target_variable_type.type
if isinstance( if isinstance(
target_variable_type_type, ArrayType target_variable_type_type, ArrayType
): # multidimensional array uint[i][], , uint[][i], or uint[][] ): # multidimensional array uint[i][], , uint[][i], or uint[][]
size = target_variable_type_type.type.size size = target_variable_type_type.type.size
type_to = target_variable_type_type.type.name type_to = target_variable_type_type.type.name
@ -538,11 +543,11 @@ class SlitherReadStorage:
@staticmethod @staticmethod
def _find_mapping_slot( def _find_mapping_slot(
target_variable: StateVariable, target_variable: StateVariable,
slot: bytes, slot: bytes,
key: Union[int, str], key: Union[int, str],
deep_key: Union[int, str] = None, deep_key: Union[int, str] = None,
struct_var: str = None, struct_var: str = None,
) -> Tuple[str, str, bytes, int, int]: ) -> Tuple[str, str, bytes, int, int]:
"""Finds the data slot of a target variable within a mapping. """Finds the data slot of a target variable within a mapping.
target_variable (`StateVariable`): The mapping that contains the target variable. target_variable (`StateVariable`): The mapping that contains the target variable.
@ -573,7 +578,7 @@ class SlitherReadStorage:
slot = keccak(encode_abi([key_type, "uint256"], [key, decode_single("uint256", slot)])) slot = keccak(encode_abi([key_type, "uint256"], [key, decode_single("uint256", slot)]))
if isinstance(target_variable.type.type_to, UserDefinedType) and isinstance( if isinstance(target_variable.type.type_to, UserDefinedType) and isinstance(
target_variable.type.type_to.type, Structure target_variable.type.type_to.type, Structure
): # mapping(elem => struct) ): # mapping(elem => struct)
assert struct_var assert struct_var
elems = target_variable.type.type_to.type.elems_ordered elems = target_variable.type.type_to.type.elems_ordered
@ -583,7 +588,7 @@ class SlitherReadStorage:
info += info_tmp info += info_tmp
elif isinstance( elif isinstance(
target_variable.type.type_to, MappingType target_variable.type.type_to, MappingType
): # mapping(elem => mapping(elem => ???)) ): # mapping(elem => mapping(elem => ???))
assert deep_key assert deep_key
key_type = target_variable.type.type_to.type_from.name key_type = target_variable.type.type_to.type_from.name
@ -600,7 +605,7 @@ class SlitherReadStorage:
offset = 0 offset = 0
if isinstance(target_variable.type.type_to.type_to, UserDefinedType) and isinstance( if isinstance(target_variable.type.type_to.type_to, UserDefinedType) and isinstance(
target_variable.type.type_to.type_to.type, Structure target_variable.type.type_to.type_to.type, Structure
): # mapping(elem => mapping(elem => struct)) ): # mapping(elem => mapping(elem => struct))
assert struct_var assert struct_var
elems = target_variable.type.type_to.type_to.type.elems_ordered elems = target_variable.type.type_to.type_to.type.elems_ordered
@ -621,7 +626,7 @@ class SlitherReadStorage:
@staticmethod @staticmethod
def get_variable_info( def get_variable_info(
contract: Contract, target_variable: StateVariable contract: Contract, target_variable: StateVariable
) -> Tuple[int, int, int, str]: ) -> Tuple[int, int, int, str]:
"""Return slot, size, offset, and type.""" """Return slot, size, offset, and type."""
type_to = str(target_variable.type) type_to = str(target_variable.type)
@ -638,7 +643,7 @@ class SlitherReadStorage:
@staticmethod @staticmethod
def convert_value_to_type( def convert_value_to_type(
hex_bytes: HexBytes, size: int, offset: int, type_to: str hex_bytes: HexBytes, size: int, offset: int, type_to: str
) -> Union[int, bool, str, ChecksumAddress]: ) -> Union[int, bool, str, ChecksumAddress]:
"""Convert slot data to type representation.""" """Convert slot data to type representation."""
# Account for storage packing # Account for storage packing
@ -651,7 +656,7 @@ class SlitherReadStorage:
return value return value
def _all_struct_slots( def _all_struct_slots(
self, var: StateVariable, st: Structure, contract: Contract, key=None self, var: StateVariable, st: Structure, contract: Contract, key: Optional[int] = None
) -> Dict[str, SlotInfo]: ) -> Dict[str, SlotInfo]:
"""Retrieves all members of a struct.""" """Retrieves all members of a struct."""
struct_elems = st.elems_ordered struct_elems = st.elems_ordered
@ -669,19 +674,18 @@ class SlitherReadStorage:
return data return data
def _all_array_slots( def _all_array_slots(
self, var: StateVariable, contract: Contract, type_: Type, slot: int self, var: StateVariable, contract: Contract, type_: Type, slot: int
) -> Dict[int, SlotInfo]: ) -> Dict[str, SlotInfo]:
"""Retrieves all members of an array.""" """Retrieves all members of an array."""
array_length = self._get_array_length(type_, slot) array_length = self._get_array_length(type_, slot)
elems: Dict[int, SlotInfo] = {} elems: Dict[str, SlotInfo] = {}
if isinstance(type_, UserDefinedType): if isinstance(type_, UserDefinedType):
st = type_.type st = type_.type
if isinstance(st, Structure): if isinstance(st, Structure):
for i in range(min(array_length, self.max_depth)): for i in range(min(array_length, self.max_depth)):
# TODO: figure out why _all_struct_slots returns a Dict[str, SlotInfo] # TODO: figure out why _all_struct_slots returns a Dict[str, SlotInfo]
# but this expect a SlotInfo (montyly) # but this expect a SlotInfo (montyly)
elems[i] = self._all_struct_slots(var, st, contract, key=str(i)) elems[str(i)] = self._all_struct_slots(var, st, contract, key=i)
continue
else: else:
for i in range(min(array_length, self.max_depth)): for i in range(min(array_length, self.max_depth)):
@ -691,12 +695,11 @@ class SlitherReadStorage:
key=str(i), key=str(i),
) )
if info: if info:
elems[i] = info elems[str(i)] = info
if isinstance(type_.type, ArrayType): # multidimensional array if isinstance(type_.type, ArrayType): # multidimensional array
array_length = self._get_array_length(type_.type, info.slot) array_length = self._get_array_length(type_.type, info.slot)
elems[i].elems = {}
for j in range(min(array_length, self.max_depth)): for j in range(min(array_length, self.max_depth)):
info = self.get_storage_slot( info = self.get_storage_slot(
var, var,
@ -704,8 +707,8 @@ class SlitherReadStorage:
key=str(i), key=str(i),
deep_key=str(j), deep_key=str(j),
) )
if info:
elems[i].elems[j] = info elems[str(i)].elems[str(j)] = info
return elems return elems
def _get_array_length(self, type_: Type, slot: int) -> int: def _get_array_length(self, type_: Type, slot: int) -> int:

@ -135,8 +135,8 @@ def test_read_storage(web3, ganache) -> None:
path_list = re.findall(r"\['(.*?)'\]", change.path()) path_list = re.findall(r"\['(.*?)'\]", change.path())
path = "_".join(path_list) path = "_".join(path_list)
with open(f"{path}_expected.txt", "w", encoding="utf8") as f: with open(f"{path}_expected.txt", "w", encoding="utf8") as f:
f.write(change.t1) f.write(str(change.t1))
with open(f"{path}_actual.txt", "w", encoding="utf8") as f: with open(f"{path}_actual.txt", "w", encoding="utf8") as f:
f.write(change.t2) f.write(str(change.t2))
assert not diff assert not diff

Loading…
Cancel
Save