[stream] added sync protocol

pull/3583/head
Jacky Wang 4 years ago
parent 1caf88c7db
commit 3977ac49eb
No known key found for this signature in database
GPG Key ID: 1085CE5F4FF5842C
  1. 169
      p2p/stream/protocols/sync/chain.go
  2. 208
      p2p/stream/protocols/sync/chain_test.go
  3. 396
      p2p/stream/protocols/sync/client.go
  4. 462
      p2p/stream/protocols/sync/client_test.go
  5. 17
      p2p/stream/protocols/sync/const.go
  6. 260
      p2p/stream/protocols/sync/protocol.go
  7. 95
      p2p/stream/protocols/sync/protocol_test.go
  8. 377
      p2p/stream/protocols/sync/stream.go
  9. 267
      p2p/stream/protocols/sync/stream_test.go
  10. 112
      p2p/stream/protocols/sync/utils.go

@ -0,0 +1,169 @@
package sync
import (
"fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/harmony-one/harmony/block"
"github.com/harmony-one/harmony/consensus/engine"
"github.com/harmony-one/harmony/core/types"
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding"
"github.com/pkg/errors"
)
// chainHelper is the adapter for blockchain which is friendly to unit test.
type chainHelper interface {
getCurrentBlockNumber() uint64
getBlockHashes(bns []uint64) []common.Hash
getBlocksByNumber(bns []uint64) ([]*types.Block, error)
getBlocksByHashes(hs []common.Hash) ([]*types.Block, error)
getEpochState(epoch uint64) (*EpochStateResult, error)
}
type chainHelperImpl struct {
chain engine.ChainReader
schedule shardingconfig.Schedule
}
func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl {
return &chainHelperImpl{
chain: chain,
schedule: schedule,
}
}
func (ch *chainHelperImpl) getCurrentBlockNumber() uint64 {
return ch.chain.CurrentBlock().NumberU64()
}
func (ch *chainHelperImpl) getBlockHashes(bns []uint64) []common.Hash {
hashes := make([]common.Hash, 0, len(bns))
for _, bn := range bns {
var h common.Hash
header := ch.chain.GetHeaderByNumber(bn)
if header != nil {
h = header.Hash()
}
hashes = append(hashes, h)
}
return hashes
}
func (ch *chainHelperImpl) getBlocksByNumber(bns []uint64) ([]*types.Block, error) {
var (
blocks = make([]*types.Block, 0, len(bns))
)
for _, bn := range bns {
var (
block *types.Block
err error
)
header := ch.chain.GetHeaderByNumber(bn)
if header != nil {
block, err = ch.getBlockWithSigByHeader(header)
if err != nil {
return nil, errors.Wrapf(err, "get block %v at %v", header.Hash().String(), header.Number())
}
}
blocks = append(blocks, block)
}
return blocks, nil
}
func (ch *chainHelperImpl) getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) {
var (
blocks = make([]*types.Block, 0, len(hs))
)
for _, h := range hs {
var (
block *types.Block
err error
)
header := ch.chain.GetHeaderByHash(h)
if header != nil {
block, err = ch.getBlockWithSigByHeader(header)
if err != nil {
return nil, errors.Wrapf(err, "get block %v at %v", header.Hash().String(), header.Number())
}
}
blocks = append(blocks, block)
}
return blocks, nil
}
var errBlockNotFound = errors.New("block not found")
func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) {
b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64())
if b == nil {
return nil, nil
}
commitSig, err := ch.getBlockSigAndBitmap(header)
if err != nil {
return nil, errors.New("missing commit signature")
}
b.SetCurrentCommitSig(commitSig)
return b, nil
}
func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) {
sb := ch.getBlockSigFromNextBlock(header)
if len(sb) != 0 {
return sb, nil
}
// Note: some commit sig read from db is different from [nextHeader.sig, nextHeader.bitMap]
// nextBlock data is better to be used.
return ch.getBlockSigFromDB(header)
}
func (ch *chainHelperImpl) getBlockSigFromNextBlock(header *block.Header) []byte {
nextBN := header.Number().Uint64() + 1
nextHeader := ch.chain.GetHeaderByNumber(nextBN)
if nextHeader == nil {
return nil
}
sigBytes := nextHeader.LastCommitSignature()
bitMap := nextHeader.LastCommitBitmap()
sb := make([]byte, len(sigBytes)+len(bitMap))
copy(sb[:], sigBytes[:])
copy(sb[len(sigBytes):], bitMap[:])
return sb
}
func (ch *chainHelperImpl) getBlockSigFromDB(header *block.Header) ([]byte, error) {
return ch.chain.ReadCommitSig(header.Number().Uint64())
}
func (ch *chainHelperImpl) getEpochState(epoch uint64) (*EpochStateResult, error) {
if ch.chain.ShardID() != 0 {
return nil, errors.New("get epoch state currently unavailable on side chain")
}
if epoch == 0 {
return nil, errors.New("nil shard state for epoch 0")
}
res := &EpochStateResult{}
targetBN := ch.schedule.EpochLastBlock(epoch - 1)
res.Header = ch.chain.GetHeaderByNumber(targetBN)
if res.Header == nil {
// we still don't have the given epoch
return res, nil
}
epochBI := new(big.Int).SetUint64(epoch)
if ch.chain.Config().IsPreStaking(epochBI) {
// For epoch before preStaking, only hash is stored in header
ss, err := ch.chain.ReadShardState(epochBI)
if err != nil {
return nil, err
}
if ss == nil {
return nil, fmt.Errorf("missing shard state for [EPOCH-%v]", epoch)
}
res.State = ss
}
return res, nil
}

