parent
1caf88c7db
commit
3977ac49eb
@ -0,0 +1,169 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"math/big" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common" |
||||||
|
"github.com/harmony-one/harmony/block" |
||||||
|
"github.com/harmony-one/harmony/consensus/engine" |
||||||
|
"github.com/harmony-one/harmony/core/types" |
||||||
|
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// chainHelper is the adapter for blockchain which is friendly to unit test.
|
||||||
|
type chainHelper interface { |
||||||
|
getCurrentBlockNumber() uint64 |
||||||
|
getBlockHashes(bns []uint64) []common.Hash |
||||||
|
getBlocksByNumber(bns []uint64) ([]*types.Block, error) |
||||||
|
getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) |
||||||
|
getEpochState(epoch uint64) (*EpochStateResult, error) |
||||||
|
} |
||||||
|
|
||||||
|
type chainHelperImpl struct { |
||||||
|
chain engine.ChainReader |
||||||
|
schedule shardingconfig.Schedule |
||||||
|
} |
||||||
|
|
||||||
|
func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl { |
||||||
|
return &chainHelperImpl{ |
||||||
|
chain: chain, |
||||||
|
schedule: schedule, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getCurrentBlockNumber() uint64 { |
||||||
|
return ch.chain.CurrentBlock().NumberU64() |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlockHashes(bns []uint64) []common.Hash { |
||||||
|
hashes := make([]common.Hash, 0, len(bns)) |
||||||
|
for _, bn := range bns { |
||||||
|
var h common.Hash |
||||||
|
header := ch.chain.GetHeaderByNumber(bn) |
||||||
|
if header != nil { |
||||||
|
h = header.Hash() |
||||||
|
} |
||||||
|
hashes = append(hashes, h) |
||||||
|
} |
||||||
|
return hashes |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlocksByNumber(bns []uint64) ([]*types.Block, error) { |
||||||
|
var ( |
||||||
|
blocks = make([]*types.Block, 0, len(bns)) |
||||||
|
) |
||||||
|
for _, bn := range bns { |
||||||
|
var ( |
||||||
|
block *types.Block |
||||||
|
err error |
||||||
|
) |
||||||
|
header := ch.chain.GetHeaderByNumber(bn) |
||||||
|
if header != nil { |
||||||
|
block, err = ch.getBlockWithSigByHeader(header) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.Wrapf(err, "get block %v at %v", header.Hash().String(), header.Number()) |
||||||
|
} |
||||||
|
} |
||||||
|
blocks = append(blocks, block) |
||||||
|
|
||||||
|
} |
||||||
|
return blocks, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) { |
||||||
|
var ( |
||||||
|
blocks = make([]*types.Block, 0, len(hs)) |
||||||
|
) |
||||||
|
for _, h := range hs { |
||||||
|
var ( |
||||||
|
block *types.Block |
||||||
|
err error |
||||||
|
) |
||||||
|
header := ch.chain.GetHeaderByHash(h) |
||||||
|
if header != nil { |
||||||
|
block, err = ch.getBlockWithSigByHeader(header) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.Wrapf(err, "get block %v at %v", header.Hash().String(), header.Number()) |
||||||
|
} |
||||||
|
} |
||||||
|
blocks = append(blocks, block) |
||||||
|
} |
||||||
|
return blocks, nil |
||||||
|
} |
||||||
|
|
||||||
|
var errBlockNotFound = errors.New("block not found") |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) { |
||||||
|
b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64()) |
||||||
|
if b == nil { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
commitSig, err := ch.getBlockSigAndBitmap(header) |
||||||
|
if err != nil { |
||||||
|
return nil, errors.New("missing commit signature") |
||||||
|
} |
||||||
|
b.SetCurrentCommitSig(commitSig) |
||||||
|
return b, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) { |
||||||
|
sb := ch.getBlockSigFromNextBlock(header) |
||||||
|
if len(sb) != 0 { |
||||||
|
return sb, nil |
||||||
|
} |
||||||
|
// Note: some commit sig read from db is different from [nextHeader.sig, nextHeader.bitMap]
|
||||||
|
// nextBlock data is better to be used.
|
||||||
|
return ch.getBlockSigFromDB(header) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlockSigFromNextBlock(header *block.Header) []byte { |
||||||
|
nextBN := header.Number().Uint64() + 1 |
||||||
|
nextHeader := ch.chain.GetHeaderByNumber(nextBN) |
||||||
|
if nextHeader == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
sigBytes := nextHeader.LastCommitSignature() |
||||||
|
bitMap := nextHeader.LastCommitBitmap() |
||||||
|
sb := make([]byte, len(sigBytes)+len(bitMap)) |
||||||
|
copy(sb[:], sigBytes[:]) |
||||||
|
copy(sb[len(sigBytes):], bitMap[:]) |
||||||
|
|
||||||
|
return sb |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getBlockSigFromDB(header *block.Header) ([]byte, error) { |
||||||
|
return ch.chain.ReadCommitSig(header.Number().Uint64()) |
||||||
|
} |
||||||
|
|
||||||
|
func (ch *chainHelperImpl) getEpochState(epoch uint64) (*EpochStateResult, error) { |
||||||
|
if ch.chain.ShardID() != 0 { |
||||||
|
return nil, errors.New("get epoch state currently unavailable on side chain") |
||||||
|
} |
||||||
|
if epoch == 0 { |
||||||
|
return nil, errors.New("nil shard state for epoch 0") |
||||||
|
} |
||||||
|
res := &EpochStateResult{} |
||||||
|
|
||||||
|
targetBN := ch.schedule.EpochLastBlock(epoch - 1) |
||||||
|
res.Header = ch.chain.GetHeaderByNumber(targetBN) |
||||||
|
if res.Header == nil { |
||||||
|
// we still don't have the given epoch
|
||||||
|
return res, nil |
||||||
|
} |
||||||
|
epochBI := new(big.Int).SetUint64(epoch) |
||||||
|
if ch.chain.Config().IsPreStaking(epochBI) { |
||||||
|
// For epoch before preStaking, only hash is stored in header
|
||||||
|
ss, err := ch.chain.ReadShardState(epochBI) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if ss == nil { |
||||||
|
return nil, fmt.Errorf("missing shard state for [EPOCH-%v]", epoch) |
||||||
|
} |
||||||
|
res.State = ss |
||||||
|
} |
||||||
|
return res, nil |
||||||
|
} |
@ -0,0 +1,208 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"encoding/binary" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"math/big" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common" |
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
|
protobuf "github.com/golang/protobuf/proto" |
||||||
|
"github.com/harmony-one/harmony/block" |
||||||
|
"github.com/harmony-one/harmony/core/types" |
||||||
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||||
|
"github.com/harmony-one/harmony/shard" |
||||||
|
) |
||||||
|
|
||||||
|
type testChainHelper struct{} |
||||||
|
|
||||||
|
func (tch *testChainHelper) getCurrentBlockNumber() uint64 { |
||||||
|
return 100 |
||||||
|
} |
||||||
|
|
||||||
|
func (tch *testChainHelper) getBlocksByNumber(bns []uint64) ([]*types.Block, error) { |
||||||
|
blocks := make([]*types.Block, 0, len(bns)) |
||||||
|
for _, bn := range bns { |
||||||
|
blocks = append(blocks, makeTestBlock(bn)) |
||||||
|
} |
||||||
|
return blocks, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (tch *testChainHelper) getEpochState(epoch uint64) (*EpochStateResult, error) { |
||||||
|
header := &block.Header{Header: testHeader.Copy()} |
||||||
|
header.SetEpoch(big.NewInt(int64(epoch - 1))) |
||||||
|
|
||||||
|
state := testEpochState.DeepCopy() |
||||||
|
state.Epoch = big.NewInt(int64(epoch)) |
||||||
|
|
||||||
|
return &EpochStateResult{ |
||||||
|
Header: header, |
||||||
|
State: state, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (tch *testChainHelper) getBlockHashes(bns []uint64) []common.Hash { |
||||||
|
hs := make([]common.Hash, 0, len(bns)) |
||||||
|
for _, bn := range bns { |
||||||
|
hs = append(hs, numberToHash(bn)) |
||||||
|
} |
||||||
|
return hs |
||||||
|
} |
||||||
|
|
||||||
|
func (tch *testChainHelper) getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) { |
||||||
|
bs := make([]*types.Block, 0, len(hs)) |
||||||
|
for _, h := range hs { |
||||||
|
bn := hashToNumber(h) |
||||||
|
bs = append(bs, makeTestBlock(bn)) |
||||||
|
} |
||||||
|
return bs, nil |
||||||
|
} |
||||||
|
|
||||||
|
func numberToHash(bn uint64) common.Hash { |
||||||
|
var h common.Hash |
||||||
|
binary.LittleEndian.PutUint64(h[:], bn) |
||||||
|
return h |
||||||
|
} |
||||||
|
|
||||||
|
func hashToNumber(h common.Hash) uint64 { |
||||||
|
return binary.LittleEndian.Uint64(h[:]) |
||||||
|
} |
||||||
|
|
||||||
|
func checkBlocksResult(bns []uint64, b []byte) error { |
||||||
|
var msg = &syncpb.Message{} |
||||||
|
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
gbResp, err := msg.GetBlocksByNumberResponse() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if len(gbResp.BlocksBytes) == 0 { |
||||||
|
return errors.New("nil response from GetBlocksByNumber") |
||||||
|
} |
||||||
|
blocks, err := decodeBlocksBytes(gbResp.BlocksBytes) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if len(blocks) != len(bns) { |
||||||
|
return errors.New("unexpected blocks number") |
||||||
|
} |
||||||
|
for i, bn := range bns { |
||||||
|
blk := blocks[i] |
||||||
|
if bn != blk.NumberU64() { |
||||||
|
return errors.New("unexpected number of a block") |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func makeTestBlock(bn uint64) *types.Block { |
||||||
|
header := testHeader.Copy() |
||||||
|
header.SetNumber(big.NewInt(int64(bn))) |
||||||
|
return types.NewBlockWithHeader(&block.Header{Header: header}) |
||||||
|
} |
||||||
|
|
||||||
|
func decodeBlocksBytes(bbs [][]byte) ([]*types.Block, error) { |
||||||
|
blocks := make([]*types.Block, 0, len(bbs)) |
||||||
|
|
||||||
|
for _, bb := range bbs { |
||||||
|
var block *types.Block |
||||||
|
if err := rlp.DecodeBytes(bb, &block); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
blocks = append(blocks, block) |
||||||
|
} |
||||||
|
return blocks, nil |
||||||
|
} |
||||||
|
|
||||||
|
func checkEpochStateResult(epoch uint64, b []byte) error { |
||||||
|
var msg = &syncpb.Message{} |
||||||
|
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
geResp, err := msg.GetEpochStateResponse() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
var ( |
||||||
|
header *block.Header |
||||||
|
epochState *shard.State |
||||||
|
) |
||||||
|
if err := rlp.DecodeBytes(geResp.HeaderBytes, &header); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if err := rlp.DecodeBytes(geResp.ShardState, &epochState); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if header.Epoch().Uint64() != epoch-1 { |
||||||
|
return fmt.Errorf("unexpected epoch of header %v / %v", header.Epoch(), epoch-1) |
||||||
|
} |
||||||
|
if epochState.Epoch.Uint64() != epoch { |
||||||
|
return fmt.Errorf("unexpected epoch of shard state %v / %v", epochState.Epoch.Uint64(), epoch) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func checkBlockNumberResult(b []byte) error { |
||||||
|
var msg = &syncpb.Message{} |
||||||
|
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
gnResp, err := msg.GetBlockNumberResponse() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if gnResp.Number != testCurBlockNumber { |
||||||
|
return fmt.Errorf("unexpected block number: %v / %v", gnResp.Number, testCurBlockNumber) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func checkBlockHashesResult(b []byte, bns []uint64) error { |
||||||
|
var msg = &syncpb.Message{} |
||||||
|
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
bhResp, err := msg.GetBlockHashesResponse() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
got := bhResp.Hashes |
||||||
|
if len(got) != len(bns) { |
||||||
|
return errors.New("unexpected size") |
||||||
|
} |
||||||
|
for i, bn := range bns { |
||||||
|
expect := numberToHash(bn) |
||||||
|
if !bytes.Equal(expect[:], got[i]) { |
||||||
|
return errors.New("unexpected hash") |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func checkBlocksByHashesResult(b []byte, hs []common.Hash) error { |
||||||
|
var msg = &syncpb.Message{} |
||||||
|
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
bhResp, err := msg.GetBlocksByHashesResponse() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if len(hs) != len(bhResp.BlocksBytes) { |
||||||
|
return errors.New("unexpected size") |
||||||
|
} |
||||||
|
for i, h := range hs { |
||||||
|
num := hashToNumber(h) |
||||||
|
var blk *types.Block |
||||||
|
if err := rlp.DecodeBytes(bhResp.BlocksBytes[i], &blk); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if blk.NumberU64() != num { |
||||||
|
return fmt.Errorf("unexpected number %v != %v", blk.NumberU64(), num) |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,396 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common" |
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
|
protobuf "github.com/golang/protobuf/proto" |
||||||
|
"github.com/harmony-one/harmony/core/types" |
||||||
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// GetBlocksByNumber do getBlocksByNumberRequest through sync stream protocol.
|
||||||
|
// Return the block as result, target stream id, and error
|
||||||
|
func (p *Protocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...Option) ([]*types.Block, sttypes.StreamID, error) { |
||||||
|
if len(bns) > GetBlocksByNumAmountCap { |
||||||
|
return nil, "", fmt.Errorf("number of blocks exceed cap of %v", GetBlocksByNumAmountCap) |
||||||
|
} |
||||||
|
|
||||||
|
req := newGetBlocksByNumberRequest(bns) |
||||||
|
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||||
|
if err != nil { |
||||||
|
// At this point, error can be context canceled, context timed out, or waiting queue
|
||||||
|
// is already full.
|
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
|
||||||
|
// Parse and return blocks
|
||||||
|
blocks, err := req.getBlocksFromResponse(resp) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
return blocks, stid, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetEpochState get the epoch block from querying the remote node running sync stream protocol.
|
||||||
|
// Currently, this method is only supported by beacon syncer.
|
||||||
|
// Note: use this after epoch chain is implemented.
|
||||||
|
func (p *Protocol) GetEpochState(ctx context.Context, epoch uint64, opts ...Option) (*EpochStateResult, sttypes.StreamID, error) { |
||||||
|
req := newGetEpochBlockRequest(epoch) |
||||||
|
|
||||||
|
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
|
||||||
|
res, err := epochStateResultFromResponse(resp) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
return res, stid, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetCurrentBlockNumber get the current block number from remote node
|
||||||
|
func (p *Protocol) GetCurrentBlockNumber(ctx context.Context, opts ...Option) (uint64, sttypes.StreamID, error) { |
||||||
|
req := newGetBlockNumberRequest() |
||||||
|
|
||||||
|
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||||
|
if err != nil { |
||||||
|
return 0, stid, err |
||||||
|
} |
||||||
|
|
||||||
|
bn, err := req.getNumberFromResponse(resp) |
||||||
|
if err != nil { |
||||||
|
return bn, stid, err |
||||||
|
} |
||||||
|
return bn, stid, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetBlockHashes do getBlockHashesRequest through sync stream protocol.
|
||||||
|
// Return the hash of the given block number. If a block is unknown, the hash will be emptyHash.
|
||||||
|
func (p *Protocol) GetBlockHashes(ctx context.Context, bns []uint64, opts ...Option) ([]common.Hash, sttypes.StreamID, error) { |
||||||
|
if len(bns) > GetBlockHashesAmountCap { |
||||||
|
return nil, "", fmt.Errorf("number of requested numbers exceed limit") |
||||||
|
} |
||||||
|
|
||||||
|
req := newGetBlockHashesRequest(bns) |
||||||
|
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
hashes, err := req.getHashesFromResponse(resp) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
return hashes, stid, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetBlocksByHashes do getBlocksByHashesRequest through sync stream protocol.
|
||||||
|
func (p *Protocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts ...Option) ([]*types.Block, sttypes.StreamID, error) { |
||||||
|
if len(hs) > GetBlocksByHashesAmountCap { |
||||||
|
return nil, "", fmt.Errorf("number of requested hashes exceed limit") |
||||||
|
} |
||||||
|
req := newGetBlocksByHashesRequest(hs) |
||||||
|
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
blocks, err := req.getBlocksFromResponse(resp) |
||||||
|
if err != nil { |
||||||
|
return nil, stid, err |
||||||
|
} |
||||||
|
return blocks, stid, nil |
||||||
|
} |
||||||
|
|
||||||
|
// getBlocksByNumberRequest is the request for get block by numbers which implements
|
||||||
|
// sttypes.Request interface
|
||||||
|
type getBlocksByNumberRequest struct { |
||||||
|
bns []uint64 |
||||||
|
pbReq *syncpb.Request |
||||||
|
} |
||||||
|
|
||||||
|
func newGetBlocksByNumberRequest(bns []uint64) *getBlocksByNumberRequest { |
||||||
|
pbReq := syncpb.MakeGetBlocksByNumRequest(bns) |
||||||
|
return &getBlocksByNumberRequest{ |
||||||
|
bns: bns, |
||||||
|
pbReq: pbReq, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) ReqID() uint64 { |
||||||
|
return req.pbReq.GetReqId() |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) SetReqID(val uint64) { |
||||||
|
req.pbReq.ReqId = val |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) String() string { |
||||||
|
ss := make([]string, 0, len(req.bns)) |
||||||
|
for _, bn := range req.bns { |
||||||
|
ss = append(ss, strconv.Itoa(int(bn))) |
||||||
|
} |
||||||
|
bnsStr := strings.Join(ss, ",") |
||||||
|
return fmt.Sprintf("REQUEST [GetBlockByNumber: %s]", bnsStr) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||||
|
return target.Version.GreaterThanOrEqual(MinVersion) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) Encode() ([]byte, error) { |
||||||
|
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||||
|
return protobuf.Marshal(msg) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) getBlocksFromResponse(resp sttypes.Response) ([]*types.Block, error) { |
||||||
|
sResp, ok := resp.(*syncResponse) |
||||||
|
if !ok || sResp == nil { |
||||||
|
return nil, errors.New("not sync response") |
||||||
|
} |
||||||
|
blockBytes, sigs, err := req.parseBlockBytesAndSigs(sResp) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
blocks := make([]*types.Block, 0, len(blockBytes)) |
||||||
|
for i, bb := range blockBytes { |
||||||
|
var block *types.Block |
||||||
|
if err := rlp.DecodeBytes(bb, &block); err != nil { |
||||||
|
return nil, errors.Wrap(err, "[GetBlocksByNumResponse]") |
||||||
|
} |
||||||
|
if block != nil { |
||||||
|
block.SetCurrentCommitSig(sigs[i]) |
||||||
|
} |
||||||
|
blocks = append(blocks, block) |
||||||
|
} |
||||||
|
return blocks, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByNumberRequest) parseBlockBytesAndSigs(resp *syncResponse) ([][]byte, [][]byte, error) { |
||||||
|
if errResp := resp.pb.GetErrorResponse(); errResp != nil { |
||||||
|
return nil, nil, errors.New(errResp.Error) |
||||||
|
} |
||||||
|
gbResp := resp.pb.GetGetBlocksByNumResponse() |
||||||
|
if gbResp == nil { |
||||||
|
return nil, nil, errors.New("response not GetBlockByNumber") |
||||||
|
} |
||||||
|
if len(gbResp.BlocksBytes) != len(gbResp.CommitSig) { |
||||||
|
return nil, nil, fmt.Errorf("commit sigs size not expected: %v / %v", |
||||||
|
len(gbResp.CommitSig), len(gbResp.BlocksBytes)) |
||||||
|
} |
||||||
|
return gbResp.BlocksBytes, gbResp.CommitSig, nil |
||||||
|
} |
||||||
|
|
||||||
|
type getEpochBlockRequest struct { |
||||||
|
epoch uint64 |
||||||
|
pbReq *syncpb.Request |
||||||
|
} |
||||||
|
|
||||||
|
func newGetEpochBlockRequest(epoch uint64) *getEpochBlockRequest { |
||||||
|
pbReq := syncpb.MakeGetEpochStateRequest(epoch) |
||||||
|
return &getEpochBlockRequest{ |
||||||
|
epoch: epoch, |
||||||
|
pbReq: pbReq, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getEpochBlockRequest) ReqID() uint64 { |
||||||
|
return req.pbReq.GetReqId() |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getEpochBlockRequest) SetReqID(val uint64) { |
||||||
|
req.pbReq.ReqId = val |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getEpochBlockRequest) String() string { |
||||||
|
return fmt.Sprintf("REQUEST [GetEpochBlock: %v]", req.epoch) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getEpochBlockRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||||
|
return target.Version.GreaterThanOrEqual(MinVersion) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getEpochBlockRequest) Encode() ([]byte, error) { |
||||||
|
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||||
|
return protobuf.Marshal(msg) |
||||||
|
} |
||||||
|
|
||||||
|
type getBlockNumberRequest struct { |
||||||
|
pbReq *syncpb.Request |
||||||
|
} |
||||||
|
|
||||||
|
func newGetBlockNumberRequest() *getBlockNumberRequest { |
||||||
|
pbReq := syncpb.MakeGetBlockNumberRequest() |
||||||
|
return &getBlockNumberRequest{ |
||||||
|
pbReq: pbReq, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockNumberRequest) ReqID() uint64 { |
||||||
|
return req.pbReq.GetReqId() |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockNumberRequest) SetReqID(val uint64) { |
||||||
|
req.pbReq.ReqId = val |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockNumberRequest) String() string { |
||||||
|
return fmt.Sprintf("REQUEST [GetBlockNumber]") |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockNumberRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||||
|
return target.Version.GreaterThanOrEqual(MinVersion) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockNumberRequest) Encode() ([]byte, error) { |
||||||
|
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||||
|
return protobuf.Marshal(msg) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockNumberRequest) getNumberFromResponse(resp sttypes.Response) (uint64, error) { |
||||||
|
sResp, ok := resp.(*syncResponse) |
||||||
|
if !ok || sResp == nil { |
||||||
|
return 0, errors.New("not sync response") |
||||||
|
} |
||||||
|
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||||
|
return 0, errors.New(errResp.Error) |
||||||
|
} |
||||||
|
gnResp := sResp.pb.GetGetBlockNumberResponse() |
||||||
|
if gnResp == nil { |
||||||
|
return 0, errors.New("response not GetBlockNumber") |
||||||
|
} |
||||||
|
return gnResp.Number, nil |
||||||
|
} |
||||||
|
|
||||||
|
type getBlockHashesRequest struct { |
||||||
|
bns []uint64 |
||||||
|
pbReq *syncpb.Request |
||||||
|
} |
||||||
|
|
||||||
|
func newGetBlockHashesRequest(bns []uint64) *getBlockHashesRequest { |
||||||
|
pbReq := syncpb.MakeGetBlockHashesRequest(bns) |
||||||
|
return &getBlockHashesRequest{ |
||||||
|
bns: bns, |
||||||
|
pbReq: pbReq, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockHashesRequest) ReqID() uint64 { |
||||||
|
return req.pbReq.ReqId |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockHashesRequest) SetReqID(val uint64) { |
||||||
|
req.pbReq.ReqId = val |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockHashesRequest) String() string { |
||||||
|
ss := make([]string, 0, len(req.bns)) |
||||||
|
for _, bn := range req.bns { |
||||||
|
ss = append(ss, strconv.Itoa(int(bn))) |
||||||
|
} |
||||||
|
bnsStr := strings.Join(ss, ",") |
||||||
|
return fmt.Sprintf("REQUEST [GetBlockHashes: %s]", bnsStr) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockHashesRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||||
|
return target.Version.GreaterThanOrEqual(MinVersion) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockHashesRequest) Encode() ([]byte, error) { |
||||||
|
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||||
|
return protobuf.Marshal(msg) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlockHashesRequest) getHashesFromResponse(resp sttypes.Response) ([]common.Hash, error) { |
||||||
|
sResp, ok := resp.(*syncResponse) |
||||||
|
if !ok || sResp == nil { |
||||||
|
return nil, errors.New("not sync response") |
||||||
|
} |
||||||
|
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||||
|
return nil, errors.New(errResp.Error) |
||||||
|
} |
||||||
|
bhResp := sResp.pb.GetGetBlockHashesResponse() |
||||||
|
if bhResp == nil { |
||||||
|
return nil, errors.New("response not GetBlockHashes") |
||||||
|
} |
||||||
|
hashBytes := bhResp.Hashes |
||||||
|
return bytesToHashes(hashBytes), nil |
||||||
|
} |
||||||
|
|
||||||
|
type getBlocksByHashesRequest struct { |
||||||
|
hashes []common.Hash |
||||||
|
pbReq *syncpb.Request |
||||||
|
} |
||||||
|
|
||||||
|
func newGetBlocksByHashesRequest(hashes []common.Hash) *getBlocksByHashesRequest { |
||||||
|
pbReq := syncpb.MakeGetBlocksByHashesRequest(hashes) |
||||||
|
return &getBlocksByHashesRequest{ |
||||||
|
hashes: hashes, |
||||||
|
pbReq: pbReq, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByHashesRequest) ReqID() uint64 { |
||||||
|
return req.pbReq.GetReqId() |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByHashesRequest) SetReqID(val uint64) { |
||||||
|
req.pbReq.ReqId = val |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByHashesRequest) String() string { |
||||||
|
hashStrs := make([]string, 0, len(req.hashes)) |
||||||
|
for _, h := range req.hashes { |
||||||
|
hashStrs = append(hashStrs, fmt.Sprintf("%x", h[:])) |
||||||
|
} |
||||||
|
hStr := strings.Join(hashStrs, ", ") |
||||||
|
return fmt.Sprintf("REQUEST [GetBlocksByHashes: %v]", hStr) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByHashesRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||||
|
return target.Version.GreaterThanOrEqual(MinVersion) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByHashesRequest) Encode() ([]byte, error) { |
||||||
|
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||||
|
return protobuf.Marshal(msg) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *getBlocksByHashesRequest) getBlocksFromResponse(resp sttypes.Response) ([]*types.Block, error) { |
||||||
|
sResp, ok := resp.(*syncResponse) |
||||||
|
if !ok || sResp == nil { |
||||||
|
return nil, errors.New("not sync response") |
||||||
|
} |
||||||
|
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||||
|
return nil, errors.New(errResp.Error) |
||||||
|
} |
||||||
|
bhResp := sResp.pb.GetGetBlocksByHashesResponse() |
||||||
|
if bhResp == nil { |
||||||
|
return nil, errors.New("response not GetBlocksByHashes") |
||||||
|
} |
||||||
|
var ( |
||||||
|
blockBytes = bhResp.BlocksBytes |
||||||
|
sigs = bhResp.CommitSig |
||||||
|
) |
||||||
|
if len(blockBytes) != len(sigs) { |
||||||
|
return nil, fmt.Errorf("sig size not expected: %v / %v", len(sigs), len(blockBytes)) |
||||||
|
} |
||||||
|
blocks := make([]*types.Block, 0, len(blockBytes)) |
||||||
|
for i, bb := range blockBytes { |
||||||
|
var block *types.Block |
||||||
|
if err := rlp.DecodeBytes(bb, &block); err != nil { |
||||||
|
return nil, errors.Wrap(err, "[GetBlocksByHashesResponse]") |
||||||
|
} |
||||||
|
if block != nil { |
||||||
|
block.SetCurrentCommitSig(sigs[i]) |
||||||
|
} |
||||||
|
blocks = append(blocks, block) |
||||||
|
} |
||||||
|
return blocks, nil |
||||||
|
} |
@ -0,0 +1,462 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"math/big" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common" |
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
|
"github.com/harmony-one/harmony/block" |
||||||
|
headerV3 "github.com/harmony-one/harmony/block/v3" |
||||||
|
"github.com/harmony-one/harmony/core/types" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||||
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/harmony-one/harmony/shard" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_ sttypes.Request = &getBlocksByNumberRequest{} |
||||||
|
_ sttypes.Request = &getEpochBlockRequest{} |
||||||
|
_ sttypes.Request = &getBlockNumberRequest{} |
||||||
|
_ sttypes.Response = &syncResponse{&syncpb.Response{}} |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
initStreamIDs = []sttypes.StreamID{ |
||||||
|
makeTestStreamID(0), |
||||||
|
makeTestStreamID(1), |
||||||
|
makeTestStreamID(2), |
||||||
|
makeTestStreamID(3), |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
testHeader = &block.Header{Header: headerV3.NewHeader()} |
||||||
|
testBlock = types.NewBlockWithHeader(testHeader) |
||||||
|
testHeaderBytes, _ = rlp.EncodeToBytes(testHeader) |
||||||
|
testBlockBytes, _ = rlp.EncodeToBytes(testBlock) |
||||||
|
testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) |
||||||
|
|
||||||
|
testEpochState = &shard.State{ |
||||||
|
Epoch: new(big.Int).SetInt64(1), |
||||||
|
Shards: []shard.Committee{}, |
||||||
|
} |
||||||
|
testEpochStateBytes, _ = rlp.EncodeToBytes(testEpochState) |
||||||
|
testEpochStateResponse = syncpb.MakeGetEpochStateResponse(0, testHeaderBytes, testEpochStateBytes) |
||||||
|
|
||||||
|
testCurBlockNumber uint64 = 100 |
||||||
|
testBlockNumberResponse = syncpb.MakeGetBlockNumberResponse(0, testCurBlockNumber) |
||||||
|
|
||||||
|
testHash = numberToHash(100) |
||||||
|
testBlockHashesResponse = syncpb.MakeGetBlockHashesResponse(0, []common.Hash{testHash}) |
||||||
|
|
||||||
|
testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) |
||||||
|
|
||||||
|
testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error")) |
||||||
|
) |
||||||
|
|
||||||
|
func TestProtocol_GetBlocksByNumber(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
getResponse getResponseFn |
||||||
|
expErr error |
||||||
|
expStID sttypes.StreamID |
||||||
|
}{ |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: nil, |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testEpochStateResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("not GetBlockByNumber"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: nil, |
||||||
|
expErr: errors.New("get response error"), |
||||||
|
expStID: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testErrorResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("test error"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for i, test := range tests { |
||||||
|
protocol := makeTestProtocol(test.getResponse) |
||||||
|
blocks, stid, err := protocol.GetBlocksByNumber(context.Background(), []uint64{0}) |
||||||
|
|
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
continue |
||||||
|
} |
||||||
|
if stid != test.expStID { |
||||||
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||||
|
} |
||||||
|
if test.expErr == nil && (len(blocks) == 0) { |
||||||
|
t.Errorf("Test %v: zero blocks delivered", i) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestProtocol_GetEpochState(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
getResponse getResponseFn |
||||||
|
expErr error |
||||||
|
expStID sttypes.StreamID |
||||||
|
}{ |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testEpochStateResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: nil, |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("not GetEpochStateResponse"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: nil, |
||||||
|
expErr: errors.New("get response error"), |
||||||
|
expStID: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testErrorResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("test error"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for i, test := range tests { |
||||||
|
protocol := makeTestProtocol(test.getResponse) |
||||||
|
res, stid, err := protocol.GetEpochState(context.Background(), 0) |
||||||
|
|
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
continue |
||||||
|
} |
||||||
|
if stid != test.expStID { |
||||||
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||||
|
} |
||||||
|
if test.expErr == nil { |
||||||
|
if gotEpoch := res.State.Epoch; gotEpoch.Cmp(new(big.Int).SetUint64(1)) != 0 { |
||||||
|
t.Errorf("Test %v: unexpected epoch delivered: %v / %v", i, gotEpoch.String(), 1) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestProtocol_GetCurrentBlockNumber(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
getResponse getResponseFn |
||||||
|
expErr error |
||||||
|
expStID sttypes.StreamID |
||||||
|
}{ |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockNumberResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: nil, |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("not GetBlockNumber"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: nil, |
||||||
|
expErr: errors.New("get response error"), |
||||||
|
expStID: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testErrorResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("test error"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for i, test := range tests { |
||||||
|
protocol := makeTestProtocol(test.getResponse) |
||||||
|
res, stid, err := protocol.GetCurrentBlockNumber(context.Background()) |
||||||
|
|
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
continue |
||||||
|
} |
||||||
|
if stid != test.expStID { |
||||||
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||||
|
} |
||||||
|
if test.expErr == nil { |
||||||
|
if res != testCurBlockNumber { |
||||||
|
t.Errorf("Test %v: block number not expected: %v / %v", i, res, testCurBlockNumber) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestProtocol_GetBlockHashes(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
getResponse getResponseFn |
||||||
|
expErr error |
||||||
|
expStID sttypes.StreamID |
||||||
|
}{ |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockHashesResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: nil, |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("not GetBlockHashes"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: nil, |
||||||
|
expErr: errors.New("get response error"), |
||||||
|
expStID: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testErrorResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("test error"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for i, test := range tests { |
||||||
|
protocol := makeTestProtocol(test.getResponse) |
||||||
|
res, stid, err := protocol.GetBlockHashes(context.Background(), []uint64{100}) |
||||||
|
|
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
continue |
||||||
|
} |
||||||
|
if stid != test.expStID { |
||||||
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||||
|
} |
||||||
|
if test.expErr == nil { |
||||||
|
if len(res) != 1 { |
||||||
|
t.Errorf("Test %v: size not 1", i) |
||||||
|
} |
||||||
|
if res[0] != testHash { |
||||||
|
t.Errorf("Test %v: hash not expected", i) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestProtocol_GetBlocksByHashes(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
getResponse getResponseFn |
||||||
|
expErr error |
||||||
|
expStID sttypes.StreamID |
||||||
|
}{ |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlocksByHashesResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: nil, |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testBlockResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("not GetBlocksByHashes"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: nil, |
||||||
|
expErr: errors.New("get response error"), |
||||||
|
expStID: "", |
||||||
|
}, |
||||||
|
{ |
||||||
|
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||||
|
return &syncResponse{ |
||||||
|
pb: testErrorResponse, |
||||||
|
}, makeTestStreamID(0) |
||||||
|
}, |
||||||
|
expErr: errors.New("test error"), |
||||||
|
expStID: makeTestStreamID(0), |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for i, test := range tests { |
||||||
|
protocol := makeTestProtocol(test.getResponse) |
||||||
|
blocks, stid, err := protocol.GetBlocksByHashes(context.Background(), []common.Hash{numberToHash(100)}) |
||||||
|
|
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
continue |
||||||
|
} |
||||||
|
if stid != test.expStID { |
||||||
|
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||||
|
} |
||||||
|
if test.expErr == nil { |
||||||
|
if len(blocks) != 1 { |
||||||
|
t.Errorf("Test %v: size not 1", i) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) |
||||||
|
|
||||||
|
type testHostRequestManager struct { |
||||||
|
getResponse getResponseFn |
||||||
|
} |
||||||
|
|
||||||
|
func makeTestProtocol(f getResponseFn) *Protocol { |
||||||
|
rm := &testHostRequestManager{f} |
||||||
|
|
||||||
|
streamIDs := make([]sttypes.StreamID, len(initStreamIDs)) |
||||||
|
copy(streamIDs, initStreamIDs) |
||||||
|
sm := &testStreamManager{streamIDs} |
||||||
|
|
||||||
|
rl := ratelimiter.NewRateLimiter(10, 10) |
||||||
|
|
||||||
|
return &Protocol{ |
||||||
|
rm: rm, |
||||||
|
rl: rl, |
||||||
|
sm: sm, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *testHostRequestManager) Start() {} |
||||||
|
func (rm *testHostRequestManager) Close() {} |
||||||
|
func (rm *testHostRequestManager) DeliverResponse(sttypes.StreamID, sttypes.Response) {} |
||||||
|
|
||||||
|
func (rm *testHostRequestManager) DoRequest(ctx context.Context, request sttypes.Request, opts ...Option) (sttypes.Response, sttypes.StreamID, error) { |
||||||
|
if rm.getResponse == nil { |
||||||
|
return nil, "", errors.New("get response error") |
||||||
|
} |
||||||
|
resp, stid := rm.getResponse(request) |
||||||
|
return resp, stid, nil |
||||||
|
} |
||||||
|
|
||||||
|
func makeTestStreamID(index int) sttypes.StreamID { |
||||||
|
id := fmt.Sprintf("[test stream %v]", index) |
||||||
|
return sttypes.StreamID(id) |
||||||
|
} |
||||||
|
|
||||||
|
// mock stream manager
|
||||||
|
type testStreamManager struct { |
||||||
|
streamIDs []sttypes.StreamID |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) Start() {} |
||||||
|
func (sm *testStreamManager) Close() {} |
||||||
|
func (sm *testStreamManager) SubscribeAddStreamEvent(chan<- streammanager.EvtStreamAdded) event.Subscription { |
||||||
|
return nil |
||||||
|
} |
||||||
|
func (sm *testStreamManager) SubscribeRemoveStreamEvent(chan<- streammanager.EvtStreamRemoved) event.Subscription { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) NewStream(stream sttypes.Stream) error { |
||||||
|
stid := stream.ID() |
||||||
|
for _, id := range sm.streamIDs { |
||||||
|
if id == stid { |
||||||
|
return errors.New("stream already exist") |
||||||
|
} |
||||||
|
} |
||||||
|
sm.streamIDs = append(sm.streamIDs, stid) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) RemoveStream(stID sttypes.StreamID) error { |
||||||
|
for i, id := range sm.streamIDs { |
||||||
|
if id == stID { |
||||||
|
sm.streamIDs = append(sm.streamIDs[:i], sm.streamIDs[i+1:]...) |
||||||
|
} |
||||||
|
} |
||||||
|
return errors.New("stream not exist") |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) isStreamExist(stid sttypes.StreamID) bool { |
||||||
|
for _, id := range sm.streamIDs { |
||||||
|
if id == stid { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) GetStreams() []sttypes.Stream { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) { |
||||||
|
return nil, false |
||||||
|
} |
||||||
|
|
||||||
|
func assertError(got, expect error) error { |
||||||
|
if (got == nil) != (expect == nil) { |
||||||
|
return fmt.Errorf("unexpected error: %v / %v", got, expect) |
||||||
|
} |
||||||
|
if got == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
if !strings.Contains(got.Error(), expect.Error()) { |
||||||
|
return fmt.Errorf("unexpected error: %v/ %v", got, expect) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,17 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import "time" |
||||||
|
|
||||||
|
const ( |
||||||
|
// GetBlockHashesAmountCap is the cap of GetBlockHashes reqeust
|
||||||
|
GetBlockHashesAmountCap = 50 |
||||||
|
|
||||||
|
// GetBlocksByNumAmountCap is the cap of request of a single GetBlocksByNum request
|
||||||
|
GetBlocksByNumAmountCap = 10 |
||||||
|
|
||||||
|
// GetBlocksByHashesAmountCap is the cap of request of single GetBlocksByHashes request
|
||||||
|
GetBlocksByHashesAmountCap = 10 |
||||||
|
|
||||||
|
// minAdvertiseInterval is the minimum advertise interval
|
||||||
|
minAdvertiseInterval = 1 * time.Minute |
||||||
|
) |
@ -0,0 +1,260 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"strconv" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
"github.com/harmony-one/harmony/consensus/engine" |
||||||
|
nodeconfig "github.com/harmony-one/harmony/internal/configs/node" |
||||||
|
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" |
||||||
|
"github.com/harmony-one/harmony/internal/utils" |
||||||
|
"github.com/harmony-one/harmony/p2p/discovery" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/requestmanager" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/hashicorp/go-version" |
||||||
|
libp2p_host "github.com/libp2p/go-libp2p-core/host" |
||||||
|
libp2p_network "github.com/libp2p/go-libp2p-core/network" |
||||||
|
"github.com/rs/zerolog" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
// serviceSpecifier is the specifier for the service.
|
||||||
|
serviceSpecifier = "sync" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
version100, _ = version.NewVersion("1.0.0") |
||||||
|
|
||||||
|
// MyVersion is the version of sync protocol of the local node
|
||||||
|
MyVersion = version100 |
||||||
|
|
||||||
|
// MinVersion is the minimum version for matching function
|
||||||
|
MinVersion = version100 |
||||||
|
) |
||||||
|
|
||||||
|
type ( |
||||||
|
// Protocol is the protocol for sync streaming
|
||||||
|
Protocol struct { |
||||||
|
chain engine.ChainReader // provide SYNC data
|
||||||
|
schedule shardingconfig.Schedule // provide schedule information
|
||||||
|
rl ratelimiter.RateLimiter // limit the incoming request rate
|
||||||
|
sm streammanager.StreamManager // stream management
|
||||||
|
rm requestmanager.RequestManager // deliver the response from stream
|
||||||
|
disc discovery.Discovery |
||||||
|
|
||||||
|
config Config |
||||||
|
logger zerolog.Logger |
||||||
|
|
||||||
|
ctx context.Context |
||||||
|
cancel func() |
||||||
|
closeC chan struct{} |
||||||
|
} |
||||||
|
|
||||||
|
// Config is the sync protocol config
|
||||||
|
Config struct { |
||||||
|
Chain engine.ChainReader |
||||||
|
Host libp2p_host.Host |
||||||
|
Discovery discovery.Discovery |
||||||
|
ShardID nodeconfig.ShardID |
||||||
|
Network nodeconfig.NetworkType |
||||||
|
|
||||||
|
// stream manager config
|
||||||
|
SmSoftLowCap int |
||||||
|
SmHardLowCap int |
||||||
|
SmHiCap int |
||||||
|
DiscBatch int |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
// NewProtocol creates a new sync protocol
|
||||||
|
func NewProtocol(config Config) *Protocol { |
||||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||||
|
|
||||||
|
sp := &Protocol{ |
||||||
|
chain: config.Chain, |
||||||
|
disc: config.Discovery, |
||||||
|
config: config, |
||||||
|
ctx: ctx, |
||||||
|
cancel: cancel, |
||||||
|
closeC: make(chan struct{}), |
||||||
|
} |
||||||
|
smConfig := streammanager.Config{ |
||||||
|
SoftLoCap: config.SmSoftLowCap, |
||||||
|
HardLoCap: config.SmHardLowCap, |
||||||
|
HiCap: config.SmHiCap, |
||||||
|
DiscBatch: config.DiscBatch, |
||||||
|
} |
||||||
|
sp.sm = streammanager.NewStreamManager(sp.ProtoID(), config.Host, config.Discovery, |
||||||
|
sp.HandleStream, smConfig) |
||||||
|
|
||||||
|
sp.rl = ratelimiter.NewRateLimiter(sp.sm, 50, 10) |
||||||
|
|
||||||
|
sp.rm = requestmanager.NewRequestManager(sp.sm) |
||||||
|
|
||||||
|
sp.logger = utils.Logger().With().Str("Protocol", string(sp.ProtoID())).Logger() |
||||||
|
return sp |
||||||
|
} |
||||||
|
|
||||||
|
// Start starts the sync protocol
|
||||||
|
func (p *Protocol) Start() { |
||||||
|
p.sm.Start() |
||||||
|
p.rm.Start() |
||||||
|
p.rl.Start() |
||||||
|
go p.advertiseLoop() |
||||||
|
} |
||||||
|
|
||||||
|
// Close close the protocol
|
||||||
|
func (p *Protocol) Close() { |
||||||
|
p.rl.Close() |
||||||
|
p.rm.Close() |
||||||
|
p.sm.Close() |
||||||
|
p.cancel() |
||||||
|
close(p.closeC) |
||||||
|
} |
||||||
|
|
||||||
|
// Specifier return the specifier for the protocol
|
||||||
|
func (p *Protocol) Specifier() string { |
||||||
|
return serviceSpecifier + "/" + strconv.Itoa(int(p.config.ShardID)) |
||||||
|
} |
||||||
|
|
||||||
|
// ProtoID return the ProtoID of the sync protocol
|
||||||
|
func (p *Protocol) ProtoID() sttypes.ProtoID { |
||||||
|
return p.protoIDByVersion(MyVersion) |
||||||
|
} |
||||||
|
|
||||||
|
// Version returns the sync protocol version
|
||||||
|
func (p *Protocol) Version() *version.Version { |
||||||
|
return MyVersion |
||||||
|
} |
||||||
|
|
||||||
|
// Match checks the compatibility to the target protocol ID.
|
||||||
|
func (p *Protocol) Match(targetID string) bool { |
||||||
|
target, err := sttypes.ProtoIDToProtoSpec(sttypes.ProtoID(targetID)) |
||||||
|
if err != nil { |
||||||
|
return false |
||||||
|
} |
||||||
|
if target.Service != serviceSpecifier { |
||||||
|
return false |
||||||
|
} |
||||||
|
if target.NetworkType != p.config.Network { |
||||||
|
return false |
||||||
|
} |
||||||
|
if target.ShardID != p.config.ShardID { |
||||||
|
return false |
||||||
|
} |
||||||
|
if target.Version.LessThan(MinVersion) { |
||||||
|
return false |
||||||
|
} |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
// HandleStream is the stream handle function being registered to libp2p.
|
||||||
|
func (p *Protocol) HandleStream(raw libp2p_network.Stream) { |
||||||
|
p.logger.Info().Str("stream", raw.ID()).Msg("handle new sync stream") |
||||||
|
st := p.wrapStream(raw) |
||||||
|
if err := p.sm.NewStream(st); err != nil { |
||||||
|
// Possibly we have reach the hard limit of the stream
|
||||||
|
p.logger.Warn().Err(err).Str("stream ID", string(st.ID())). |
||||||
|
Msg("failed to add new stream") |
||||||
|
return |
||||||
|
} |
||||||
|
st.run() |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Protocol) advertiseLoop() { |
||||||
|
for { |
||||||
|
sleep := p.advertise() |
||||||
|
select { |
||||||
|
case <-time.After(sleep): |
||||||
|
case <-p.closeC: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// advertise will advertise all compatible protocol versions for helping nodes running low
|
||||||
|
// version
|
||||||
|
func (p *Protocol) advertise() time.Duration { |
||||||
|
var nextWait time.Duration |
||||||
|
|
||||||
|
pids := p.supportedProtoIDs() |
||||||
|
for _, pid := range pids { |
||||||
|
w, e := p.disc.Advertise(p.ctx, string(pid)) |
||||||
|
if e != nil { |
||||||
|
p.logger.Warn().Err(e).Str("protocol", string(pid)). |
||||||
|
Msg("cannot advertise sync protocol") |
||||||
|
continue |
||||||
|
} |
||||||
|
if nextWait == 0 || nextWait > w { |
||||||
|
nextWait = w |
||||||
|
} |
||||||
|
} |
||||||
|
if nextWait < minAdvertiseInterval { |
||||||
|
nextWait = minAdvertiseInterval |
||||||
|
} |
||||||
|
return nextWait |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Protocol) supportedProtoIDs() []sttypes.ProtoID { |
||||||
|
vs := p.supportedVersions() |
||||||
|
|
||||||
|
pids := make([]sttypes.ProtoID, 0, len(vs)) |
||||||
|
for _, v := range vs { |
||||||
|
pids = append(pids, p.protoIDByVersion(v)) |
||||||
|
} |
||||||
|
return pids |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Protocol) supportedVersions() []*version.Version { |
||||||
|
return []*version.Version{version100} |
||||||
|
} |
||||||
|
|
||||||
|
func (p *Protocol) protoIDByVersion(v *version.Version) sttypes.ProtoID { |
||||||
|
spec := sttypes.ProtoSpec{ |
||||||
|
Service: serviceSpecifier, |
||||||
|
NetworkType: p.config.Network, |
||||||
|
ShardID: p.config.ShardID, |
||||||
|
Version: v, |
||||||
|
} |
||||||
|
return spec.ToProtoID() |
||||||
|
} |
||||||
|
|
||||||
|
// RemoveStream removes the stream of the given stream ID
|
||||||
|
func (p *Protocol) RemoveStream(stID sttypes.StreamID) { |
||||||
|
if stID == "" { |
||||||
|
return |
||||||
|
} |
||||||
|
st, exist := p.sm.GetStreamByID(stID) |
||||||
|
if exist && st != nil { |
||||||
|
st.Close() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// NumStreams return the streams with minimum version.
|
||||||
|
// Note: nodes with sync version smaller than minVersion is not counted.
|
||||||
|
func (p *Protocol) NumStreams() int { |
||||||
|
res := 0 |
||||||
|
sts := p.sm.GetStreams() |
||||||
|
|
||||||
|
for _, st := range sts { |
||||||
|
ps, _ := st.ProtoSpec() |
||||||
|
if ps.Version.GreaterThanOrEqual(MinVersion) { |
||||||
|
res++ |
||||||
|
} |
||||||
|
} |
||||||
|
return res |
||||||
|
} |
||||||
|
|
||||||
|
// GetStreamManager get the underlying stream manager for upper level stream operations
|
||||||
|
func (p *Protocol) GetStreamManager() streammanager.StreamManager { |
||||||
|
return p.sm |
||||||
|
} |
||||||
|
|
||||||
|
// SubscribeAddStreamEvent subscribe the stream add event
|
||||||
|
func (p *Protocol) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription { |
||||||
|
return p.sm.SubscribeAddStreamEvent(ch) |
||||||
|
} |
@ -0,0 +1,95 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p-core/discovery" |
||||||
|
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||||
|
) |
||||||
|
|
||||||
|
func TestProtocol_Match(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
targetID string |
||||||
|
exp bool |
||||||
|
}{ |
||||||
|
{"harmony/sync/unitest/0/1.0.1", true}, |
||||||
|
{"h123456", false}, |
||||||
|
{"harmony/sync/unitest/0/0.9.9", false}, |
||||||
|
{"harmony/epoch/unitest/0/1.0.1", false}, |
||||||
|
{"harmony/sync/mainnet/0/1.0.1", false}, |
||||||
|
{"harmony/sync/unitest/1/1.0.1", false}, |
||||||
|
} |
||||||
|
|
||||||
|
for i, test := range tests { |
||||||
|
p := &Protocol{ |
||||||
|
config: Config{ |
||||||
|
Network: "unitest", |
||||||
|
ShardID: 0, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
res := p.Match(test.targetID) |
||||||
|
|
||||||
|
if res != test.exp { |
||||||
|
t.Errorf("Test %v: unexpected result %v / %v", i, res, test.exp) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestProtocol_advertiseLoop(t *testing.T) { |
||||||
|
disc := newTestDiscovery(100 * time.Millisecond) |
||||||
|
p := &Protocol{ |
||||||
|
disc: disc, |
||||||
|
closeC: make(chan struct{}), |
||||||
|
} |
||||||
|
|
||||||
|
go p.advertiseLoop() |
||||||
|
|
||||||
|
time.Sleep(150 * time.Millisecond) |
||||||
|
close(p.closeC) |
||||||
|
|
||||||
|
if len(disc.advCnt) != len(p.supportedVersions()) { |
||||||
|
t.Errorf("unexpected advertise topic count: %v / %v", len(disc.advCnt), |
||||||
|
len(p.supportedVersions())) |
||||||
|
} |
||||||
|
for _, cnt := range disc.advCnt { |
||||||
|
if cnt < 1 { |
||||||
|
t.Errorf("unexpected discovery count: %v", cnt) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type testDiscovery struct { |
||||||
|
advCnt map[string]int |
||||||
|
sleep time.Duration |
||||||
|
} |
||||||
|
|
||||||
|
func newTestDiscovery(discInterval time.Duration) *testDiscovery { |
||||||
|
return &testDiscovery{ |
||||||
|
advCnt: make(map[string]int), |
||||||
|
sleep: discInterval, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (disc *testDiscovery) Start() error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (disc *testDiscovery) Close() error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) { |
||||||
|
disc.advCnt[ns]++ |
||||||
|
return disc.sleep, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (disc *testDiscovery) GetRawDiscovery() discovery.Discovery { |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,377 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"sync/atomic" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common" |
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
|
protobuf "github.com/golang/protobuf/proto" |
||||||
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
libp2p_network "github.com/libp2p/go-libp2p-core/network" |
||||||
|
"github.com/pkg/errors" |
||||||
|
"github.com/rs/zerolog" |
||||||
|
) |
||||||
|
|
||||||
|
// syncStream is the structure for a stream running sync protocol.
|
||||||
|
type syncStream struct { |
||||||
|
// Basic stream
|
||||||
|
*sttypes.BaseStream |
||||||
|
|
||||||
|
protocol *Protocol |
||||||
|
chain chainHelper |
||||||
|
|
||||||
|
// pipeline channels
|
||||||
|
reqC chan *syncpb.Request |
||||||
|
respC chan *syncpb.Response |
||||||
|
|
||||||
|
// close related fields. Concurrent call of close is possible.
|
||||||
|
closeC chan struct{} |
||||||
|
closeStat uint32 |
||||||
|
|
||||||
|
logger zerolog.Logger |
||||||
|
} |
||||||
|
|
||||||
|
// wrapStream wraps the raw libp2p stream to syncStream
|
||||||
|
func (p *Protocol) wrapStream(raw libp2p_network.Stream) *syncStream { |
||||||
|
bs := sttypes.NewBaseStream(raw) |
||||||
|
logger := p.logger.With(). |
||||||
|
Str("ID", string(bs.ID())). |
||||||
|
Str("Remote Protocol", string(bs.ProtoID())). |
||||||
|
Logger() |
||||||
|
|
||||||
|
return &syncStream{ |
||||||
|
BaseStream: bs, |
||||||
|
protocol: p, |
||||||
|
chain: newChainHelper(p.chain, p.schedule), |
||||||
|
reqC: make(chan *syncpb.Request, 100), |
||||||
|
respC: make(chan *syncpb.Response, 100), |
||||||
|
closeC: make(chan struct{}), |
||||||
|
closeStat: 0, |
||||||
|
logger: logger, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) run() { |
||||||
|
st.logger.Info().Str("StreamID", string(st.ID())).Msg("running sync protocol on stream") |
||||||
|
defer st.logger.Info().Str("StreamID", string(st.ID())).Msg("end running sync protocol on stream") |
||||||
|
|
||||||
|
go st.handleReqLoop() |
||||||
|
go st.handleRespLoop() |
||||||
|
st.readMsgLoop() |
||||||
|
} |
||||||
|
|
||||||
|
// readMsgLoop is the loop
|
||||||
|
func (st *syncStream) readMsgLoop() { |
||||||
|
for { |
||||||
|
msg, err := st.readMsg() |
||||||
|
if err != nil { |
||||||
|
if err := st.Close(); err != nil { |
||||||
|
st.logger.Err(err).Msg("failed to close sync stream") |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
st.deliverMsg(msg) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// deliverMsg process the delivered message and forward to the corresponding channel
|
||||||
|
func (st *syncStream) deliverMsg(msg protobuf.Message) { |
||||||
|
syncMsg := msg.(*syncpb.Message) |
||||||
|
if syncMsg == nil { |
||||||
|
st.logger.Info().Str("message", msg.String()).Msg("received unexpected sync message") |
||||||
|
return |
||||||
|
} |
||||||
|
if req := syncMsg.GetReq(); req != nil { |
||||||
|
go func() { |
||||||
|
select { |
||||||
|
case st.reqC <- req: |
||||||
|
case <-time.After(1 * time.Minute): |
||||||
|
st.logger.Warn().Str("request", req.String()). |
||||||
|
Msg("request handler severely jammed, message dropped") |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
if resp := syncMsg.GetResp(); resp != nil { |
||||||
|
go func() { |
||||||
|
select { |
||||||
|
case st.respC <- resp: |
||||||
|
case <-time.After(1 * time.Minute): |
||||||
|
st.logger.Warn().Str("response", resp.String()). |
||||||
|
Msg("response handler severely jammed, message dropped") |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleReqLoop() { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case req := <-st.reqC: |
||||||
|
st.protocol.rl.LimitRequest(st.ID()) |
||||||
|
err := st.handleReq(req) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
st.logger.Info().Err(err).Str("request", req.String()). |
||||||
|
Msg("handle request error. Closing stream") |
||||||
|
if err := st.Close(); err != nil { |
||||||
|
st.logger.Err(err).Msg("failed to close sync stream") |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
case <-st.closeC: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleRespLoop() { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case resp := <-st.respC: |
||||||
|
st.handleResp(resp) |
||||||
|
|
||||||
|
case <-st.closeC: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Close stops the stream handling and closes the underlying stream
|
||||||
|
func (st *syncStream) Close() error { |
||||||
|
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1) |
||||||
|
if !notClosed { |
||||||
|
// Already closed by another goroutine. Directly return
|
||||||
|
return nil |
||||||
|
} |
||||||
|
if err := st.protocol.sm.RemoveStream(st.ID()); err != nil { |
||||||
|
st.logger.Err(err).Str("stream ID", string(st.ID())). |
||||||
|
Msg("failed to remove sync stream on close") |
||||||
|
} |
||||||
|
close(st.closeC) |
||||||
|
return st.BaseStream.Close() |
||||||
|
} |
||||||
|
|
||||||
|
// ResetOnClose reset the stream on close
|
||||||
|
func (st *syncStream) ResetOnClose() error { |
||||||
|
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1) |
||||||
|
if !notClosed { |
||||||
|
// Already closed by another goroutine. Directly return
|
||||||
|
return nil |
||||||
|
} |
||||||
|
close(st.closeC) |
||||||
|
return st.BaseStream.ResetOnClose() |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleReq(req *syncpb.Request) error { |
||||||
|
if gnReq := req.GetGetBlockNumberRequest(); gnReq != nil { |
||||||
|
return st.handleGetBlockNumberRequest(req.ReqId) |
||||||
|
} |
||||||
|
if ghReq := req.GetGetBlockHashesRequest(); ghReq != nil { |
||||||
|
return st.handleGetBlockHashesRequest(req.ReqId, ghReq) |
||||||
|
} |
||||||
|
if bnReq := req.GetGetBlocksByNumRequest(); bnReq != nil { |
||||||
|
return st.handleGetBlocksByNumRequest(req.ReqId, bnReq) |
||||||
|
} |
||||||
|
if bhReq := req.GetGetBlocksByHashesRequest(); bhReq != nil { |
||||||
|
return st.handleGetBlocksByHashesRequest(req.ReqId, bhReq) |
||||||
|
} |
||||||
|
if esReq := req.GetGetEpochStateRequest(); esReq != nil { |
||||||
|
return st.handleEpochStateRequest(req.ReqId, esReq) |
||||||
|
} |
||||||
|
// unsupported request type
|
||||||
|
resp := syncpb.MakeErrorResponseMessage(req.ReqId, errUnknownReqType) |
||||||
|
return st.writeMsg(resp) |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleGetBlockNumberRequest(rid uint64) error { |
||||||
|
resp := st.computeBlockNumberResp(rid) |
||||||
|
if err := st.writeMsg(resp); err != nil { |
||||||
|
return errors.Wrap(err, "[GetBlockNumber]: writeMsg") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleGetBlockHashesRequest(rid uint64, req *syncpb.GetBlockHashesRequest) error { |
||||||
|
resp, err := st.computeGetBlockHashesResp(rid, req.Nums) |
||||||
|
if err != nil { |
||||||
|
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||||
|
} |
||||||
|
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||||
|
if err == nil { |
||||||
|
err = writeErr |
||||||
|
} else { |
||||||
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||||
|
} |
||||||
|
} |
||||||
|
return errors.Wrap(err, "[GetBlockHashes]") |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleGetBlocksByNumRequest(rid uint64, req *syncpb.GetBlocksByNumRequest) error { |
||||||
|
resp, err := st.computeRespFromBlockNumber(rid, req.Nums) |
||||||
|
if resp == nil && err != nil { |
||||||
|
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||||
|
} |
||||||
|
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||||
|
if err == nil { |
||||||
|
err = writeErr |
||||||
|
} else { |
||||||
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||||
|
} |
||||||
|
} |
||||||
|
return errors.Wrap(err, "[GetBlocksByNumber]") |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleGetBlocksByHashesRequest(rid uint64, req *syncpb.GetBlocksByHashesRequest) error { |
||||||
|
hashes := bytesToHashes(req.BlockHashes) |
||||||
|
resp, err := st.computeRespFromBlockHashes(rid, hashes) |
||||||
|
if resp == nil && err != nil { |
||||||
|
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||||
|
} |
||||||
|
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||||
|
if err == nil { |
||||||
|
err = writeErr |
||||||
|
} else { |
||||||
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||||
|
} |
||||||
|
} |
||||||
|
return errors.Wrap(err, "[GetBlocksByHashes]") |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleEpochStateRequest(rid uint64, req *syncpb.GetEpochStateRequest) error { |
||||||
|
resp, err := st.computeEpochStateResp(rid, req.Epoch) |
||||||
|
if err != nil { |
||||||
|
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||||
|
} |
||||||
|
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||||
|
if err == nil { |
||||||
|
err = writeErr |
||||||
|
} else { |
||||||
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||||
|
} |
||||||
|
} |
||||||
|
return errors.Wrap(err, "[GetEpochState]") |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) handleResp(resp *syncpb.Response) { |
||||||
|
st.protocol.rm.DeliverResponse(st.ID(), &syncResponse{resp}) |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) readMsg() (*syncpb.Message, error) { |
||||||
|
b, err := st.ReadBytes() |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
var msg = &syncpb.Message{} |
||||||
|
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return msg, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) writeMsg(msg *syncpb.Message) error { |
||||||
|
b, err := protobuf.Marshal(msg) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return st.WriteBytes(b) |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) computeBlockNumberResp(rid uint64) *syncpb.Message { |
||||||
|
bn := st.chain.getCurrentBlockNumber() |
||||||
|
return syncpb.MakeGetBlockNumberResponseMessage(rid, bn) |
||||||
|
} |
||||||
|
|
||||||
|
func (st syncStream) computeGetBlockHashesResp(rid uint64, bns []uint64) (*syncpb.Message, error) { |
||||||
|
if len(bns) > GetBlockHashesAmountCap { |
||||||
|
err := fmt.Errorf("GetBlockHashes amount exceed cap: %v>%v", len(bns), GetBlockHashesAmountCap) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
hashes := st.chain.getBlockHashes(bns) |
||||||
|
return syncpb.MakeGetBlockHashesResponseMessage(rid, hashes), nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) computeRespFromBlockNumber(rid uint64, bns []uint64) (*syncpb.Message, error) { |
||||||
|
if len(bns) > GetBlocksByNumAmountCap { |
||||||
|
err := fmt.Errorf("GetBlocksByNum amount exceed cap: %v>%v", len(bns), GetBlocksByNumAmountCap) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
blocks, err := st.chain.getBlocksByNumber(bns) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
blocksBytes = make([][]byte, 0, len(blocks)) |
||||||
|
sigs = make([][]byte, 0, len(blocks)) |
||||||
|
) |
||||||
|
for _, block := range blocks { |
||||||
|
bb, err := rlp.EncodeToBytes(block) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
blocksBytes = append(blocksBytes, bb) |
||||||
|
|
||||||
|
var sig []byte |
||||||
|
if block != nil { |
||||||
|
sig = block.GetCurrentCommitSig() |
||||||
|
} |
||||||
|
sigs = append(sigs, sig) |
||||||
|
} |
||||||
|
return syncpb.MakeGetBlocksByNumResponseMessage(rid, blocksBytes, sigs), nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) computeRespFromBlockHashes(rid uint64, hs []common.Hash) (*syncpb.Message, error) { |
||||||
|
if len(hs) > GetBlocksByHashesAmountCap { |
||||||
|
err := fmt.Errorf("GetBlockByHashes amount exceed cap: %v > %v", len(hs), GetBlocksByHashesAmountCap) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
blocks, err := st.chain.getBlocksByHashes(hs) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
blocksBytes = make([][]byte, 0, len(blocks)) |
||||||
|
sigs = make([][]byte, 0, len(blocks)) |
||||||
|
) |
||||||
|
for _, block := range blocks { |
||||||
|
bb, err := rlp.EncodeToBytes(block) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
blocksBytes = append(blocksBytes, bb) |
||||||
|
|
||||||
|
var sig []byte |
||||||
|
if block != nil { |
||||||
|
sig = block.GetCurrentCommitSig() |
||||||
|
} |
||||||
|
sigs = append(sigs, sig) |
||||||
|
} |
||||||
|
return syncpb.MakeGetBlocksByHashesResponseMessage(rid, blocksBytes, sigs), nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *syncStream) computeEpochStateResp(rid uint64, epoch uint64) (*syncpb.Message, error) { |
||||||
|
if epoch == 0 { |
||||||
|
return nil, errors.New("Epoch 0 does not have shard state") |
||||||
|
} |
||||||
|
esRes, err := st.chain.getEpochState(epoch) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return esRes.toMessage(rid) |
||||||
|
} |
||||||
|
|
||||||
|
func bytesToHashes(bs [][]byte) []common.Hash { |
||||||
|
hs := make([]common.Hash, 0, len(bs)) |
||||||
|
for _, b := range bs { |
||||||
|
var h common.Hash |
||||||
|
copy(h[:], b) |
||||||
|
hs = append(hs, h) |
||||||
|
} |
||||||
|
return hs |
||||||
|
} |
@ -0,0 +1,267 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"context" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/common" |
||||||
|
protobuf "github.com/golang/protobuf/proto" |
||||||
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
ic "github.com/libp2p/go-libp2p-core/crypto" |
||||||
|
libp2p_network "github.com/libp2p/go-libp2p-core/network" |
||||||
|
"github.com/libp2p/go-libp2p-core/peer" |
||||||
|
"github.com/libp2p/go-libp2p-core/protocol" |
||||||
|
ma "github.com/multiformats/go-multiaddr" |
||||||
|
) |
||||||
|
|
||||||
|
var _ sttypes.Protocol = &Protocol{} |
||||||
|
|
||||||
|
var ( |
||||||
|
testGetBlockNumbers = []uint64{1, 2, 3, 4, 5} |
||||||
|
testGetBlockRequest = syncpb.MakeGetBlocksByNumRequest(testGetBlockNumbers) |
||||||
|
testGetBlockRequestMsg = syncpb.MakeMessageFromRequest(testGetBlockRequest) |
||||||
|
|
||||||
|
testEpoch uint64 = 20 |
||||||
|
testEpochStateRequest = syncpb.MakeGetEpochStateRequest(testEpoch) |
||||||
|
testEpochStateRequestMsg = syncpb.MakeMessageFromRequest(testEpochStateRequest) |
||||||
|
|
||||||
|
testCurrentNumberRequest = syncpb.MakeGetBlockNumberRequest() |
||||||
|
testCurrentNumberRequestMsg = syncpb.MakeMessageFromRequest(testCurrentNumberRequest) |
||||||
|
|
||||||
|
testGetBlockHashNums = []uint64{1, 2, 3, 4, 5} |
||||||
|
testGetBlockHashesRequest = syncpb.MakeGetBlockHashesRequest(testGetBlockHashNums) |
||||||
|
testGetBlockHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlockHashesRequest) |
||||||
|
|
||||||
|
testGetBlockByHashes = []common.Hash{ |
||||||
|
numberToHash(1), |
||||||
|
numberToHash(2), |
||||||
|
numberToHash(3), |
||||||
|
numberToHash(4), |
||||||
|
numberToHash(5), |
||||||
|
} |
||||||
|
testGetBlocksByHashesRequest = syncpb.MakeGetBlocksByHashesRequest(testGetBlockByHashes) |
||||||
|
testGetBlocksByHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlocksByHashesRequest) |
||||||
|
) |
||||||
|
|
||||||
|
func TestSyncStream_HandleGetBlocksByRequest(t *testing.T) { |
||||||
|
st, remoteSt := makeTestSyncStream() |
||||||
|
|
||||||
|
go st.run() |
||||||
|
defer close(st.closeC) |
||||||
|
|
||||||
|
req := testGetBlockRequestMsg |
||||||
|
b, _ := protobuf.Marshal(req) |
||||||
|
err := remoteSt.WriteBytes(b) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond) |
||||||
|
receivedBytes, _ := remoteSt.ReadBytes() |
||||||
|
|
||||||
|
if err := checkBlocksResult(testGetBlockNumbers, receivedBytes); err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestSyncStream_HandleEpochStateRequest(t *testing.T) { |
||||||
|
st, remoteSt := makeTestSyncStream() |
||||||
|
|
||||||
|
go st.run() |
||||||
|
defer close(st.closeC) |
||||||
|
|
||||||
|
req := testEpochStateRequestMsg |
||||||
|
b, _ := protobuf.Marshal(req) |
||||||
|
err := remoteSt.WriteBytes(b) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond) |
||||||
|
receivedBytes, _ := remoteSt.ReadBytes() |
||||||
|
|
||||||
|
if err := checkEpochStateResult(testEpoch, receivedBytes); err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestSyncStream_HandleCurrentBlockNumber(t *testing.T) { |
||||||
|
st, remoteSt := makeTestSyncStream() |
||||||
|
|
||||||
|
go st.run() |
||||||
|
defer close(st.closeC) |
||||||
|
|
||||||
|
req := testCurrentNumberRequestMsg |
||||||
|
b, _ := protobuf.Marshal(req) |
||||||
|
err := remoteSt.WriteBytes(b) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond) |
||||||
|
receivedBytes, _ := remoteSt.ReadBytes() |
||||||
|
|
||||||
|
if err := checkBlockNumberResult(receivedBytes); err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestSyncStream_HandleGetBlockHashes(t *testing.T) { |
||||||
|
st, remoteSt := makeTestSyncStream() |
||||||
|
|
||||||
|
go st.run() |
||||||
|
defer close(st.closeC) |
||||||
|
|
||||||
|
req := testGetBlockHashesRequestMsg |
||||||
|
b, _ := protobuf.Marshal(req) |
||||||
|
err := remoteSt.WriteBytes(b) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond) |
||||||
|
receivedBytes, _ := remoteSt.ReadBytes() |
||||||
|
|
||||||
|
if err := checkBlockHashesResult(receivedBytes, testGetBlockNumbers); err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestSyncStream_HandleGetBlocksByHashes(t *testing.T) { |
||||||
|
st, remoteSt := makeTestSyncStream() |
||||||
|
|
||||||
|
go st.run() |
||||||
|
defer close(st.closeC) |
||||||
|
|
||||||
|
req := testGetBlocksByHashesRequestMsg |
||||||
|
b, _ := protobuf.Marshal(req) |
||||||
|
err := remoteSt.WriteBytes(b) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond) |
||||||
|
receivedBytes, _ := remoteSt.ReadBytes() |
||||||
|
|
||||||
|
if err := checkBlocksByHashesResult(receivedBytes, testGetBlockByHashes); err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func makeTestSyncStream() (*syncStream, *testRemoteBaseStream) { |
||||||
|
localRaw, remoteRaw := makePairP2PStreams() |
||||||
|
remote := newTestRemoteBaseStream(remoteRaw) |
||||||
|
|
||||||
|
bs := sttypes.NewBaseStream(localRaw) |
||||||
|
|
||||||
|
return &syncStream{ |
||||||
|
BaseStream: bs, |
||||||
|
chain: &testChainHelper{}, |
||||||
|
protocol: makeTestProtocol(nil), |
||||||
|
reqC: make(chan *syncpb.Request, 100), |
||||||
|
respC: make(chan *syncpb.Response, 100), |
||||||
|
closeC: make(chan struct{}), |
||||||
|
closeStat: 0, |
||||||
|
}, remote |
||||||
|
} |
||||||
|
|
||||||
|
type testP2PStream struct { |
||||||
|
readBuf *bytes.Buffer |
||||||
|
inC chan struct{} |
||||||
|
identity string |
||||||
|
|
||||||
|
writeHook func([]byte) (int, error) |
||||||
|
} |
||||||
|
|
||||||
|
func makePairP2PStreams() (*testP2PStream, *testP2PStream) { |
||||||
|
buf1 := bytes.NewBuffer(nil) |
||||||
|
buf2 := bytes.NewBuffer(nil) |
||||||
|
|
||||||
|
st1 := &testP2PStream{ |
||||||
|
readBuf: buf1, |
||||||
|
inC: make(chan struct{}, 1), |
||||||
|
identity: "local", |
||||||
|
} |
||||||
|
st2 := &testP2PStream{ |
||||||
|
readBuf: buf2, |
||||||
|
inC: make(chan struct{}, 1), |
||||||
|
identity: "remote", |
||||||
|
} |
||||||
|
st1.writeHook = st2.receiveBytes |
||||||
|
st2.writeHook = st1.receiveBytes |
||||||
|
return st1, st2 |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testP2PStream) Read(b []byte) (n int, err error) { |
||||||
|
<-st.inC |
||||||
|
n, err = st.readBuf.Read(b) |
||||||
|
if st.readBuf.Len() != 0 { |
||||||
|
select { |
||||||
|
case st.inC <- struct{}{}: |
||||||
|
default: |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testP2PStream) Write(b []byte) (n int, err error) { |
||||||
|
return st.writeHook(b) |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testP2PStream) receiveBytes(b []byte) (n int, err error) { |
||||||
|
n, err = st.readBuf.Write(b) |
||||||
|
select { |
||||||
|
case st.inC <- struct{}{}: |
||||||
|
default: |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testP2PStream) Close() error { return nil } |
||||||
|
func (st *testP2PStream) CloseRead() error { return nil } |
||||||
|
func (st *testP2PStream) CloseWrite() error { return nil } |
||||||
|
func (st *testP2PStream) Reset() error { return nil } |
||||||
|
func (st *testP2PStream) SetDeadline(time.Time) error { return nil } |
||||||
|
func (st *testP2PStream) SetReadDeadline(time.Time) error { return nil } |
||||||
|
func (st *testP2PStream) SetWriteDeadline(time.Time) error { return nil } |
||||||
|
func (st *testP2PStream) ID() string { return "" } |
||||||
|
func (st *testP2PStream) Protocol() protocol.ID { return "" } |
||||||
|
func (st *testP2PStream) SetProtocol(protocol.ID) {} |
||||||
|
func (st *testP2PStream) Stat() libp2p_network.Stat { return libp2p_network.Stat{} } |
||||||
|
func (st *testP2PStream) Conn() libp2p_network.Conn { return &fakeConn{} } |
||||||
|
|
||||||
|
type testRemoteBaseStream struct { |
||||||
|
base *sttypes.BaseStream |
||||||
|
} |
||||||
|
|
||||||
|
func newTestRemoteBaseStream(st *testP2PStream) *testRemoteBaseStream { |
||||||
|
rst := &testRemoteBaseStream{ |
||||||
|
base: sttypes.NewBaseStream(st), |
||||||
|
} |
||||||
|
return rst |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testRemoteBaseStream) ReadBytes() ([]byte, error) { |
||||||
|
return st.base.ReadBytes() |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testRemoteBaseStream) WriteBytes(b []byte) error { |
||||||
|
return st.base.WriteBytes(b) |
||||||
|
} |
||||||
|
|
||||||
|
type fakeConn struct{} |
||||||
|
|
||||||
|
func (conn *fakeConn) Close() error { return nil } |
||||||
|
func (conn *fakeConn) LocalPeer() peer.ID { return "" } |
||||||
|
func (conn *fakeConn) LocalPrivateKey() ic.PrivKey { return nil } |
||||||
|
func (conn *fakeConn) RemotePeer() peer.ID { return "" } |
||||||
|
func (conn *fakeConn) RemotePublicKey() ic.PubKey { return nil } |
||||||
|
func (conn *fakeConn) LocalMultiaddr() ma.Multiaddr { return nil } |
||||||
|
func (conn *fakeConn) RemoteMultiaddr() ma.Multiaddr { return nil } |
||||||
|
func (conn *fakeConn) ID() string { return "" } |
||||||
|
func (conn *fakeConn) NewStream(context.Context) (libp2p_network.Stream, error) { return nil, nil } |
||||||
|
func (conn *fakeConn) GetStreams() []libp2p_network.Stream { return nil } |
||||||
|
func (conn *fakeConn) Stat() libp2p_network.Stat { return libp2p_network.Stat{} } |
@ -0,0 +1,112 @@ |
|||||||
|
package sync |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
|
protobuf "github.com/golang/protobuf/proto" |
||||||
|
"github.com/harmony-one/harmony/block" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/requestmanager" |
||||||
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/harmony-one/harmony/shard" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
errUnknownReqType = errors.New("unknown request") |
||||||
|
) |
||||||
|
|
||||||
|
// syncResponse is the sync protocol response which implements sttypes.Response
|
||||||
|
type syncResponse struct { |
||||||
|
pb *syncpb.Response |
||||||
|
} |
||||||
|
|
||||||
|
// ReqID return the request ID of the response
|
||||||
|
func (resp *syncResponse) ReqID() uint64 { |
||||||
|
return resp.pb.ReqId |
||||||
|
} |
||||||
|
|
||||||
|
// GetProtobufMsg return the raw protobuf message
|
||||||
|
func (resp *syncResponse) GetProtobufMsg() protobuf.Message { |
||||||
|
return resp.pb |
||||||
|
} |
||||||
|
|
||||||
|
func (resp *syncResponse) String() string { |
||||||
|
return fmt.Sprintf("[SyncResponse %v]", resp.pb.String()) |
||||||
|
} |
||||||
|
|
||||||
|
// EpochStateResult is the result for GetEpochStateQuery
|
||||||
|
type EpochStateResult struct { |
||||||
|
Header *block.Header |
||||||
|
State *shard.State |
||||||
|
} |
||||||
|
|
||||||
|
func epochStateResultFromResponse(resp sttypes.Response) (*EpochStateResult, error) { |
||||||
|
sResp, ok := resp.(*syncResponse) |
||||||
|
if !ok || sResp == nil { |
||||||
|
return nil, errors.New("not sync response") |
||||||
|
} |
||||||
|
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||||
|
return nil, errors.New(errResp.Error) |
||||||
|
} |
||||||
|
gesResp := sResp.pb.GetGetEpochStateResponse() |
||||||
|
if gesResp == nil { |
||||||
|
return nil, errors.New("not GetEpochStateResponse") |
||||||
|
} |
||||||
|
var ( |
||||||
|
headerBytes = gesResp.HeaderBytes |
||||||
|
ssBytes = gesResp.ShardState |
||||||
|
|
||||||
|
header *block.Header |
||||||
|
ss *shard.State |
||||||
|
) |
||||||
|
if len(headerBytes) > 0 { |
||||||
|
if err := rlp.DecodeBytes(headerBytes, &header); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
if len(ssBytes) > 0 { |
||||||
|
// here shard state is not encoded with legacy rules
|
||||||
|
if err := rlp.DecodeBytes(ssBytes, &ss); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
} |
||||||
|
return &EpochStateResult{ |
||||||
|
Header: header, |
||||||
|
State: ss, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (res *EpochStateResult) toMessage(rid uint64) (*syncpb.Message, error) { |
||||||
|
headerBytes, err := rlp.EncodeToBytes(res.Header) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
// Shard state is not wrapped here, means no legacy shard state encoding rule as
|
||||||
|
// in shard.EncodeWrapper.
|
||||||
|
ssBytes, err := rlp.EncodeToBytes(res.State) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return syncpb.MakeGetEpochStateResponseMessage(rid, headerBytes, ssBytes), nil |
||||||
|
} |
||||||
|
|
||||||
|
// Option is the additional option to do requests.
|
||||||
|
// Currently, two options are supported:
|
||||||
|
// 1. WithHighPriority - do the request in high priority.
|
||||||
|
// 2. WithBlacklist - do the request without the given stream ids as blacklist
|
||||||
|
// 3. WithWhitelist - do the request only with the given stream ids
|
||||||
|
type Option = requestmanager.RequestOption |
||||||
|
|
||||||
|
var ( |
||||||
|
// WithHighPriority instruct the request manager to do the request with high
|
||||||
|
// priority
|
||||||
|
WithHighPriority = requestmanager.WithHighPriority |
||||||
|
// WithBlacklist instruct the request manager not to assign the request to the
|
||||||
|
// given streamID
|
||||||
|
WithBlacklist = requestmanager.WithBlacklist |
||||||
|
// WithWhitelist instruct the request manager only to assign the request to the
|
||||||
|
// given streamID
|
||||||
|
WithWhitelist = requestmanager.WithWhitelist |
||||||
|
) |
Loading…
Reference in new issue