diff --git a/pyhmy/rpc/account.py b/pyhmy/rpc/account.py index 35c8636..bfe3295 100644 --- a/pyhmy/rpc/account.py +++ b/pyhmy/rpc/account.py @@ -6,6 +6,14 @@ from .blockchain import ( get_sharding_structure ) +from .exceptions import ( + RPCError, + InvalidRPCReplyError, + JSONDecodeError, + RequestsError, + RequestsTimeoutError +) + _default_endpoint = 'http://localhost:9500' _default_timeout = 30 @@ -27,24 +35,34 @@ def get_balance(address, endpoint=_default_endpoint, timeout=_default_timeout) - ------- int Account balance in ATTO + + Raises + ------ + InvalidRPCReplyError + If received unknown result from endpoint """ + method = 'hmy_getBalance' params = [ address, 'latest' ] - return int(rpc_request('hmy_getBalance', params=params, endpoint=endpoint, timeout=timeout)['result'], 16) + balance = rpc_request(method, params=params, endpoint=endpoint, timeout=timeout)['result'] + try: + return int(balance, 16) + except TypeError as e: + raise InvalidRPCReplyError(method, endpoint) from e def get_balance_by_block(address, block_num, endpoint=_default_endpoint, timeout=_default_timeout) -> int: """ - Get account balance at time of given block + Get account balance for address at a given block number Parameters ---------- address: str Address to get balance for block_num: int - Block number to req + Block to get balance at endpoint: :obj:`str`, optional Endpoint to send request to timeout: :obj:`int`, optional @@ -53,23 +71,36 @@ def get_balance_by_block(address, block_num, endpoint=_default_endpoint, timeout Returns ------- int - Account balance in ATTO at given block + Account balance in ATTO + + Raises + ------ + InvalidRPCReplyError + If received unknown result from endpoint """ + method = 'hmy_getBalanceByBlockNumber' params = [ address, str(hex(block_num)) ] - return int(rpc_request('hmy_getBalanceByBlockNumber', params=params, endpoint=endpoint, timeout=timeout)['result'], 16) + balance = rpc_request(method, params=params, endpoint=endpoint, timeout=timeout)['result'] + try: + return int(balance, 16) + except TypeError as e: + raise InvalidRPCReplyError(method, endpoint) from e -def get_transaction_count(address, endpoint=_default_endpoint, timeout=_default_timeout) -> int: +def get_account_nonce(address, true_nonce=False, endpoint=_default_endpoint, timeout=_default_timeout) -> int: """ - Get number of transactions & staking transactions sent by an account + Get the account nonce Parameters ---------- address: str Address to get transaction count for + true_nonce: :obj:`bool`, optional + True to get on-chain nonce + False to get nonce based on pending transaction pool endpoint: :obj:`str`, optional Endpoint to send request to timeout: :obj:`int`, optional @@ -78,13 +109,48 @@ def get_transaction_count(address, endpoint=_default_endpoint, timeout=_default_ Returns ------- int - Number of transactions sent by the account (account nonce) + Account nonce + + Raises + ------ + InvalidRPCReplyError + If received unknown result from endpoint """ + method = 'hmy_getTransactionCount' params = [ address, - 'latest' + 'latest' if true_nonce else 'pending' ] - return int(rpc_request('hmy_getTransactionCount', params=params, endpoint=endpoint, timeout=timeout)['result'], 16) + nonce = rpc_request(method, params=params, endpoint=endpoint, timeout=timeout)['result'] + try: + return int(nonce, 16) + except TypeError as e: + raise InvalidRPCReplyError(method, endpoint) from e + + +def get_transaction_count(address, endpoint=_default_endpoint, timeout=_default_timeout) -> int: + """ + Get number of transactions & staking transactions sent by an account + + Parameters + ---------- + address: str + Address to get transaction count for + endpoint: :obj:`str`, optional + Endpoint to send request to + timeout: :obj:`int`, optional + Timeout in seconds + + Returns + ------- + int + Number of transactions sent by the account + + See also + -------- + get_account_nonce + """ + return get_account_nonce(address, true_nonce=True, endpoint=endpoint, timeout=timeout) def get_transaction_history(address, page=0, page_size=1000, include_full_tx=False, tx_type='ALL', @@ -120,6 +186,11 @@ def get_transaction_history(address, page=0, page_size=1000, include_full_tx=Fal ------- list # TODO: Add link to reference RPC documentation + + Raises + ------ + InvalidRPCReplyError + If received unknown result from endpoint """ params = [ { @@ -131,8 +202,12 @@ def get_transaction_history(address, page=0, page_size=1000, include_full_tx=Fal 'order': order } ] - tx_history = rpc_request('hmy_getTransactionsHistory', params=params, endpoint=endpoint, timeout=timeout) - return tx_history['result']['transactions'] + method = 'hmy_getTransactionsHistory' + tx_history = rpc_request(method, params=params, endpoint=endpoint, timeout=timeout) + try: + return tx_history['result']['transactions'] + except KeyError as e: + raise InvalidRPCReplyError(method, endpoint) from e def get_staking_transaction_history(address, page=0, page_size=1000, include_full_tx=False, tx_type='ALL', @@ -166,6 +241,11 @@ def get_staking_transaction_history(address, page=0, page_size=1000, include_ful ------- list # TODO: Add link to reference RPC documentation + + Raises + ------ + InvalidRPCReplyError + If received unknown result from endpoint """ params = [ { @@ -178,18 +258,25 @@ def get_staking_transaction_history(address, page=0, page_size=1000, include_ful } ] # Using v2 API, because getStakingTransactionHistory not implemented in v1 - stx_history = rpc_request('hmyv2_getStakingTransactionsHistory', params=params, endpoint=endpoint, timeout=timeout) - return stx_history['result']['staking_transactions'] + method = 'hmyv2_getStakingTransactionsHistory' + stx_history = rpc_request(method, params=params, endpoint=endpoint, timeout=timeout)['result'] + try: + return stx_history['staking_transactions'] + except KeyError as e: + raise InvalidRPCReplyError(method, endpoint) from e -def get_balance_on_all_shards(address, endpoint=_default_endpoint, timeout=_default_timeout): +def get_balance_on_all_shards(address, skip_error=True, endpoint=_default_endpoint, timeout=_default_timeout) -> list: """ - Get current account balance in all shards + Get current account balance in all shards & optionally report errors getting account balance for a shard Parameters ---------- address: str Address to get balance for + skip_error: :obj:`bool`, optional + True to ignore errors getting balance for shard + False to include errors when getting balance for shard endpoint: :obj:`str`, optional Endpoint to send request to timeout: :obj:`int`, optional @@ -197,11 +284,59 @@ def get_balance_on_all_shards(address, endpoint=_default_endpoint, timeout=_defa Returns ------- - dict + list Account balance per shard in ATTO + Example reply: + [ + { + 'shard': 0, + 'balance': 0, + }, + ... + ] """ - balances = {} + balances = [] sharding_structure = get_sharding_structure(endpoint=endpoint, timeout=timeout) for shard in sharding_structure: - balances[shard['shardID']] = get_balance(address, endpoint=shard['http'], timeout=timeout) + try: + balances.append({ + 'shard': shard['shardID'], + 'balance': get_balance(address, endpoint=shard['http'], timeout=timeout) + }) + except (KeyError, RPCError, RequestsError, RequestsTimeoutError, JSONDecodeError): + if not skip_error: + balances.append({ + 'shard': shard['shardID'], + 'balance': None + }) return balances + + +def get_total_balance(address, endpoint=_default_endpoint, timeout=_default_timeout) -> int: + """ + Get total account balance on all shards + + Parameters + ---------- + address: str + Address to get balance for + endpoint: :obj:`str`, optional + Endpoint to send request to + timeout: :obj:`int`, optional + Timeout in seconds per request + + Returns + ------- + int + Total account balance in ATTO + + Raises + ------ + RuntimeError + If error occurred getting account balance for a shard + """ + try: + balances = get_balance_on_all_shards(address, skip_error=False, endpoint=endpoint, timeout=timeout) + return sum(b['balance'] for b in balances) + except TypeError as e: + raise RuntimeError from e diff --git a/pyhmy/rpc/exceptions.py b/pyhmy/rpc/exceptions.py index a2cfe6b..2db62b9 100644 --- a/pyhmy/rpc/exceptions.py +++ b/pyhmy/rpc/exceptions.py @@ -7,6 +7,15 @@ class RPCError(RuntimeError): Exception raised when RPC call returns an error """ +class InvalidRPCReplyError(RuntimeError): + """ + Exception raised when RPC call returns unexpected result + Generally indicates Harmony API has been updated & pyhmy library needs to be updated as well + """ + + def __init__(self, method, endpoint): + self.message = f'Unexpected reply for {method} from {endpoint}' + class JSONDecodeError(json.decoder.JSONDecodeError): """ Wrapper for json lib DecodeError exception diff --git a/tests/rpc-pyhmy/test_account.py b/tests/rpc-pyhmy/test_account.py index e4ec79a..6d22e4b 100644 --- a/tests/rpc-pyhmy/test_account.py +++ b/tests/rpc-pyhmy/test_account.py @@ -41,16 +41,31 @@ def test_get_balance_by_block(setup_blockchain): assert balance > 0 @pytest.mark.run(order=3) -def test_get_transaction_count(setup_blockchain): - transactions = _test_account_rpc(account.get_transaction_count, local_test_address, endpoint=endpoint_shard_one) - assert transactions > 0 +def test_get_true_nonce(setup_blockchain): + true_nonce = _test_account_rpc(account.get_account_nonce, local_test_address, true_nonce=True, endpoint=endpoint_shard_one) + assert true_nonce > 0 @pytest.mark.run(order=4) +def test_get_pending_nonce(setup_blockchain): + pending_nonce = _test_account_rpc(account.get_account_nonce, local_test_address, endpoint=endpoint_shard_one) + assert pending_nonce > 0 + +@pytest.mark.run(order=5) def test_get_transaction_history(setup_blockchain): tx_history = _test_account_rpc(account.get_transaction_history, local_test_address, endpoint=explorer_endpoint) assert len(tx_history) >= 0 -@pytest.mark.run(order=5) +@pytest.mark.run(order=6) def test_get_staking_transaction_history(setup_blockchain): staking_tx_history = _test_account_rpc(account.get_staking_transaction_history, test_validator_address, endpoint=explorer_endpoint) assert len(staking_tx_history) > 0 + +@pytest.mark.run(order=7) +def test_get_balance_on_all_shards(setup_blockchain): + balances = _test_account_rpc(account.get_balance_on_all_shards, local_test_address) + assert len(balances) == 2 + +@pytest.mark.run(order=8) +def test_get_total_balance(setup_blockchain): + total_balance = _test_account_rpc(account.get_total_balance, local_test_address) + assert total_balance > 0