@ -0,0 +1,208 @@
package sync
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
protobuf "github.com/golang/protobuf/proto"
"github.com/harmony-one/harmony/block"
"github.com/harmony-one/harmony/core/types"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
"github.com/harmony-one/harmony/shard"
)
type testChainHelper struct{}
func (tch *testChainHelper) getCurrentBlockNumber() uint64 {
return 100
}
func (tch *testChainHelper) getBlocksByNumber(bns []uint64) ([]*types.Block, error) {
blocks := make([]*types.Block, 0, len(bns))
for _, bn := range bns {
blocks = append(blocks, makeTestBlock(bn))
}
return blocks, nil
}
func (tch *testChainHelper) getEpochState(epoch uint64) (*EpochStateResult, error) {
header := &block.Header{Header: testHeader.Copy()}
header.SetEpoch(big.NewInt(int64(epoch - 1)))
state := testEpochState.DeepCopy()
state.Epoch = big.NewInt(int64(epoch))
return &EpochStateResult{
Header: header,
State: state,
}, nil
}
func (tch *testChainHelper) getBlockHashes(bns []uint64) []common.Hash {
hs := make([]common.Hash, 0, len(bns))
for _, bn := range bns {
hs = append(hs, numberToHash(bn))
}
return hs
}
func (tch *testChainHelper) getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) {
bs := make([]*types.Block, 0, len(hs))
for _, h := range hs {
bn := hashToNumber(h)
bs = append(bs, makeTestBlock(bn))
}
return bs, nil
}
func numberToHash(bn uint64) common.Hash {
var h common.Hash
binary.LittleEndian.PutUint64(h[:], bn)
return h
}
func hashToNumber(h common.Hash) uint64 {
return binary.LittleEndian.Uint64(h[:])
}
func checkBlocksResult(bns []uint64, b []byte) error {
var msg = &syncpb.Message{}
if err := protobuf.Unmarshal(b, msg); err != nil {
return err
}
gbResp, err := msg.GetBlocksByNumberResponse()
if err != nil {
return err
}
if len(gbResp.BlocksBytes) == 0 {
return errors.New("nil response from GetBlocksByNumber")
}
blocks, err := decodeBlocksBytes(gbResp.BlocksBytes)
if err != nil {
return err
}
if len(blocks) != len(bns) {
return errors.New("unexpected blocks number")
}
for i, bn := range bns {
blk := blocks[i]
if bn != blk.NumberU64() {
return errors.New("unexpected number of a block")
}
}
return nil
}
func makeTestBlock(bn uint64) *types.Block {
header := testHeader.Copy()
header.SetNumber(big.NewInt(int64(bn)))
return types.NewBlockWithHeader(&block.Header{Header: header})
}
func decodeBlocksBytes(bbs [][]byte) ([]*types.Block, error) {
blocks := make([]*types.Block, 0, len(bbs))
for _, bb := range bbs {
var block *types.Block
if err := rlp.DecodeBytes(bb, &block); err != nil {
return nil, err
}
blocks = append(blocks, block)
}
return blocks, nil
}
func checkEpochStateResult(epoch uint64, b []byte) error {
var msg = &syncpb.Message{}
if err := protobuf.Unmarshal(b, msg); err != nil {
return err
}
geResp, err := msg.GetEpochStateResponse()
if err != nil {
return err
}
var (
header *block.Header
epochState *shard.State
)
if err := rlp.DecodeBytes(geResp.HeaderBytes, &header); err != nil {
return err
}
if err := rlp.DecodeBytes(geResp.ShardState, &epochState); err != nil {
return err
}
if header.Epoch().Uint64() != epoch-1 {
return fmt.Errorf("unexpected epoch of header %v / %v", header.Epoch(), epoch-1)
}
if epochState.Epoch.Uint64() != epoch {
return fmt.Errorf("unexpected epoch of shard state %v / %v", epochState.Epoch.Uint64(), epoch)
}
return nil
}
func checkBlockNumberResult(b []byte) error {
var msg = &syncpb.Message{}
if err := protobuf.Unmarshal(b, msg); err != nil {
return err
}
gnResp, err := msg.GetBlockNumberResponse()
if err != nil {
return err
}
if gnResp.Number != testCurBlockNumber {
return fmt.Errorf("unexpected block number: %v / %v", gnResp.Number, testCurBlockNumber)
}
return nil
}
func checkBlockHashesResult(b []byte, bns []uint64) error {
var msg = &syncpb.Message{}
if err := protobuf.Unmarshal(b, msg); err != nil {
return err
}
bhResp, err := msg.GetBlockHashesResponse()
if err != nil {
return err
}
got := bhResp.Hashes
if len(got) != len(bns) {
return errors.New("unexpected size")
}
for i, bn := range bns {
expect := numberToHash(bn)
if !bytes.Equal(expect[:], got[i]) {
return errors.New("unexpected hash")
}
}
return nil
}
func checkBlocksByHashesResult(b []byte, hs []common.Hash) error {
var msg = &syncpb.Message{}
if err := protobuf.Unmarshal(b, msg); err != nil {
return err
}
bhResp, err := msg.GetBlocksByHashesResponse()
if err != nil {
return err
}
if len(hs) != len(bhResp.BlocksBytes) {
return errors.New("unexpected size")
}
for i, h := range hs {
num := hashToNumber(h)
var blk *types.Block
if err := rlp.DecodeBytes(bhResp.BlocksBytes[i], &blk); err != nil {
return err
}
if blk.NumberU64() != num {
return fmt.Errorf("unexpected number %v != %v", blk.NumberU64(), num)
}
}
return nil
}

