The core protocol of WoopChain
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
woop/p2p/stream/protocols/sync/client_test.go

462 lines
12 KiB

package sync
import (
"context"
"errors"
"fmt"
"math/big"
"strings"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/rlp"
"github.com/harmony-one/harmony/block"
headerV3 "github.com/harmony-one/harmony/block/v3"
"github.com/harmony-one/harmony/core/types"
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter"
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/harmony-one/harmony/shard"
)
var (
_ sttypes.Request = &getBlocksByNumberRequest{}
_ sttypes.Request = &getEpochBlockRequest{}
_ sttypes.Request = &getBlockNumberRequest{}
_ sttypes.Response = &syncResponse{&syncpb.Response{}}
)
var (
initStreamIDs = []sttypes.StreamID{
makeTestStreamID(0),
makeTestStreamID(1),
makeTestStreamID(2),
makeTestStreamID(3),
}
)
var (
testHeader = &block.Header{Header: headerV3.NewHeader()}
testBlock = types.NewBlockWithHeader(testHeader)
testHeaderBytes, _ = rlp.EncodeToBytes(testHeader)
testBlockBytes, _ = rlp.EncodeToBytes(testBlock)
testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
testEpochState = &shard.State{
Epoch: new(big.Int).SetInt64(1),
Shards: []shard.Committee{},
}
testEpochStateBytes, _ = rlp.EncodeToBytes(testEpochState)
testEpochStateResponse = syncpb.MakeGetEpochStateResponse(0, testHeaderBytes, testEpochStateBytes)
testCurBlockNumber uint64 = 100
testBlockNumberResponse = syncpb.MakeGetBlockNumberResponse(0, testCurBlockNumber)
testHash = numberToHash(100)
testBlockHashesResponse = syncpb.MakeGetBlockHashesResponse(0, []common.Hash{testHash})
testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error"))
)
func TestProtocol_GetBlocksByNumber(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testEpochStateResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlockByNumber"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
blocks, stid, err := protocol.GetBlocksByNumber(context.Background(), []uint64{0})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil && (len(blocks) == 0) {
t.Errorf("Test %v: zero blocks delivered", i)
}
}
}
func TestProtocol_GetEpochState(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testEpochStateResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetEpochStateResponse"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
res, stid, err := protocol.GetEpochState(context.Background(), 0)
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if gotEpoch := res.State.Epoch; gotEpoch.Cmp(new(big.Int).SetUint64(1)) != 0 {
t.Errorf("Test %v: unexpected epoch delivered: %v / %v", i, gotEpoch.String(), 1)
}
}
}
}
func TestProtocol_GetCurrentBlockNumber(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockNumberResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlockNumber"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
res, stid, err := protocol.GetCurrentBlockNumber(context.Background())
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if res != testCurBlockNumber {
t.Errorf("Test %v: block number not expected: %v / %v", i, res, testCurBlockNumber)
}
}
}
}
func TestProtocol_GetBlockHashes(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockHashesResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlockHashes"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
res, stid, err := protocol.GetBlockHashes(context.Background(), []uint64{100})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if len(res) != 1 {
t.Errorf("Test %v: size not 1", i)
}
if res[0] != testHash {
t.Errorf("Test %v: hash not expected", i)
}
}
}
}
func TestProtocol_GetBlocksByHashes(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlocksByHashesResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlocksByHashes"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
blocks, stid, err := protocol.GetBlocksByHashes(context.Background(), []common.Hash{numberToHash(100)})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if len(blocks) != 1 {
t.Errorf("Test %v: size not 1", i)
}
}
}
}
type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID)
type testHostRequestManager struct {
getResponse getResponseFn
}
func makeTestProtocol(f getResponseFn) *Protocol {
rm := &testHostRequestManager{f}
streamIDs := make([]sttypes.StreamID, len(initStreamIDs))
copy(streamIDs, initStreamIDs)
sm := &testStreamManager{streamIDs}
rl := ratelimiter.NewRateLimiter(10, 10)
return &Protocol{
rm: rm,
rl: rl,
sm: sm,
}
}
func (rm *testHostRequestManager) Start() {}
func (rm *testHostRequestManager) Close() {}
func (rm *testHostRequestManager) DeliverResponse(sttypes.StreamID, sttypes.Response) {}
func (rm *testHostRequestManager) DoRequest(ctx context.Context, request sttypes.Request, opts ...Option) (sttypes.Response, sttypes.StreamID, error) {
if rm.getResponse == nil {
return nil, "", errors.New("get response error")
}
resp, stid := rm.getResponse(request)
return resp, stid, nil
}
func makeTestStreamID(index int) sttypes.StreamID {
id := fmt.Sprintf("[test stream %v]", index)
return sttypes.StreamID(id)
}
// mock stream manager
type testStreamManager struct {
streamIDs []sttypes.StreamID
}
func (sm *testStreamManager) Start() {}
func (sm *testStreamManager) Close() {}
func (sm *testStreamManager) SubscribeAddStreamEvent(chan<- streammanager.EvtStreamAdded) event.Subscription {
return nil
}
func (sm *testStreamManager) SubscribeRemoveStreamEvent(chan<- streammanager.EvtStreamRemoved) event.Subscription {
return nil
}
func (sm *testStreamManager) NewStream(stream sttypes.Stream) error {
stid := stream.ID()
for _, id := range sm.streamIDs {
if id == stid {
return errors.New("stream already exist")
}
}
sm.streamIDs = append(sm.streamIDs, stid)
return nil
}
func (sm *testStreamManager) RemoveStream(stID sttypes.StreamID) error {
for i, id := range sm.streamIDs {
if id == stID {
sm.streamIDs = append(sm.streamIDs[:i], sm.streamIDs[i+1:]...)
}
}
return errors.New("stream not exist")
}
func (sm *testStreamManager) isStreamExist(stid sttypes.StreamID) bool {
for _, id := range sm.streamIDs {
if id == stid {
return true
}
}
return false
}
func (sm *testStreamManager) GetStreams() []sttypes.Stream {
return nil
}
func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) {
return nil, false
}
func assertError(got, expect error) error {
if (got == nil) != (expect == nil) {
return fmt.Errorf("unexpected error: %v / %v", got, expect)
}
if got == nil {
return nil
}
if !strings.Contains(got.Error(), expect.Error()) {
return fmt.Errorf("unexpected error: %v/ %v", got, expect)
}
return nil
}