slither-read-storage native POA support (#1843)

* Native support for POA networks in read_storage

* Set `srs.rpc` before `srs.block`

* Type hint

* New RpcInfo class w/ RpcInfo.web3 and RpcInfo.block
In SlitherReadStorage, self.rpc_info: Optional[RpcInfo]
replaces self.rpc, self._block, self._web3

* Black

* Update test_read_storage.py

* Add import in __init__.py

* Avoid instantiating SRS twice

* Add comment about `get_block` for POA networks

* Pylint

* Black

* Allow other valid block string arguments
["latest", "earliest", "pending", "safe", "finalized"]

* `args.block` can be in ["latest", "earliest", "pending", "safe", "finalized"]

* Use BlockTag enum class for valid `str` arguments

* Tweak `RpcInfo.__init__()` signature

* get rid of `or "latest"`

* Import BlockTag

* Use `web3.types.BlockIdentifier`

* Revert BlockTag enum

* Pylint and black

* Replace missing newline

* Update slither/tools/read_storage/__main__.py

Better, cleaner python

Co-authored-by: alpharush <0xalpharush@protonmail.com>

* Drop try/except around args.block parsing
allow ValueError if user provides invalid block arg

* Remove unused import

---------

Co-authored-by: alpharush <0xalpharush@protonmail.com>
pull/1947/head
William E Bodell III 1 year ago committed by GitHub
parent 48e3466628
commit 00461aad9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      slither/tools/read_storage/__init__.py
  2. 29
      slither/tools/read_storage/__main__.py
  3. 57
      slither/tools/read_storage/read_storage.py
  4. 6
      tests/tools/read-storage/test_read_storage.py

@ -1 +1 @@
from .read_storage import SlitherReadStorage
from .read_storage import SlitherReadStorage, RpcInfo

@ -7,7 +7,7 @@ import argparse
from crytic_compile import cryticparser
from slither import Slither
from slither.tools.read_storage.read_storage import SlitherReadStorage
from slither.tools.read_storage.read_storage import SlitherReadStorage, RpcInfo
def parse_args() -> argparse.Namespace:
@ -126,22 +126,19 @@ def main() -> None:
else:
contracts = slither.contracts
srs = SlitherReadStorage(contracts, args.max_depth)
try:
srs.block = int(args.block)
except ValueError:
srs.block = str(args.block or "latest")
rpc_info = None
if args.rpc_url:
# Remove target prefix e.g. rinkeby:0x0 -> 0x0.
address = target[target.find(":") + 1 :]
# Default to implementation address unless a storage address is given.
if not args.storage_address:
args.storage_address = address
srs.storage_address = args.storage_address
srs.rpc = args.rpc_url
valid = ["latest", "earliest", "pending", "safe", "finalized"]
block = args.block if args.block in valid else int(args.block)
rpc_info = RpcInfo(args.rpc_url, block)
srs = SlitherReadStorage(contracts, args.max_depth, rpc_info)
# Remove target prefix e.g. rinkeby:0x0 -> 0x0.
address = target[target.find(":") + 1 :]
# Default to implementation address unless a storage address is given.
if not args.storage_address:
args.storage_address = address
srs.storage_address = args.storage_address
if args.variable_name:
# Use a lambda func to only return variables that have same name as target.

@ -6,8 +6,11 @@ import dataclasses
from eth_abi import decode, encode
from eth_typing.evm import ChecksumAddress
from eth_utils import keccak
from eth_utils import keccak, to_checksum_address
from web3 import Web3
from web3.types import BlockIdentifier
from web3.exceptions import ExtraDataLengthError
from web3.middleware import geth_poa_middleware
from slither.core.declarations import Contract, Structure
from slither.core.solidity_types import ArrayType, ElementaryType, MappingType, UserDefinedType
@ -42,18 +45,43 @@ class SlitherReadStorageException(Exception):
pass
class RpcInfo:
def __init__(self, rpc_url: str, block: BlockIdentifier = "latest") -> None:
assert isinstance(block, int) or block in [
"latest",
"earliest",
"pending",
"safe",
"finalized",
]
self.rpc: str = rpc_url
self._web3: Web3 = Web3(Web3.HTTPProvider(self.rpc))
"""If the RPC is for a POA network, the first call to get_block fails, so we inject geth_poa_middleware"""
try:
self._block: int = self.web3.eth.get_block(block)["number"]
except ExtraDataLengthError:
self._web3.middleware_onion.inject(geth_poa_middleware, layer=0)
self._block: int = self.web3.eth.get_block(block)["number"]
@property
def web3(self) -> Web3:
return self._web3
@property
def block(self) -> int:
return self._block
# pylint: disable=too-many-instance-attributes
class SlitherReadStorage:
def __init__(self, contracts: List[Contract], max_depth: int) -> None:
def __init__(self, contracts: List[Contract], max_depth: int, rpc_info: RpcInfo = None) -> None:
self._checksum_address: Optional[ChecksumAddress] = None
self._contracts: List[Contract] = contracts
self._log: str = ""
self._max_depth: int = max_depth
self._slot_info: Dict[str, SlotInfo] = {}
self._target_variables: List[Tuple[Contract, StateVariable]] = []
self._web3: Optional[Web3] = None
self.block: Union[str, int] = "latest"
self.rpc: Optional[str] = None
self.rpc_info: Optional[RpcInfo] = rpc_info
self.storage_address: Optional[str] = None
self.table: Optional[MyPrettyTable] = None
@ -73,18 +101,12 @@ class SlitherReadStorage:
def log(self, log: str) -> None:
self._log = log
@property
def web3(self) -> Web3:
if not self._web3:
self._web3 = Web3(Web3.HTTPProvider(self.rpc))
return self._web3
@property
def checksum_address(self) -> ChecksumAddress:
if not self.storage_address:
raise ValueError
if not self._checksum_address:
self._checksum_address = self.web3.to_checksum_address(self.storage_address)
self._checksum_address = to_checksum_address(self.storage_address)
return self._checksum_address
@property
@ -223,11 +245,12 @@ class SlitherReadStorage:
"""Fetches the slot value of `SlotInfo` object
:param slot_info:
"""
assert self.rpc_info is not None
hex_bytes = get_storage_data(
self.web3,
self.rpc_info.web3,
self.checksum_address,
int.to_bytes(slot_info.slot, 32, byteorder="big"),
self.block,
self.rpc_info.block,
)
slot_info.value = self.convert_value_to_type(
hex_bytes, slot_info.size, slot_info.offset, slot_info.type_string
@ -600,15 +623,15 @@ class SlitherReadStorage:
(int): The length of the array.
"""
val = 0
if self.rpc:
if self.rpc_info:
# The length of dynamic arrays is stored at the starting slot.
# Convert from hexadecimal to decimal.
val = int(
get_storage_data(
self.web3,
self.rpc_info.web3,
self.checksum_address,
int.to_bytes(slot, 32, byteorder="big"),
self.block,
self.rpc_info.block,
).hex(),
16,
)

@ -12,7 +12,7 @@ from web3 import Web3
from web3.contract import Contract
from slither import Slither
from slither.tools.read_storage import SlitherReadStorage
from slither.tools.read_storage import SlitherReadStorage, RpcInfo
TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"
@ -105,8 +105,8 @@ def test_read_storage(web3, ganache, solc_binary_path) -> None:
sl = Slither(Path(TEST_DATA_DIR, "storage_layout-0.8.10.sol").as_posix(), solc=solc_path)
contracts = sl.contracts
srs = SlitherReadStorage(contracts, 100)
srs.rpc = ganache.provider
rpc_info: RpcInfo = RpcInfo(ganache.provider)
srs = SlitherReadStorage(contracts, 100, rpc_info)
srs.storage_address = address
srs.get_all_storage_variables()
srs.get_storage_layout()

Loading…
Cancel
Save