@ -0,0 +1,396 @@
package sync
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
protobuf "github.com/golang/protobuf/proto"
"github.com/harmony-one/harmony/core/types"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/pkg/errors"
)
// GetBlocksByNumber do getBlocksByNumberRequest through sync stream protocol.
// Return the block as result, target stream id, and error
func (p *Protocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...Option) ([]*types.Block, sttypes.StreamID, error) {
if len(bns) > GetBlocksByNumAmountCap {
return nil, "", fmt.Errorf("number of blocks exceed cap of %v", GetBlocksByNumAmountCap)
}
req := newGetBlocksByNumberRequest(bns)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
// At this point, error can be context canceled, context timed out, or waiting queue
// is already full.
return nil, stid, err
}
// Parse and return blocks
blocks, err := req.getBlocksFromResponse(resp)
if err != nil {
return nil, stid, err
}
return blocks, stid, nil
}
// GetEpochState get the epoch block from querying the remote node running sync stream protocol.
// Currently, this method is only supported by beacon syncer.
// Note: use this after epoch chain is implemented.
func (p *Protocol) GetEpochState(ctx context.Context, epoch uint64, opts ...Option) (*EpochStateResult, sttypes.StreamID, error) {
req := newGetEpochBlockRequest(epoch)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
return nil, stid, err
}
res, err := epochStateResultFromResponse(resp)
if err != nil {
return nil, stid, err
}
return res, stid, nil
}
// GetCurrentBlockNumber get the current block number from remote node
func (p *Protocol) GetCurrentBlockNumber(ctx context.Context, opts ...Option) (uint64, sttypes.StreamID, error) {
req := newGetBlockNumberRequest()
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
return 0, stid, err
}
bn, err := req.getNumberFromResponse(resp)
if err != nil {
return bn, stid, err
}
return bn, stid, nil
}
// GetBlockHashes do getBlockHashesRequest through sync stream protocol.
// Return the hash of the given block number. If a block is unknown, the hash will be emptyHash.
func (p *Protocol) GetBlockHashes(ctx context.Context, bns []uint64, opts ...Option) ([]common.Hash, sttypes.StreamID, error) {
if len(bns) > GetBlockHashesAmountCap {
return nil, "", fmt.Errorf("number of requested numbers exceed limit")
}
req := newGetBlockHashesRequest(bns)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
return nil, stid, err
}
hashes, err := req.getHashesFromResponse(resp)
if err != nil {
return nil, stid, err
}
return hashes, stid, nil
}
// GetBlocksByHashes do getBlocksByHashesRequest through sync stream protocol.
func (p *Protocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts ...Option) ([]*types.Block, sttypes.StreamID, error) {
if len(hs) > GetBlocksByHashesAmountCap {
return nil, "", fmt.Errorf("number of requested hashes exceed limit")
}
req := newGetBlocksByHashesRequest(hs)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
return nil, stid, err
}
blocks, err := req.getBlocksFromResponse(resp)
if err != nil {
return nil, stid, err
}
return blocks, stid, nil
}
// getBlocksByNumberRequest is the request for get block by numbers which implements
// sttypes.Request interface
type getBlocksByNumberRequest struct {
bns []uint64
pbReq *syncpb.Request
}
func newGetBlocksByNumberRequest(bns []uint64) *getBlocksByNumberRequest {
pbReq := syncpb.MakeGetBlocksByNumRequest(bns)
return &getBlocksByNumberRequest{
bns: bns,
pbReq: pbReq,
}
}
func (req *getBlocksByNumberRequest) ReqID() uint64 {
return req.pbReq.GetReqId()
}
func (req *getBlocksByNumberRequest) SetReqID(val uint64) {
req.pbReq.ReqId = val
}
func (req *getBlocksByNumberRequest) String() string {
ss := make([]string, 0, len(req.bns))
for _, bn := range req.bns {
ss = append(ss, strconv.Itoa(int(bn)))
}
bnsStr := strings.Join(ss, ",")
return fmt.Sprintf("REQUEST [GetBlockByNumber: %s]", bnsStr)
}
func (req *getBlocksByNumberRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool {
return target.Version.GreaterThanOrEqual(MinVersion)
}
func (req *getBlocksByNumberRequest) Encode() ([]byte, error) {
msg := syncpb.MakeMessageFromRequest(req.pbReq)
return protobuf.Marshal(msg)
}
func (req *getBlocksByNumberRequest) getBlocksFromResponse(resp sttypes.Response) ([]*types.Block, error) {
sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil {
return nil, errors.New("not sync response")
}
blockBytes, sigs, err := req.parseBlockBytesAndSigs(sResp)
if err != nil {
return nil, err
}
blocks := make([]*types.Block, 0, len(blockBytes))
for i, bb := range blockBytes {
var block *types.Block
if err := rlp.DecodeBytes(bb, &block); err != nil {
return nil, errors.Wrap(err, "[GetBlocksByNumResponse]")
}
if block != nil {
block.SetCurrentCommitSig(sigs[i])
}
blocks = append(blocks, block)
}
return blocks, nil
}
func (req *getBlocksByNumberRequest) parseBlockBytesAndSigs(resp *syncResponse) ([][]byte, [][]byte, error) {
if errResp := resp.pb.GetErrorResponse(); errResp != nil {
return nil, nil, errors.New(errResp.Error)
}
gbResp := resp.pb.GetGetBlocksByNumResponse()
if gbResp == nil {
return nil, nil, errors.New("response not GetBlockByNumber")
}
if len(gbResp.BlocksBytes) != len(gbResp.CommitSig) {
return nil, nil, fmt.Errorf("commit sigs size not expected: %v / %v",
len(gbResp.CommitSig), len(gbResp.BlocksBytes))
}
return gbResp.BlocksBytes, gbResp.CommitSig, nil
}
type getEpochBlockRequest struct {
epoch uint64
pbReq *syncpb.Request
}
func newGetEpochBlockRequest(epoch uint64) *getEpochBlockRequest {
pbReq := syncpb.MakeGetEpochStateRequest(epoch)
return &getEpochBlockRequest{
epoch: epoch,
pbReq: pbReq,
}
}
func (req *getEpochBlockRequest) ReqID() uint64 {
return req.pbReq.GetReqId()
}
func (req *getEpochBlockRequest) SetReqID(val uint64) {
req.pbReq.ReqId = val
}
func (req *getEpochBlockRequest) String() string {
return fmt.Sprintf("REQUEST [GetEpochBlock: %v]", req.epoch)
}
func (req *getEpochBlockRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool {
return target.Version.GreaterThanOrEqual(MinVersion)
}
func (req *getEpochBlockRequest) Encode() ([]byte, error) {
msg := syncpb.MakeMessageFromRequest(req.pbReq)
return protobuf.Marshal(msg)
}
type getBlockNumberRequest struct {
pbReq *syncpb.Request
}
func newGetBlockNumberRequest() *getBlockNumberRequest {
pbReq := syncpb.MakeGetBlockNumberRequest()
return &getBlockNumberRequest{
pbReq: pbReq,
}
}
func (req *getBlockNumberRequest) ReqID() uint64 {
return req.pbReq.GetReqId()
}
func (req *getBlockNumberRequest) SetReqID(val uint64) {
req.pbReq.ReqId = val
}
func (req *getBlockNumberRequest) String() string {
return fmt.Sprintf("REQUEST [GetBlockNumber]")
}
func (req *getBlockNumberRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool {
return target.Version.GreaterThanOrEqual(MinVersion)
}
func (req *getBlockNumberRequest) Encode() ([]byte, error) {
msg := syncpb.MakeMessageFromRequest(req.pbReq)
return protobuf.Marshal(msg)
}
func (req *getBlockNumberRequest) getNumberFromResponse(resp sttypes.Response) (uint64, error) {
sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil {
return 0, errors.New("not sync response")
}
if errResp := sResp.pb.GetErrorResponse(); errResp != nil {
return 0, errors.New(errResp.Error)
}
gnResp := sResp.pb.GetGetBlockNumberResponse()
if gnResp == nil {
return 0, errors.New("response not GetBlockNumber")
}
return gnResp.Number, nil
}
type getBlockHashesRequest struct {
bns []uint64
pbReq *syncpb.Request
}
func newGetBlockHashesRequest(bns []uint64) *getBlockHashesRequest {
pbReq := syncpb.MakeGetBlockHashesRequest(bns)
return &getBlockHashesRequest{
bns: bns,
pbReq: pbReq,
}
}
func (req *getBlockHashesRequest) ReqID() uint64 {
return req.pbReq.ReqId
}
func (req *getBlockHashesRequest) SetReqID(val uint64) {
req.pbReq.ReqId = val
}
func (req *getBlockHashesRequest) String() string {
ss := make([]string, 0, len(req.bns))
for _, bn := range req.bns {
ss = append(ss, strconv.Itoa(int(bn)))
}
bnsStr := strings.Join(ss, ",")
return fmt.Sprintf("REQUEST [GetBlockHashes: %s]", bnsStr)
}
func (req *getBlockHashesRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool {
return target.Version.GreaterThanOrEqual(MinVersion)
}
func (req *getBlockHashesRequest) Encode() ([]byte, error) {
msg := syncpb.MakeMessageFromRequest(req.pbReq)
return protobuf.Marshal(msg)
}
func (req *getBlockHashesRequest) getHashesFromResponse(resp sttypes.Response) ([]common.Hash, error) {
sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil {
return nil, errors.New("not sync response")
}
if errResp := sResp.pb.GetErrorResponse(); errResp != nil {
return nil, errors.New(errResp.Error)
}
bhResp := sResp.pb.GetGetBlockHashesResponse()
if bhResp == nil {
return nil, errors.New("response not GetBlockHashes")
}
hashBytes := bhResp.Hashes
return bytesToHashes(hashBytes), nil
}
type getBlocksByHashesRequest struct {
hashes []common.Hash
pbReq *syncpb.Request
}
func newGetBlocksByHashesRequest(hashes []common.Hash) *getBlocksByHashesRequest {
pbReq := syncpb.MakeGetBlocksByHashesRequest(hashes)
return &getBlocksByHashesRequest{
hashes: hashes,
pbReq: pbReq,
}
}
func (req *getBlocksByHashesRequest) ReqID() uint64 {
return req.pbReq.GetReqId()
}
func (req *getBlocksByHashesRequest) SetReqID(val uint64) {
req.pbReq.ReqId = val
}
func (req *getBlocksByHashesRequest) String() string {
hashStrs := make([]string, 0, len(req.hashes))
for _, h := range req.hashes {
hashStrs = append(hashStrs, fmt.Sprintf("%x", h[:]))
}
hStr := strings.Join(hashStrs, ", ")
return fmt.Sprintf("REQUEST [GetBlocksByHashes: %v]", hStr)
}
func (req *getBlocksByHashesRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool {
return target.Version.GreaterThanOrEqual(MinVersion)
}
func (req *getBlocksByHashesRequest) Encode() ([]byte, error) {
msg := syncpb.MakeMessageFromRequest(req.pbReq)
return protobuf.Marshal(msg)
}
func (req *getBlocksByHashesRequest) getBlocksFromResponse(resp sttypes.Response) ([]*types.Block, error) {
sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil {
return nil, errors.New("not sync response")
}
if errResp := sResp.pb.GetErrorResponse(); errResp != nil {
return nil, errors.New(errResp.Error)
}
bhResp := sResp.pb.GetGetBlocksByHashesResponse()
if bhResp == nil {
return nil, errors.New("response not GetBlocksByHashes")
}
var (
blockBytes = bhResp.BlocksBytes
sigs = bhResp.CommitSig
)
if len(blockBytes) != len(sigs) {
return nil, fmt.Errorf("sig size not expected: %v / %v", len(sigs), len(blockBytes))
}
blocks := make([]*types.Block, 0, len(blockBytes))
for i, bb := range blockBytes {
var block *types.Block
if err := rlp.DecodeBytes(bb, &block); err != nil {
return nil, errors.Wrap(err, "[GetBlocksByHashesResponse]")
}
if block != nil {
block.SetCurrentCommitSig(sigs[i])
}
blocks = append(blocks, block)
}
return blocks, nil
}

@ -0,0 +1,462 @@
package sync
import (
"context"
"errors"
"fmt"
"math/big"
"strings"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/rlp"
"github.com/harmony-one/harmony/block"
headerV3 "github.com/harmony-one/harmony/block/v3"
"github.com/harmony-one/harmony/core/types"
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter"
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/harmony-one/harmony/shard"
)
var (
_ sttypes.Request = &getBlocksByNumberRequest{}
_ sttypes.Request = &getEpochBlockRequest{}
_ sttypes.Request = &getBlockNumberRequest{}
_ sttypes.Response = &syncResponse{&syncpb.Response{}}
)
var (
initStreamIDs = []sttypes.StreamID{
makeTestStreamID(0),
makeTestStreamID(1),
makeTestStreamID(2),
makeTestStreamID(3),
}
)
var (
testHeader = &block.Header{Header: headerV3.NewHeader()}
testBlock = types.NewBlockWithHeader(testHeader)
testHeaderBytes, _ = rlp.EncodeToBytes(testHeader)
testBlockBytes, _ = rlp.EncodeToBytes(testBlock)
testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
testEpochState = &shard.State{
Epoch: new(big.Int).SetInt64(1),
Shards: []shard.Committee{},
}
testEpochStateBytes, _ = rlp.EncodeToBytes(testEpochState)
testEpochStateResponse = syncpb.MakeGetEpochStateResponse(0, testHeaderBytes, testEpochStateBytes)
testCurBlockNumber uint64 = 100
testBlockNumberResponse = syncpb.MakeGetBlockNumberResponse(0, testCurBlockNumber)
testHash = numberToHash(100)
testBlockHashesResponse = syncpb.MakeGetBlockHashesResponse(0, []common.Hash{testHash})
testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error"))
)
func TestProtocol_GetBlocksByNumber(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testEpochStateResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlockByNumber"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
blocks, stid, err := protocol.GetBlocksByNumber(context.Background(), []uint64{0})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil && (len(blocks) == 0) {
t.Errorf("Test %v: zero blocks delivered", i)
}
}
}
func TestProtocol_GetEpochState(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testEpochStateResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetEpochStateResponse"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
res, stid, err := protocol.GetEpochState(context.Background(), 0)
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if gotEpoch := res.State.Epoch; gotEpoch.Cmp(new(big.Int).SetUint64(1)) != 0 {
t.Errorf("Test %v: unexpected epoch delivered: %v / %v", i, gotEpoch.String(), 1)
}
}
}
}
func TestProtocol_GetCurrentBlockNumber(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockNumberResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlockNumber"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
res, stid, err := protocol.GetCurrentBlockNumber(context.Background())
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if res != testCurBlockNumber {
t.Errorf("Test %v: block number not expected: %v / %v", i, res, testCurBlockNumber)
}
}
}
}
func TestProtocol_GetBlockHashes(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockHashesResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlockHashes"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
res, stid, err := protocol.GetBlockHashes(context.Background(), []uint64{100})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if len(res) != 1 {
t.Errorf("Test %v: size not 1", i)
}
if res[0] != testHash {
t.Errorf("Test %v: hash not expected", i)
}
}
}
}
func TestProtocol_GetBlocksByHashes(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlocksByHashesResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("not GetBlocksByHashes"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
blocks, stid, err := protocol.GetBlocksByHashes(context.Background(), []common.Hash{numberToHash(100)})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if len(blocks) != 1 {
t.Errorf("Test %v: size not 1", i)
}
}
}
}
type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID)
type testHostRequestManager struct {
getResponse getResponseFn
}
func makeTestProtocol(f getResponseFn) *Protocol {
rm := &testHostRequestManager{f}
streamIDs := make([]sttypes.StreamID, len(initStreamIDs))
copy(streamIDs, initStreamIDs)
sm := &testStreamManager{streamIDs}
rl := ratelimiter.NewRateLimiter(10, 10)
return &Protocol{
rm: rm,
rl: rl,
sm: sm,
}
}
func (rm *testHostRequestManager) Start() {}
func (rm *testHostRequestManager) Close() {}
func (rm *testHostRequestManager) DeliverResponse(sttypes.StreamID, sttypes.Response) {}
func (rm *testHostRequestManager) DoRequest(ctx context.Context, request sttypes.Request, opts ...Option) (sttypes.Response, sttypes.StreamID, error) {
if rm.getResponse == nil {
return nil, "", errors.New("get response error")
}
resp, stid := rm.getResponse(request)
return resp, stid, nil
}
func makeTestStreamID(index int) sttypes.StreamID {
id := fmt.Sprintf("[test stream %v]", index)
return sttypes.StreamID(id)
}
// mock stream manager
type testStreamManager struct {
streamIDs []sttypes.StreamID
}
func (sm *testStreamManager) Start() {}
func (sm *testStreamManager) Close() {}
func (sm *testStreamManager) SubscribeAddStreamEvent(chan<- streammanager.EvtStreamAdded) event.Subscription {
return nil
}
func (sm *testStreamManager) SubscribeRemoveStreamEvent(chan<- streammanager.EvtStreamRemoved) event.Subscription {
return nil
}
func (sm *testStreamManager) NewStream(stream sttypes.Stream) error {
stid := stream.ID()
for _, id := range sm.streamIDs {
if id == stid {
return errors.New("stream already exist")
}
}
sm.streamIDs = append(sm.streamIDs, stid)
return nil
}
func (sm *testStreamManager) RemoveStream(stID sttypes.StreamID) error {
for i, id := range sm.streamIDs {
if id == stID {
sm.streamIDs = append(sm.streamIDs[:i], sm.streamIDs[i+1:]...)
}
}
return errors.New("stream not exist")
}
func (sm *testStreamManager) isStreamExist(stid sttypes.StreamID) bool {
for _, id := range sm.streamIDs {
if id == stid {
return true
}
}
return false
}
func (sm *testStreamManager) GetStreams() []sttypes.Stream {
return nil
}
func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) {
return nil, false
}
func assertError(got, expect error) error {
if (got == nil) != (expect == nil) {
return fmt.Errorf("unexpected error: %v / %v", got, expect)
}
if got == nil {
return nil
}
if !strings.Contains(got.Error(), expect.Error()) {
return fmt.Errorf("unexpected error: %v/ %v", got, expect)
}
return nil
}

