You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
858 lines
23 KiB
858 lines
23 KiB
package sync
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/ethereum/go-ethereum/event"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
"github.com/harmony-one/harmony/block"
|
|
headerV3 "github.com/harmony-one/harmony/block/v3"
|
|
"github.com/harmony-one/harmony/core/types"
|
|
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter"
|
|
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
|
|
"github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
|
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
|
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
|
|
)
|
|
|
|
var (
|
|
_ sttypes.Request = &getBlocksByNumberRequest{}
|
|
_ sttypes.Request = &getBlockNumberRequest{}
|
|
_ sttypes.Request = &getReceiptsRequest{}
|
|
_ sttypes.Response = &syncResponse{&syncpb.Response{}}
|
|
// MaxHash represents the maximum possible hash value.
|
|
MaxHash = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
|
|
)
|
|
|
|
var (
|
|
initStreamIDs = []sttypes.StreamID{
|
|
makeTestStreamID(0),
|
|
makeTestStreamID(1),
|
|
makeTestStreamID(2),
|
|
makeTestStreamID(3),
|
|
}
|
|
)
|
|
|
|
var (
|
|
testHeader = &block.Header{Header: headerV3.NewHeader()}
|
|
testBlock = types.NewBlockWithHeader(testHeader)
|
|
testReceipt = &types.Receipt{
|
|
Status: types.ReceiptStatusSuccessful,
|
|
CumulativeGasUsed: 0x888888888,
|
|
Logs: []*types.Log{},
|
|
}
|
|
testNodeData = numberToHash(123456789).Bytes()
|
|
testHeaderBytes, _ = rlp.EncodeToBytes(testHeader)
|
|
testBlockBytes, _ = rlp.EncodeToBytes(testBlock)
|
|
testReceiptBytes, _ = rlp.EncodeToBytes(testReceipt)
|
|
testNodeDataBytes, _ = rlp.EncodeToBytes(testNodeData)
|
|
|
|
testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
|
|
|
|
testCurBlockNumber uint64 = 100
|
|
testBlockNumberResponse = syncpb.MakeGetBlockNumberResponse(0, testCurBlockNumber)
|
|
|
|
testHash = numberToHash(100)
|
|
testBlockHashesResponse = syncpb.MakeGetBlockHashesResponse(0, []common.Hash{testHash})
|
|
|
|
testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
|
|
|
|
testReceipsMap = map[uint64]*message.Receipts{
|
|
0: {ReceiptBytes: [][]byte{testReceiptBytes}},
|
|
}
|
|
testReceiptResponse = syncpb.MakeGetReceiptsResponse(0, testReceipsMap)
|
|
|
|
testNodeDataResponse = syncpb.MakeGetNodeDataResponse(0, [][]byte{testNodeDataBytes})
|
|
|
|
account1 = common.HexToHash("0xf493f79c43bd747129a226ad42529885a4b108aba6046b2d12071695a6627844")
|
|
account2 = common.HexToHash("0xf493f79c43bd747129a226ad42529885a4b108aba6046b2d12071695a6627844")
|
|
resAccounts = []common.Hash{account1, account2}
|
|
|
|
accountsData = []*message.AccountData{
|
|
&syncpb.AccountData{
|
|
Hash: account1[:],
|
|
Body: common.HexToHash("0x00bf100000000000000000000000000000000000000000000000000000000000").Bytes(),
|
|
},
|
|
&syncpb.AccountData{
|
|
Hash: account2[:],
|
|
Body: common.HexToHash("0x00bf100000000000000000000000000000000000000000000000000000000000").Bytes(),
|
|
},
|
|
}
|
|
|
|
slots = []*syncpb.StoragesData{
|
|
&syncpb.StoragesData{
|
|
Data: []*syncpb.StorageData{
|
|
&syncpb.StorageData{
|
|
Hash: account1[:],
|
|
Body: common.HexToHash("0x00bf100000000000000000000000000000000000000000000000000000000000").Bytes(),
|
|
},
|
|
},
|
|
},
|
|
&syncpb.StoragesData{
|
|
Data: []*syncpb.StorageData{
|
|
&syncpb.StorageData{
|
|
Hash: account2[:],
|
|
Body: common.HexToHash("0x00bf100000000000000000000000000000000000000000000000000000000000").Bytes(),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
proofBytes1, _ = rlp.EncodeToBytes(account1)
|
|
proofBytes2, _ = rlp.EncodeToBytes(account2)
|
|
proof = [][]byte{proofBytes1, proofBytes2}
|
|
|
|
codeBytes1, _ = rlp.EncodeToBytes(account1)
|
|
codeBytes2, _ = rlp.EncodeToBytes(account2)
|
|
testByteCodes = [][]byte{codeBytes1, codeBytes2}
|
|
dataNodeBytes1, _ = rlp.EncodeToBytes(numberToHash(1).Bytes())
|
|
dataNodeBytes2, _ = rlp.EncodeToBytes(numberToHash(2).Bytes())
|
|
testTrieNodes = [][]byte{dataNodeBytes1, dataNodeBytes2}
|
|
testPathSet = [][]byte{numberToHash(19850928).Bytes(), numberToHash(13640607).Bytes()}
|
|
|
|
testPaths = []*syncpb.TrieNodePathSet{
|
|
&syncpb.TrieNodePathSet{
|
|
Pathset: testPathSet,
|
|
},
|
|
&syncpb.TrieNodePathSet{
|
|
Pathset: testPathSet,
|
|
},
|
|
}
|
|
|
|
testAccountRangeResponse = syncpb.MakeGetAccountRangeResponse(0, accountsData, proof)
|
|
|
|
testStorageRangesResponse = syncpb.MakeGetStorageRangesResponse(0, slots, proof)
|
|
|
|
testByteCodesResponse = syncpb.MakeGetByteCodesResponse(0, testByteCodes)
|
|
|
|
testTrieNodesResponse = syncpb.MakeGetTrieNodesResponse(0, testTrieNodes)
|
|
|
|
testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error"))
|
|
)
|
|
|
|
func TestProtocol_GetBlocksByNumber(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockNumberResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("not GetBlockByNumber"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
blocks, stid, err := protocol.GetBlocksByNumber(context.Background(), []uint64{0})
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil && (len(blocks) == 0) {
|
|
t.Errorf("Test %v: zero blocks delivered", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetCurrentBlockNumber(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockNumberResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("not GetBlockNumber"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
res, stid, err := protocol.GetCurrentBlockNumber(context.Background())
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if res != testCurBlockNumber {
|
|
t.Errorf("Test %v: block number not expected: %v / %v", i, res, testCurBlockNumber)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetBlockHashes(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockHashesResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("not GetBlockHashes"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
res, stid, err := protocol.GetBlockHashes(context.Background(), []uint64{100})
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(res) != 1 {
|
|
t.Errorf("Test %v: size not 1", i)
|
|
}
|
|
if res[0] != testHash {
|
|
t.Errorf("Test %v: hash not expected", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetBlocksByHashes(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlocksByHashesResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("not GetBlocksByHashes"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
blocks, stid, err := protocol.GetBlocksByHashes(context.Background(), []common.Hash{numberToHash(100)})
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(blocks) != 1 {
|
|
t.Errorf("Test %v: size not 1", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetReceipts(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testReceiptResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("response not GetReceipts"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
receipts, stid, err := protocol.GetReceipts(context.Background(), []common.Hash{numberToHash(100)})
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(receipts) != 1 {
|
|
t.Errorf("Test %v: size not 1", i)
|
|
}
|
|
if len(receipts[0]) != 1 {
|
|
t.Errorf("Test %v: block receipts size not 1", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetNodeData(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testNodeDataResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("response not GetNodeData"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
receipts, stid, err := protocol.GetNodeData(context.Background(), []common.Hash{numberToHash(100)})
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(receipts) != 1 {
|
|
t.Errorf("Test %v: size not 1", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetAccountRange(t *testing.T) {
|
|
var (
|
|
root = numberToHash(1985082913640607)
|
|
ffHash = MaxHash
|
|
zero = common.Hash{}
|
|
)
|
|
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testAccountRangeResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("response not GetAccountRange"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
accounts, proof, stid, err := protocol.GetAccountRange(context.Background(), root, zero, ffHash, uint64(100))
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(accounts) != len(proof) {
|
|
t.Errorf("accounts: %v", test.getResponse)
|
|
t.Errorf("accounts: %v", accounts)
|
|
t.Errorf("proof: %v", proof)
|
|
t.Errorf("Test %v: accounts size (%d) not equal to proof size (%d)", i, len(accounts), len(proof))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetStorageRanges(t *testing.T) {
|
|
var (
|
|
root = numberToHash(1985082913640607)
|
|
firstKey = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf2a")
|
|
secondKey = common.HexToHash("0x09e47cd5056a689e708f22fe1f932709a320518e444f5f7d8d46a3da523d6606")
|
|
testAccounts = []common.Hash{secondKey, firstKey}
|
|
ffHash = MaxHash
|
|
zero = common.Hash{}
|
|
)
|
|
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testStorageRangesResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("response not GetStorageRanges"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
slots, proof, stid, err := protocol.GetStorageRanges(context.Background(), root, testAccounts, zero, ffHash, uint64(100))
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(slots) != len(testAccounts) {
|
|
t.Errorf("Test %v: slots size not equal to accounts size", i)
|
|
}
|
|
if len(slots) != len(proof) {
|
|
t.Errorf("Test %v: account size not equal to proof", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetByteCodes(t *testing.T) {
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testByteCodesResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("response not GetByteCodes"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
codes, stid, err := protocol.GetByteCodes(context.Background(), []common.Hash{numberToHash(19850829)}, uint64(500))
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(codes) != 2 {
|
|
t.Errorf("Test %v: size not 2", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestProtocol_GetTrieNodes(t *testing.T) {
|
|
var (
|
|
root = numberToHash(1985082913640607)
|
|
)
|
|
|
|
tests := []struct {
|
|
getResponse getResponseFn
|
|
expErr error
|
|
expStID sttypes.StreamID
|
|
}{
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testTrieNodesResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: nil,
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testBlockResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("response not GetTrieNodes"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
{
|
|
getResponse: nil,
|
|
expErr: errors.New("get response error"),
|
|
expStID: "",
|
|
},
|
|
{
|
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
|
|
return &syncResponse{
|
|
pb: testErrorResponse,
|
|
}, makeTestStreamID(0)
|
|
},
|
|
expErr: errors.New("test error"),
|
|
expStID: makeTestStreamID(0),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
protocol := makeTestProtocol(test.getResponse)
|
|
nodes, stid, err := protocol.GetTrieNodes(context.Background(), root, testPaths, uint64(500))
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if stid != test.expStID {
|
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
|
|
}
|
|
if test.expErr == nil {
|
|
if len(nodes) != 2 {
|
|
t.Errorf("Test %v: size not 2", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID)
|
|
|
|
type testHostRequestManager struct {
|
|
getResponse getResponseFn
|
|
}
|
|
|
|
func makeTestProtocol(f getResponseFn) *Protocol {
|
|
rm := &testHostRequestManager{f}
|
|
|
|
streamIDs := make([]sttypes.StreamID, len(initStreamIDs))
|
|
copy(streamIDs, initStreamIDs)
|
|
sm := &testStreamManager{streamIDs}
|
|
|
|
rl := ratelimiter.NewRateLimiter(sm, 10, 10)
|
|
|
|
return &Protocol{
|
|
rm: rm,
|
|
rl: rl,
|
|
sm: sm,
|
|
}
|
|
}
|
|
|
|
func (rm *testHostRequestManager) Start() {}
|
|
func (rm *testHostRequestManager) Close() {}
|
|
func (rm *testHostRequestManager) DeliverResponse(sttypes.StreamID, sttypes.Response) {}
|
|
|
|
func (rm *testHostRequestManager) DoRequest(ctx context.Context, request sttypes.Request, opts ...Option) (sttypes.Response, sttypes.StreamID, error) {
|
|
if rm.getResponse == nil {
|
|
return nil, "", errors.New("get response error")
|
|
}
|
|
resp, stid := rm.getResponse(request)
|
|
return resp, stid, nil
|
|
}
|
|
|
|
func makeTestStreamID(index int) sttypes.StreamID {
|
|
id := fmt.Sprintf("[test stream %v]", index)
|
|
return sttypes.StreamID(id)
|
|
}
|
|
|
|
// mock stream manager
|
|
type testStreamManager struct {
|
|
streamIDs []sttypes.StreamID
|
|
}
|
|
|
|
func (sm *testStreamManager) Start() {}
|
|
func (sm *testStreamManager) Close() {}
|
|
func (sm *testStreamManager) SubscribeAddStreamEvent(chan<- streammanager.EvtStreamAdded) event.Subscription {
|
|
return nil
|
|
}
|
|
func (sm *testStreamManager) SubscribeRemoveStreamEvent(chan<- streammanager.EvtStreamRemoved) event.Subscription {
|
|
return nil
|
|
}
|
|
|
|
func (sm *testStreamManager) NewStream(stream sttypes.Stream) error {
|
|
stid := stream.ID()
|
|
for _, id := range sm.streamIDs {
|
|
if id == stid {
|
|
return errors.New("stream already exist")
|
|
}
|
|
}
|
|
sm.streamIDs = append(sm.streamIDs, stid)
|
|
return nil
|
|
}
|
|
|
|
func (sm *testStreamManager) RemoveStream(stID sttypes.StreamID) error {
|
|
for i, id := range sm.streamIDs {
|
|
if id == stID {
|
|
sm.streamIDs = append(sm.streamIDs[:i], sm.streamIDs[i+1:]...)
|
|
}
|
|
}
|
|
return errors.New("stream not exist")
|
|
}
|
|
|
|
func (sm *testStreamManager) isStreamExist(stid sttypes.StreamID) bool {
|
|
for _, id := range sm.streamIDs {
|
|
if id == stid {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (sm *testStreamManager) GetStreams() []sttypes.Stream {
|
|
return nil
|
|
}
|
|
|
|
func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) {
|
|
return nil, false
|
|
}
|
|
|
|
func assertError(got, expect error) error {
|
|
if (got == nil) != (expect == nil) {
|
|
return fmt.Errorf("unexpected error: %v / %v", got, expect)
|
|
}
|
|
if got == nil {
|
|
return nil
|
|
}
|
|
if !strings.Contains(got.Error(), expect.Error()) {
|
|
return fmt.Errorf("unexpected error: %v/ %v", got, expect)
|
|
}
|
|
return nil
|
|
}
|
|
|