@ -0,0 +1,17 @@
package sync
import "time"
const (
// GetBlockHashesAmountCap is the cap of GetBlockHashes reqeust
GetBlockHashesAmountCap = 50
// GetBlocksByNumAmountCap is the cap of request of a single GetBlocksByNum request
GetBlocksByNumAmountCap = 10
// GetBlocksByHashesAmountCap is the cap of request of single GetBlocksByHashes request
GetBlocksByHashesAmountCap = 10
// minAdvertiseInterval is the minimum advertise interval
minAdvertiseInterval = 1 * time.Minute
)

@ -0,0 +1,260 @@
package sync
import (
"context"
"strconv"
"time"
"github.com/ethereum/go-ethereum/event"
"github.com/harmony-one/harmony/consensus/engine"
nodeconfig "github.com/harmony-one/harmony/internal/configs/node"
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding"
"github.com/harmony-one/harmony/internal/utils"
"github.com/harmony-one/harmony/p2p/discovery"
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter"
"github.com/harmony-one/harmony/p2p/stream/common/requestmanager"
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/hashicorp/go-version"
libp2p_host "github.com/libp2p/go-libp2p-core/host"
libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/rs/zerolog"
)
const (
// serviceSpecifier is the specifier for the service.
serviceSpecifier = "sync"
)
var (
version100, _ = version.NewVersion("1.0.0")
// MyVersion is the version of sync protocol of the local node
MyVersion = version100
// MinVersion is the minimum version for matching function
MinVersion = version100
)
type (
// Protocol is the protocol for sync streaming
Protocol struct {
chain engine.ChainReader // provide SYNC data
schedule shardingconfig.Schedule // provide schedule information
rl ratelimiter.RateLimiter // limit the incoming request rate
sm streammanager.StreamManager // stream management
rm requestmanager.RequestManager // deliver the response from stream
disc discovery.Discovery
config Config
logger zerolog.Logger
ctx context.Context
cancel func()
closeC chan struct{}
}
// Config is the sync protocol config
Config struct {
Chain engine.ChainReader
Host libp2p_host.Host
Discovery discovery.Discovery
ShardID nodeconfig.ShardID
Network nodeconfig.NetworkType
// stream manager config
SmSoftLowCap int
SmHardLowCap int
SmHiCap int
DiscBatch int
}
)
// NewProtocol creates a new sync protocol
func NewProtocol(config Config) *Protocol {
ctx, cancel := context.WithCancel(context.Background())
sp := &Protocol{
chain: config.Chain,
disc: config.Discovery,
config: config,
ctx: ctx,
cancel: cancel,
closeC: make(chan struct{}),
}
smConfig := streammanager.Config{
SoftLoCap: config.SmSoftLowCap,
HardLoCap: config.SmHardLowCap,
HiCap: config.SmHiCap,
DiscBatch: config.DiscBatch,
}
sp.sm = streammanager.NewStreamManager(sp.ProtoID(), config.Host, config.Discovery,
sp.HandleStream, smConfig)
sp.rl = ratelimiter.NewRateLimiter(sp.sm, 50, 10)
sp.rm = requestmanager.NewRequestManager(sp.sm)
sp.logger = utils.Logger().With().Str("Protocol", string(sp.ProtoID())).Logger()
return sp
}
// Start starts the sync protocol
func (p *Protocol) Start() {
p.sm.Start()
p.rm.Start()
p.rl.Start()
go p.advertiseLoop()
}
// Close close the protocol
func (p *Protocol) Close() {
p.rl.Close()
p.rm.Close()
p.sm.Close()
p.cancel()
close(p.closeC)
}
// Specifier return the specifier for the protocol
func (p *Protocol) Specifier() string {
return serviceSpecifier + "/" + strconv.Itoa(int(p.config.ShardID))
}
// ProtoID return the ProtoID of the sync protocol
func (p *Protocol) ProtoID() sttypes.ProtoID {
return p.protoIDByVersion(MyVersion)
}
// Version returns the sync protocol version
func (p *Protocol) Version() *version.Version {
return MyVersion
}
// Match checks the compatibility to the target protocol ID.
func (p *Protocol) Match(targetID string) bool {
target, err := sttypes.ProtoIDToProtoSpec(sttypes.ProtoID(targetID))
if err != nil {
return false
}
if target.Service != serviceSpecifier {
return false
}
if target.NetworkType != p.config.Network {
return false
}
if target.ShardID != p.config.ShardID {
return false
}
if target.Version.LessThan(MinVersion) {
return false
}
return true
}
// HandleStream is the stream handle function being registered to libp2p.
func (p *Protocol) HandleStream(raw libp2p_network.Stream) {
p.logger.Info().Str("stream", raw.ID()).Msg("handle new sync stream")
st := p.wrapStream(raw)
if err := p.sm.NewStream(st); err != nil {
// Possibly we have reach the hard limit of the stream
p.logger.Warn().Err(err).Str("stream ID", string(st.ID())).
Msg("failed to add new stream")
return
}
st.run()
}
func (p *Protocol) advertiseLoop() {
for {
sleep := p.advertise()
select {
case <-time.After(sleep):
case <-p.closeC:
return
}
}
}
// advertise will advertise all compatible protocol versions for helping nodes running low
// version
func (p *Protocol) advertise() time.Duration {
var nextWait time.Duration
pids := p.supportedProtoIDs()
for _, pid := range pids {
w, e := p.disc.Advertise(p.ctx, string(pid))
if e != nil {
p.logger.Warn().Err(e).Str("protocol", string(pid)).
Msg("cannot advertise sync protocol")
continue
}
if nextWait == 0 || nextWait > w {
nextWait = w
}
}
if nextWait < minAdvertiseInterval {
nextWait = minAdvertiseInterval
}
return nextWait
}
func (p *Protocol) supportedProtoIDs() []sttypes.ProtoID {
vs := p.supportedVersions()
pids := make([]sttypes.ProtoID, 0, len(vs))
for _, v := range vs {
pids = append(pids, p.protoIDByVersion(v))
}
return pids
}
func (p *Protocol) supportedVersions() []*version.Version {
return []*version.Version{version100}
}
func (p *Protocol) protoIDByVersion(v *version.Version) sttypes.ProtoID {
spec := sttypes.ProtoSpec{
Service: serviceSpecifier,
NetworkType: p.config.Network,
ShardID: p.config.ShardID,
Version: v,
}
return spec.ToProtoID()
}
// RemoveStream removes the stream of the given stream ID
func (p *Protocol) RemoveStream(stID sttypes.StreamID) {
if stID == "" {
return
}
st, exist := p.sm.GetStreamByID(stID)
if exist && st != nil {
st.Close()
}
}
// NumStreams return the streams with minimum version.
// Note: nodes with sync version smaller than minVersion is not counted.
func (p *Protocol) NumStreams() int {
res := 0
sts := p.sm.GetStreams()
for _, st := range sts {
ps, _ := st.ProtoSpec()
if ps.Version.GreaterThanOrEqual(MinVersion) {
res++
}
}
return res
}
// GetStreamManager get the underlying stream manager for upper level stream operations
func (p *Protocol) GetStreamManager() streammanager.StreamManager {
return p.sm
}
// SubscribeAddStreamEvent subscribe the stream add event
func (p *Protocol) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription {
return p.sm.SubscribeAddStreamEvent(ch)
}

@ -0,0 +1,95 @@
package sync
import (
"context"
"testing"
"time"
"github.com/libp2p/go-libp2p-core/discovery"
libp2p_peer "github.com/libp2p/go-libp2p-core/peer"
)
func TestProtocol_Match(t *testing.T) {
tests := []struct {
targetID string
exp bool
}{
{"harmony/sync/unitest/0/1.0.1", true},
{"h123456", false},
{"harmony/sync/unitest/0/0.9.9", false},
{"harmony/epoch/unitest/0/1.0.1", false},
{"harmony/sync/mainnet/0/1.0.1", false},
{"harmony/sync/unitest/1/1.0.1", false},
}
for i, test := range tests {
p := &Protocol{
config: Config{
Network: "unitest",
ShardID: 0,
},
}
res := p.Match(test.targetID)
if res != test.exp {
t.Errorf("Test %v: unexpected result %v / %v", i, res, test.exp)
}
}
}
func TestProtocol_advertiseLoop(t *testing.T) {
disc := newTestDiscovery(100 * time.Millisecond)
p := &Protocol{
disc: disc,
closeC: make(chan struct{}),
}
go p.advertiseLoop()
time.Sleep(150 * time.Millisecond)
close(p.closeC)
if len(disc.advCnt) != len(p.supportedVersions()) {
t.Errorf("unexpected advertise topic count: %v / %v", len(disc.advCnt),
len(p.supportedVersions()))
}
for _, cnt := range disc.advCnt {
if cnt < 1 {
t.Errorf("unexpected discovery count: %v", cnt)
}
}
}
type testDiscovery struct {
advCnt map[string]int
sleep time.Duration
}
func newTestDiscovery(discInterval time.Duration) *testDiscovery {
return &testDiscovery{
advCnt: make(map[string]int),
sleep: discInterval,
}
}
func (disc *testDiscovery) Start() error {
return nil
}
func (disc *testDiscovery) Close() error {
return nil
}
func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) {
disc.advCnt[ns]++
return disc.sleep, nil
}
func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) {
return nil, nil
}
func (disc *testDiscovery) GetRawDiscovery() discovery.Discovery {
return nil
}

@ -0,0 +1,377 @@
package sync
import (
"fmt"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
protobuf "github.com/golang/protobuf/proto"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
// syncStream is the structure for a stream running sync protocol.
type syncStream struct {
// Basic stream
*sttypes.BaseStream
protocol *Protocol
chain chainHelper
// pipeline channels
reqC chan *syncpb.Request
respC chan *syncpb.Response
// close related fields. Concurrent call of close is possible.
closeC chan struct{}
closeStat uint32
logger zerolog.Logger
}
// wrapStream wraps the raw libp2p stream to syncStream
func (p *Protocol) wrapStream(raw libp2p_network.Stream) *syncStream {
bs := sttypes.NewBaseStream(raw)
logger := p.logger.With().
Str("ID", string(bs.ID())).
Str("Remote Protocol", string(bs.ProtoID())).
Logger()
return &syncStream{
BaseStream: bs,
protocol: p,
chain: newChainHelper(p.chain, p.schedule),
reqC: make(chan *syncpb.Request, 100),
respC: make(chan *syncpb.Response, 100),
closeC: make(chan struct{}),
closeStat: 0,
logger: logger,
}
}
func (st *syncStream) run() {
st.logger.Info().Str("StreamID", string(st.ID())).Msg("running sync protocol on stream")
defer st.logger.Info().Str("StreamID", string(st.ID())).Msg("end running sync protocol on stream")
go st.handleReqLoop()
go st.handleRespLoop()
st.readMsgLoop()
}
// readMsgLoop is the loop
func (st *syncStream) readMsgLoop() {
for {
msg, err := st.readMsg()
if err != nil {
if err := st.Close(); err != nil {
st.logger.Err(err).Msg("failed to close sync stream")
}
return
}
st.deliverMsg(msg)
}
}
// deliverMsg process the delivered message and forward to the corresponding channel
func (st *syncStream) deliverMsg(msg protobuf.Message) {
syncMsg := msg.(*syncpb.Message)
if syncMsg == nil {
st.logger.Info().Str("message", msg.String()).Msg("received unexpected sync message")
return
}
if req := syncMsg.GetReq(); req != nil {
go func() {
select {
case st.reqC <- req:
case <-time.After(1 * time.Minute):
st.logger.Warn().Str("request", req.String()).
Msg("request handler severely jammed, message dropped")
}
}()
}
if resp := syncMsg.GetResp(); resp != nil {
go func() {
select {
case st.respC <- resp:
case <-time.After(1 * time.Minute):
st.logger.Warn().Str("response", resp.String()).
Msg("response handler severely jammed, message dropped")
}
}()
}
return
}
func (st *syncStream) handleReqLoop() {
for {
select {
case req := <-st.reqC:
st.protocol.rl.LimitRequest(st.ID())
err := st.handleReq(req)
if err != nil {
st.logger.Info().Err(err).Str("request", req.String()).
Msg("handle request error. Closing stream")
if err := st.Close(); err != nil {
st.logger.Err(err).Msg("failed to close sync stream")
}
return
}
case <-st.closeC:
return
}
}
}
func (st *syncStream) handleRespLoop() {
for {
select {
case resp := <-st.respC:
st.handleResp(resp)
case <-st.closeC:
return
}
}
}
// Close stops the stream handling and closes the underlying stream
func (st *syncStream) Close() error {
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1)
if !notClosed {
// Already closed by another goroutine. Directly return
return nil
}
if err := st.protocol.sm.RemoveStream(st.ID()); err != nil {
st.logger.Err(err).Str("stream ID", string(st.ID())).
Msg("failed to remove sync stream on close")
}
close(st.closeC)
return st.BaseStream.Close()
}
// ResetOnClose reset the stream on close
func (st *syncStream) ResetOnClose() error {
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1)
if !notClosed {
// Already closed by another goroutine. Directly return
return nil
}
close(st.closeC)
return st.BaseStream.ResetOnClose()
}
func (st *syncStream) handleReq(req *syncpb.Request) error {
if gnReq := req.GetGetBlockNumberRequest(); gnReq != nil {
return st.handleGetBlockNumberRequest(req.ReqId)
}
if ghReq := req.GetGetBlockHashesRequest(); ghReq != nil {
return st.handleGetBlockHashesRequest(req.ReqId, ghReq)
}
if bnReq := req.GetGetBlocksByNumRequest(); bnReq != nil {
return st.handleGetBlocksByNumRequest(req.ReqId, bnReq)
}
if bhReq := req.GetGetBlocksByHashesRequest(); bhReq != nil {
return st.handleGetBlocksByHashesRequest(req.ReqId, bhReq)
}
if esReq := req.GetGetEpochStateRequest(); esReq != nil {
return st.handleEpochStateRequest(req.ReqId, esReq)
}
// unsupported request type
resp := syncpb.MakeErrorResponseMessage(req.ReqId, errUnknownReqType)
return st.writeMsg(resp)
}
func (st *syncStream) handleGetBlockNumberRequest(rid uint64) error {
resp := st.computeBlockNumberResp(rid)
if err := st.writeMsg(resp); err != nil {
return errors.Wrap(err, "[GetBlockNumber]: writeMsg")
}
return nil
}
func (st *syncStream) handleGetBlockHashesRequest(rid uint64, req *syncpb.GetBlockHashesRequest) error {
resp, err := st.computeGetBlockHashesResp(rid, req.Nums)
if err != nil {
resp = syncpb.MakeErrorResponseMessage(rid, err)
}
if writeErr := st.writeMsg(resp); writeErr != nil {
if err == nil {
err = writeErr
} else {
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
}
}
return errors.Wrap(err, "[GetBlockHashes]")
}
func (st *syncStream) handleGetBlocksByNumRequest(rid uint64, req *syncpb.GetBlocksByNumRequest) error {
resp, err := st.computeRespFromBlockNumber(rid, req.Nums)
if resp == nil && err != nil {
resp = syncpb.MakeErrorResponseMessage(rid, err)
}
if writeErr := st.writeMsg(resp); writeErr != nil {
if err == nil {
err = writeErr
} else {
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
}
}
return errors.Wrap(err, "[GetBlocksByNumber]")
}
func (st *syncStream) handleGetBlocksByHashesRequest(rid uint64, req *syncpb.GetBlocksByHashesRequest) error {
hashes := bytesToHashes(req.BlockHashes)
resp, err := st.computeRespFromBlockHashes(rid, hashes)
if resp == nil && err != nil {
resp = syncpb.MakeErrorResponseMessage(rid, err)
}
if writeErr := st.writeMsg(resp); writeErr != nil {
if err == nil {
err = writeErr
} else {
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
}
}
return errors.Wrap(err, "[GetBlocksByHashes]")
}
func (st *syncStream) handleEpochStateRequest(rid uint64, req *syncpb.GetEpochStateRequest) error {
resp, err := st.computeEpochStateResp(rid, req.Epoch)
if err != nil {
resp = syncpb.MakeErrorResponseMessage(rid, err)
}
if writeErr := st.writeMsg(resp); writeErr != nil {
if err == nil {
err = writeErr
} else {
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
}
}
return errors.Wrap(err, "[GetEpochState]")
}
func (st *syncStream) handleResp(resp *syncpb.Response) {
st.protocol.rm.DeliverResponse(st.ID(), &syncResponse{resp})
}
func (st *syncStream) readMsg() (*syncpb.Message, error) {
b, err := st.ReadBytes()
if err != nil {
return nil, err
}
var msg = &syncpb.Message{}
if err := protobuf.Unmarshal(b, msg); err != nil {
return nil, err
}
return msg, nil
}
func (st *syncStream) writeMsg(msg *syncpb.Message) error {
b, err := protobuf.Marshal(msg)
if err != nil {
return err
}
return st.WriteBytes(b)
}
func (st *syncStream) computeBlockNumberResp(rid uint64) *syncpb.Message {
bn := st.chain.getCurrentBlockNumber()
return syncpb.MakeGetBlockNumberResponseMessage(rid, bn)
}
func (st syncStream) computeGetBlockHashesResp(rid uint64, bns []uint64) (*syncpb.Message, error) {
if len(bns) > GetBlockHashesAmountCap {
err := fmt.Errorf("GetBlockHashes amount exceed cap: %v>%v", len(bns), GetBlockHashesAmountCap)
return nil, err
}
hashes := st.chain.getBlockHashes(bns)
return syncpb.MakeGetBlockHashesResponseMessage(rid, hashes), nil
}
func (st *syncStream) computeRespFromBlockNumber(rid uint64, bns []uint64) (*syncpb.Message, error) {
if len(bns) > GetBlocksByNumAmountCap {
err := fmt.Errorf("GetBlocksByNum amount exceed cap: %v>%v", len(bns), GetBlocksByNumAmountCap)
return nil, err
}
blocks, err := st.chain.getBlocksByNumber(bns)
if err != nil {
return nil, err
}
var (
blocksBytes = make([][]byte, 0, len(blocks))
sigs = make([][]byte, 0, len(blocks))
)
for _, block := range blocks {
bb, err := rlp.EncodeToBytes(block)
if err != nil {
return nil, err
}
blocksBytes = append(blocksBytes, bb)
var sig []byte
if block != nil {
sig = block.GetCurrentCommitSig()
}
sigs = append(sigs, sig)
}
return syncpb.MakeGetBlocksByNumResponseMessage(rid, blocksBytes, sigs), nil
}
func (st *syncStream) computeRespFromBlockHashes(rid uint64, hs []common.Hash) (*syncpb.Message, error) {
if len(hs) > GetBlocksByHashesAmountCap {
err := fmt.Errorf("GetBlockByHashes amount exceed cap: %v > %v", len(hs), GetBlocksByHashesAmountCap)
return nil, err
}
blocks, err := st.chain.getBlocksByHashes(hs)
if err != nil {
return nil, err
}
var (
blocksBytes = make([][]byte, 0, len(blocks))
sigs = make([][]byte, 0, len(blocks))
)
for _, block := range blocks {
bb, err := rlp.EncodeToBytes(block)
if err != nil {
return nil, err
}
blocksBytes = append(blocksBytes, bb)
var sig []byte
if block != nil {
sig = block.GetCurrentCommitSig()
}
sigs = append(sigs, sig)
}
return syncpb.MakeGetBlocksByHashesResponseMessage(rid, blocksBytes, sigs), nil
}
func (st *syncStream) computeEpochStateResp(rid uint64, epoch uint64) (*syncpb.Message, error) {
if epoch == 0 {
return nil, errors.New("Epoch 0 does not have shard state")
}
esRes, err := st.chain.getEpochState(epoch)
if err != nil {
return nil, err
}
return esRes.toMessage(rid)
}
func bytesToHashes(bs [][]byte) []common.Hash {
hs := make([]common.Hash, 0, len(bs))
for _, b := range bs {
var h common.Hash
copy(h[:], b)
hs = append(hs, h)
}
return hs
}

@ -0,0 +1,267 @@
package sync
import (
"bytes"
"context"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
protobuf "github.com/golang/protobuf/proto"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
ic "github.com/libp2p/go-libp2p-core/crypto"
libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
ma "github.com/multiformats/go-multiaddr"
)
var _ sttypes.Protocol = &Protocol{}
var (
testGetBlockNumbers = []uint64{1, 2, 3, 4, 5}
testGetBlockRequest = syncpb.MakeGetBlocksByNumRequest(testGetBlockNumbers)
testGetBlockRequestMsg = syncpb.MakeMessageFromRequest(testGetBlockRequest)
testEpoch uint64 = 20
testEpochStateRequest = syncpb.MakeGetEpochStateRequest(testEpoch)
testEpochStateRequestMsg = syncpb.MakeMessageFromRequest(testEpochStateRequest)
testCurrentNumberRequest = syncpb.MakeGetBlockNumberRequest()
testCurrentNumberRequestMsg = syncpb.MakeMessageFromRequest(testCurrentNumberRequest)
testGetBlockHashNums = []uint64{1, 2, 3, 4, 5}
testGetBlockHashesRequest = syncpb.MakeGetBlockHashesRequest(testGetBlockHashNums)
testGetBlockHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlockHashesRequest)
testGetBlockByHashes = []common.Hash{
numberToHash(1),
numberToHash(2),
numberToHash(3),
numberToHash(4),
numberToHash(5),
}
testGetBlocksByHashesRequest = syncpb.MakeGetBlocksByHashesRequest(testGetBlockByHashes)
testGetBlocksByHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlocksByHashesRequest)
)
func TestSyncStream_HandleGetBlocksByRequest(t *testing.T) {
st, remoteSt := makeTestSyncStream()
go st.run()
defer close(st.closeC)
req := testGetBlockRequestMsg
b, _ := protobuf.Marshal(req)
err := remoteSt.WriteBytes(b)
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
receivedBytes, _ := remoteSt.ReadBytes()
if err := checkBlocksResult(testGetBlockNumbers, receivedBytes); err != nil {
t.Fatal(err)
}
}
func TestSyncStream_HandleEpochStateRequest(t *testing.T) {
st, remoteSt := makeTestSyncStream()
go st.run()
defer close(st.closeC)
req := testEpochStateRequestMsg
b, _ := protobuf.Marshal(req)
err := remoteSt.WriteBytes(b)
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
receivedBytes, _ := remoteSt.ReadBytes()
if err := checkEpochStateResult(testEpoch, receivedBytes); err != nil {
t.Fatal(err)
}
}
func TestSyncStream_HandleCurrentBlockNumber(t *testing.T) {
st, remoteSt := makeTestSyncStream()
go st.run()
defer close(st.closeC)
req := testCurrentNumberRequestMsg
b, _ := protobuf.Marshal(req)
err := remoteSt.WriteBytes(b)
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
receivedBytes, _ := remoteSt.ReadBytes()
if err := checkBlockNumberResult(receivedBytes); err != nil {
t.Fatal(err)
}
}
func TestSyncStream_HandleGetBlockHashes(t *testing.T) {
st, remoteSt := makeTestSyncStream()
go st.run()
defer close(st.closeC)
req := testGetBlockHashesRequestMsg
b, _ := protobuf.Marshal(req)
err := remoteSt.WriteBytes(b)
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
receivedBytes, _ := remoteSt.ReadBytes()
if err := checkBlockHashesResult(receivedBytes, testGetBlockNumbers); err != nil {
t.Fatal(err)
}
}
func TestSyncStream_HandleGetBlocksByHashes(t *testing.T) {
st, remoteSt := makeTestSyncStream()
go st.run()
defer close(st.closeC)
req := testGetBlocksByHashesRequestMsg
b, _ := protobuf.Marshal(req)
err := remoteSt.WriteBytes(b)
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
receivedBytes, _ := remoteSt.ReadBytes()
if err := checkBlocksByHashesResult(receivedBytes, testGetBlockByHashes); err != nil {
t.Fatal(err)
}
}
func makeTestSyncStream() (*syncStream, *testRemoteBaseStream) {
localRaw, remoteRaw := makePairP2PStreams()
remote := newTestRemoteBaseStream(remoteRaw)
bs := sttypes.NewBaseStream(localRaw)
return &syncStream{
BaseStream: bs,
chain: &testChainHelper{},
protocol: makeTestProtocol(nil),
reqC: make(chan *syncpb.Request, 100),
respC: make(chan *syncpb.Response, 100),
closeC: make(chan struct{}),
closeStat: 0,
}, remote
}
type testP2PStream struct {
readBuf *bytes.Buffer
inC chan struct{}
identity string
writeHook func([]byte) (int, error)
}
func makePairP2PStreams() (*testP2PStream, *testP2PStream) {
buf1 := bytes.NewBuffer(nil)
buf2 := bytes.NewBuffer(nil)
st1 := &testP2PStream{
readBuf: buf1,
inC: make(chan struct{}, 1),
identity: "local",
}
st2 := &testP2PStream{
readBuf: buf2,
inC: make(chan struct{}, 1),
identity: "remote",
}
st1.writeHook = st2.receiveBytes
st2.writeHook = st1.receiveBytes
return st1, st2
}
func (st *testP2PStream) Read(b []byte) (n int, err error) {
<-st.inC
n, err = st.readBuf.Read(b)
if st.readBuf.Len() != 0 {
select {
case st.inC <- struct{}{}:
default:
}
}
return
}
func (st *testP2PStream) Write(b []byte) (n int, err error) {
return st.writeHook(b)
}
func (st *testP2PStream) receiveBytes(b []byte) (n int, err error) {
n, err = st.readBuf.Write(b)
select {
case st.inC <- struct{}{}:
default:
}
return
}
func (st *testP2PStream) Close() error { return nil }
func (st *testP2PStream) CloseRead() error { return nil }
func (st *testP2PStream) CloseWrite() error { return nil }
func (st *testP2PStream) Reset() error { return nil }
func (st *testP2PStream) SetDeadline(time.Time) error { return nil }
func (st *testP2PStream) SetReadDeadline(time.Time) error { return nil }
func (st *testP2PStream) SetWriteDeadline(time.Time) error { return nil }
func (st *testP2PStream) ID() string { return "" }
func (st *testP2PStream) Protocol() protocol.ID { return "" }
func (st *testP2PStream) SetProtocol(protocol.ID) {}
func (st *testP2PStream) Stat() libp2p_network.Stat { return libp2p_network.Stat{} }
func (st *testP2PStream) Conn() libp2p_network.Conn { return &fakeConn{} }
type testRemoteBaseStream struct {
base *sttypes.BaseStream
}
func newTestRemoteBaseStream(st *testP2PStream) *testRemoteBaseStream {
rst := &testRemoteBaseStream{
base: sttypes.NewBaseStream(st),
}
return rst
}
func (st *testRemoteBaseStream) ReadBytes() ([]byte, error) {
return st.base.ReadBytes()
}
func (st *testRemoteBaseStream) WriteBytes(b []byte) error {
return st.base.WriteBytes(b)
}
type fakeConn struct{}
func (conn *fakeConn) Close() error { return nil }
func (conn *fakeConn) LocalPeer() peer.ID { return "" }
func (conn *fakeConn) LocalPrivateKey() ic.PrivKey { return nil }
func (conn *fakeConn) RemotePeer() peer.ID { return "" }
func (conn *fakeConn) RemotePublicKey() ic.PubKey { return nil }
func (conn *fakeConn) LocalMultiaddr() ma.Multiaddr { return nil }
func (conn *fakeConn) RemoteMultiaddr() ma.Multiaddr { return nil }
func (conn *fakeConn) ID() string { return "" }
func (conn *fakeConn) NewStream(context.Context) (libp2p_network.Stream, error) { return nil, nil }
func (conn *fakeConn) GetStreams() []libp2p_network.Stream { return nil }
func (conn *fakeConn) Stat() libp2p_network.Stat { return libp2p_network.Stat{} }

@ -0,0 +1,112 @@
package sync
import (
"fmt"
"github.com/ethereum/go-ethereum/rlp"
protobuf "github.com/golang/protobuf/proto"
"github.com/harmony-one/harmony/block"
"github.com/harmony-one/harmony/p2p/stream/common/requestmanager"
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/harmony-one/harmony/shard"
"github.com/pkg/errors"
)
var (
errUnknownReqType = errors.New("unknown request")
)
// syncResponse is the sync protocol response which implements sttypes.Response
type syncResponse struct {
pb *syncpb.Response
}
// ReqID return the request ID of the response
func (resp *syncResponse) ReqID() uint64 {
return resp.pb.ReqId
}
// GetProtobufMsg return the raw protobuf message
func (resp *syncResponse) GetProtobufMsg() protobuf.Message {
return resp.pb
}
func (resp *syncResponse) String() string {
return fmt.Sprintf("[SyncResponse %v]", resp.pb.String())
}
// EpochStateResult is the result for GetEpochStateQuery
type EpochStateResult struct {
Header *block.Header
State *shard.State
}
func epochStateResultFromResponse(resp sttypes.Response) (*EpochStateResult, error) {
sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil {
return nil, errors.New("not sync response")
}
if errResp := sResp.pb.GetErrorResponse(); errResp != nil {
return nil, errors.New(errResp.Error)
}
gesResp := sResp.pb.GetGetEpochStateResponse()
if gesResp == nil {
return nil, errors.New("not GetEpochStateResponse")
}
var (
headerBytes = gesResp.HeaderBytes
ssBytes = gesResp.ShardState
header *block.Header
ss *shard.State
)
if len(headerBytes) > 0 {
if err := rlp.DecodeBytes(headerBytes, &header); err != nil {
return nil, err
}
}
if len(ssBytes) > 0 {
// here shard state is not encoded with legacy rules
if err := rlp.DecodeBytes(ssBytes, &ss); err != nil {
return nil, err
}
}
return &EpochStateResult{
Header: header,
State: ss,
}, nil
}
func (res *EpochStateResult) toMessage(rid uint64) (*syncpb.Message, error) {
headerBytes, err := rlp.EncodeToBytes(res.Header)
if err != nil {
return nil, err
}
// Shard state is not wrapped here, means no legacy shard state encoding rule as
// in shard.EncodeWrapper.
ssBytes, err := rlp.EncodeToBytes(res.State)
if err != nil {
return nil, err
}
return syncpb.MakeGetEpochStateResponseMessage(rid, headerBytes, ssBytes), nil
}
// Option is the additional option to do requests.
// Currently, two options are supported:
// 1. WithHighPriority - do the request in high priority.
// 2. WithBlacklist - do the request without the given stream ids as blacklist
// 3. WithWhitelist - do the request only with the given stream ids
type Option = requestmanager.RequestOption
var (
// WithHighPriority instruct the request manager to do the request with high
// priority
WithHighPriority = requestmanager.WithHighPriority
// WithBlacklist instruct the request manager not to assign the request to the
// given streamID
WithBlacklist = requestmanager.WithBlacklist
// WithWhitelist instruct the request manager only to assign the request to the
// given streamID
WithWhitelist = requestmanager.WithWhitelist
)
Loading…
Cancel
Save