commit
876e07892d
@ -0,0 +1,135 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/harmony-one/harmony/block" |
||||
"github.com/harmony-one/harmony/consensus/engine" |
||||
"github.com/harmony-one/harmony/core/types" |
||||
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// chainHelper is the adapter for blockchain which is friendly to unit test.
|
||||
type chainHelper interface { |
||||
getCurrentBlockNumber() uint64 |
||||
getBlockHashes(bns []uint64) []common.Hash |
||||
getBlocksByNumber(bns []uint64) ([]*types.Block, error) |
||||
getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) |
||||
} |
||||
|
||||
type chainHelperImpl struct { |
||||
chain engine.ChainReader |
||||
schedule shardingconfig.Schedule |
||||
} |
||||
|
||||
func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl { |
||||
return &chainHelperImpl{ |
||||
chain: chain, |
||||
schedule: schedule, |
||||
} |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getCurrentBlockNumber() uint64 { |
||||
return ch.chain.CurrentBlock().NumberU64() |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getBlockHashes(bns []uint64) []common.Hash { |
||||
hashes := make([]common.Hash, 0, len(bns)) |
||||
for _, bn := range bns { |
||||
var h common.Hash |
||||
header := ch.chain.GetHeaderByNumber(bn) |
||||
if header != nil { |
||||
h = header.Hash() |
||||
} |
||||
hashes = append(hashes, h) |
||||
} |
||||
return hashes |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getBlocksByNumber(bns []uint64) ([]*types.Block, error) { |
||||
var ( |
||||
blocks = make([]*types.Block, 0, len(bns)) |
||||
) |
||||
for _, bn := range bns { |
||||
var ( |
||||
block *types.Block |
||||
err error |
||||
) |
||||
header := ch.chain.GetHeaderByNumber(bn) |
||||
if header != nil { |
||||
block, err = ch.getBlockWithSigByHeader(header) |
||||
if err != nil { |
||||
return nil, errors.Wrapf(err, "get block %v at %v", header.Hash().String(), header.Number()) |
||||
} |
||||
} |
||||
blocks = append(blocks, block) |
||||
|
||||
} |
||||
return blocks, nil |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) { |
||||
var ( |
||||
blocks = make([]*types.Block, 0, len(hs)) |
||||
) |
||||
for _, h := range hs { |
||||
var ( |
||||
block *types.Block |
||||
err error |
||||
) |
||||
header := ch.chain.GetHeaderByHash(h) |
||||
if header != nil { |
||||
block, err = ch.getBlockWithSigByHeader(header) |
||||
if err != nil { |
||||
return nil, errors.Wrapf(err, "get block %v at %v", header.Hash().String(), header.Number()) |
||||
} |
||||
} |
||||
blocks = append(blocks, block) |
||||
} |
||||
return blocks, nil |
||||
} |
||||
|
||||
var errBlockNotFound = errors.New("block not found") |
||||
|
||||
func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) { |
||||
b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64()) |
||||
if b == nil { |
||||
return nil, nil |
||||
} |
||||
commitSig, err := ch.getBlockSigAndBitmap(header) |
||||
if err != nil { |
||||
return nil, errors.New("missing commit signature") |
||||
} |
||||
b.SetCurrentCommitSig(commitSig) |
||||
return b, nil |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) { |
||||
sb := ch.getBlockSigFromNextBlock(header) |
||||
if len(sb) != 0 { |
||||
return sb, nil |
||||
} |
||||
// Note: some commit sig read from db is different from [nextHeader.sig, nextHeader.bitMap]
|
||||
// nextBlock data is better to be used.
|
||||
return ch.getBlockSigFromDB(header) |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getBlockSigFromNextBlock(header *block.Header) []byte { |
||||
nextBN := header.Number().Uint64() + 1 |
||||
nextHeader := ch.chain.GetHeaderByNumber(nextBN) |
||||
if nextHeader == nil { |
||||
return nil |
||||
} |
||||
|
||||
sigBytes := nextHeader.LastCommitSignature() |
||||
bitMap := nextHeader.LastCommitBitmap() |
||||
sb := make([]byte, len(sigBytes)+len(bitMap)) |
||||
copy(sb[:], sigBytes[:]) |
||||
copy(sb[len(sigBytes):], bitMap[:]) |
||||
|
||||
return sb |
||||
} |
||||
|
||||
func (ch *chainHelperImpl) getBlockSigFromDB(header *block.Header) ([]byte, error) { |
||||
return ch.chain.ReadCommitSig(header.Number().Uint64()) |
||||
} |
@ -0,0 +1,166 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"math/big" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
protobuf "github.com/golang/protobuf/proto" |
||||
"github.com/harmony-one/harmony/block" |
||||
"github.com/harmony-one/harmony/core/types" |
||||
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||
) |
||||
|
||||
type testChainHelper struct{} |
||||
|
||||
func (tch *testChainHelper) getCurrentBlockNumber() uint64 { |
||||
return 100 |
||||
} |
||||
|
||||
func (tch *testChainHelper) getBlocksByNumber(bns []uint64) ([]*types.Block, error) { |
||||
blocks := make([]*types.Block, 0, len(bns)) |
||||
for _, bn := range bns { |
||||
blocks = append(blocks, makeTestBlock(bn)) |
||||
} |
||||
return blocks, nil |
||||
} |
||||
|
||||
func (tch *testChainHelper) getBlockHashes(bns []uint64) []common.Hash { |
||||
hs := make([]common.Hash, 0, len(bns)) |
||||
for _, bn := range bns { |
||||
hs = append(hs, numberToHash(bn)) |
||||
} |
||||
return hs |
||||
} |
||||
|
||||
func (tch *testChainHelper) getBlocksByHashes(hs []common.Hash) ([]*types.Block, error) { |
||||
bs := make([]*types.Block, 0, len(hs)) |
||||
for _, h := range hs { |
||||
bn := hashToNumber(h) |
||||
bs = append(bs, makeTestBlock(bn)) |
||||
} |
||||
return bs, nil |
||||
} |
||||
|
||||
func numberToHash(bn uint64) common.Hash { |
||||
var h common.Hash |
||||
binary.LittleEndian.PutUint64(h[:], bn) |
||||
return h |
||||
} |
||||
|
||||
func hashToNumber(h common.Hash) uint64 { |
||||
return binary.LittleEndian.Uint64(h[:]) |
||||
} |
||||
|
||||
func checkBlocksResult(bns []uint64, b []byte) error { |
||||
var msg = &syncpb.Message{} |
||||
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||
return err |
||||
} |
||||
gbResp, err := msg.GetBlocksByNumberResponse() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if len(gbResp.BlocksBytes) == 0 { |
||||
return errors.New("nil response from GetBlocksByNumber") |
||||
} |
||||
blocks, err := decodeBlocksBytes(gbResp.BlocksBytes) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if len(blocks) != len(bns) { |
||||
return errors.New("unexpected blocks number") |
||||
} |
||||
for i, bn := range bns { |
||||
blk := blocks[i] |
||||
if bn != blk.NumberU64() { |
||||
return errors.New("unexpected number of a block") |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func makeTestBlock(bn uint64) *types.Block { |
||||
header := testHeader.Copy() |
||||
header.SetNumber(big.NewInt(int64(bn))) |
||||
return types.NewBlockWithHeader(&block.Header{Header: header}) |
||||
} |
||||
|
||||
func decodeBlocksBytes(bbs [][]byte) ([]*types.Block, error) { |
||||
blocks := make([]*types.Block, 0, len(bbs)) |
||||
|
||||
for _, bb := range bbs { |
||||
var block *types.Block |
||||
if err := rlp.DecodeBytes(bb, &block); err != nil { |
||||
return nil, err |
||||
} |
||||
blocks = append(blocks, block) |
||||
} |
||||
return blocks, nil |
||||
} |
||||
|
||||
func checkBlockNumberResult(b []byte) error { |
||||
var msg = &syncpb.Message{} |
||||
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||
return err |
||||
} |
||||
gnResp, err := msg.GetBlockNumberResponse() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if gnResp.Number != testCurBlockNumber { |
||||
return fmt.Errorf("unexpected block number: %v / %v", gnResp.Number, testCurBlockNumber) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func checkBlockHashesResult(b []byte, bns []uint64) error { |
||||
var msg = &syncpb.Message{} |
||||
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||
return err |
||||
} |
||||
bhResp, err := msg.GetBlockHashesResponse() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
got := bhResp.Hashes |
||||
if len(got) != len(bns) { |
||||
return errors.New("unexpected size") |
||||
} |
||||
for i, bn := range bns { |
||||
expect := numberToHash(bn) |
||||
if !bytes.Equal(expect[:], got[i]) { |
||||
return errors.New("unexpected hash") |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func checkBlocksByHashesResult(b []byte, hs []common.Hash) error { |
||||
var msg = &syncpb.Message{} |
||||
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||
return err |
||||
} |
||||
bhResp, err := msg.GetBlocksByHashesResponse() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if len(hs) != len(bhResp.BlocksBytes) { |
||||
return errors.New("unexpected size") |
||||
} |
||||
for i, h := range hs { |
||||
num := hashToNumber(h) |
||||
var blk *types.Block |
||||
if err := rlp.DecodeBytes(bhResp.BlocksBytes[i], &blk); err != nil { |
||||
return err |
||||
} |
||||
if blk.NumberU64() != num { |
||||
return fmt.Errorf("unexpected number %v != %v", blk.NumberU64(), num) |
||||
} |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,344 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
protobuf "github.com/golang/protobuf/proto" |
||||
"github.com/harmony-one/harmony/core/types" |
||||
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// GetBlocksByNumber do getBlocksByNumberRequest through sync stream protocol.
|
||||
// Return the block as result, target stream id, and error
|
||||
func (p *Protocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...Option) ([]*types.Block, sttypes.StreamID, error) { |
||||
if len(bns) > GetBlocksByNumAmountCap { |
||||
return nil, "", fmt.Errorf("number of blocks exceed cap of %v", GetBlocksByNumAmountCap) |
||||
} |
||||
|
||||
req := newGetBlocksByNumberRequest(bns) |
||||
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||
if err != nil { |
||||
// At this point, error can be context canceled, context timed out, or waiting queue
|
||||
// is already full.
|
||||
return nil, stid, err |
||||
} |
||||
|
||||
// Parse and return blocks
|
||||
blocks, err := req.getBlocksFromResponse(resp) |
||||
if err != nil { |
||||
return nil, stid, err |
||||
} |
||||
return blocks, stid, nil |
||||
} |
||||
|
||||
// GetCurrentBlockNumber get the current block number from remote node
|
||||
func (p *Protocol) GetCurrentBlockNumber(ctx context.Context, opts ...Option) (uint64, sttypes.StreamID, error) { |
||||
req := newGetBlockNumberRequest() |
||||
|
||||
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||
if err != nil { |
||||
return 0, stid, err |
||||
} |
||||
|
||||
bn, err := req.getNumberFromResponse(resp) |
||||
if err != nil { |
||||
return bn, stid, err |
||||
} |
||||
return bn, stid, nil |
||||
} |
||||
|
||||
// GetBlockHashes do getBlockHashesRequest through sync stream protocol.
|
||||
// Return the hash of the given block number. If a block is unknown, the hash will be emptyHash.
|
||||
func (p *Protocol) GetBlockHashes(ctx context.Context, bns []uint64, opts ...Option) ([]common.Hash, sttypes.StreamID, error) { |
||||
if len(bns) > GetBlockHashesAmountCap { |
||||
return nil, "", fmt.Errorf("number of requested numbers exceed limit") |
||||
} |
||||
|
||||
req := newGetBlockHashesRequest(bns) |
||||
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||
if err != nil { |
||||
return nil, stid, err |
||||
} |
||||
hashes, err := req.getHashesFromResponse(resp) |
||||
if err != nil { |
||||
return nil, stid, err |
||||
} |
||||
return hashes, stid, nil |
||||
} |
||||
|
||||
// GetBlocksByHashes do getBlocksByHashesRequest through sync stream protocol.
|
||||
func (p *Protocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts ...Option) ([]*types.Block, sttypes.StreamID, error) { |
||||
if len(hs) > GetBlocksByHashesAmountCap { |
||||
return nil, "", fmt.Errorf("number of requested hashes exceed limit") |
||||
} |
||||
req := newGetBlocksByHashesRequest(hs) |
||||
resp, stid, err := p.rm.DoRequest(ctx, req, opts...) |
||||
if err != nil { |
||||
return nil, stid, err |
||||
} |
||||
blocks, err := req.getBlocksFromResponse(resp) |
||||
if err != nil { |
||||
return nil, stid, err |
||||
} |
||||
return blocks, stid, nil |
||||
} |
||||
|
||||
// getBlocksByNumberRequest is the request for get block by numbers which implements
|
||||
// sttypes.Request interface
|
||||
type getBlocksByNumberRequest struct { |
||||
bns []uint64 |
||||
pbReq *syncpb.Request |
||||
} |
||||
|
||||
func newGetBlocksByNumberRequest(bns []uint64) *getBlocksByNumberRequest { |
||||
pbReq := syncpb.MakeGetBlocksByNumRequest(bns) |
||||
return &getBlocksByNumberRequest{ |
||||
bns: bns, |
||||
pbReq: pbReq, |
||||
} |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) ReqID() uint64 { |
||||
return req.pbReq.GetReqId() |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) SetReqID(val uint64) { |
||||
req.pbReq.ReqId = val |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) String() string { |
||||
ss := make([]string, 0, len(req.bns)) |
||||
for _, bn := range req.bns { |
||||
ss = append(ss, strconv.Itoa(int(bn))) |
||||
} |
||||
bnsStr := strings.Join(ss, ",") |
||||
return fmt.Sprintf("REQUEST [GetBlockByNumber: %s]", bnsStr) |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||
return target.Version.GreaterThanOrEqual(MinVersion) |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) Encode() ([]byte, error) { |
||||
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||
return protobuf.Marshal(msg) |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) getBlocksFromResponse(resp sttypes.Response) ([]*types.Block, error) { |
||||
sResp, ok := resp.(*syncResponse) |
||||
if !ok || sResp == nil { |
||||
return nil, errors.New("not sync response") |
||||
} |
||||
blockBytes, sigs, err := req.parseBlockBytesAndSigs(sResp) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
blocks := make([]*types.Block, 0, len(blockBytes)) |
||||
for i, bb := range blockBytes { |
||||
var block *types.Block |
||||
if err := rlp.DecodeBytes(bb, &block); err != nil { |
||||
return nil, errors.Wrap(err, "[GetBlocksByNumResponse]") |
||||
} |
||||
if block != nil { |
||||
block.SetCurrentCommitSig(sigs[i]) |
||||
} |
||||
blocks = append(blocks, block) |
||||
} |
||||
return blocks, nil |
||||
} |
||||
|
||||
func (req *getBlocksByNumberRequest) parseBlockBytesAndSigs(resp *syncResponse) ([][]byte, [][]byte, error) { |
||||
if errResp := resp.pb.GetErrorResponse(); errResp != nil { |
||||
return nil, nil, errors.New(errResp.Error) |
||||
} |
||||
gbResp := resp.pb.GetGetBlocksByNumResponse() |
||||
if gbResp == nil { |
||||
return nil, nil, errors.New("response not GetBlockByNumber") |
||||
} |
||||
if len(gbResp.BlocksBytes) != len(gbResp.CommitSig) { |
||||
return nil, nil, fmt.Errorf("commit sigs size not expected: %v / %v", |
||||
len(gbResp.CommitSig), len(gbResp.BlocksBytes)) |
||||
} |
||||
return gbResp.BlocksBytes, gbResp.CommitSig, nil |
||||
} |
||||
|
||||
type getBlockNumberRequest struct { |
||||
pbReq *syncpb.Request |
||||
} |
||||
|
||||
func newGetBlockNumberRequest() *getBlockNumberRequest { |
||||
pbReq := syncpb.MakeGetBlockNumberRequest() |
||||
return &getBlockNumberRequest{ |
||||
pbReq: pbReq, |
||||
} |
||||
} |
||||
|
||||
func (req *getBlockNumberRequest) ReqID() uint64 { |
||||
return req.pbReq.GetReqId() |
||||
} |
||||
|
||||
func (req *getBlockNumberRequest) SetReqID(val uint64) { |
||||
req.pbReq.ReqId = val |
||||
} |
||||
|
||||
func (req *getBlockNumberRequest) String() string { |
||||
return fmt.Sprintf("REQUEST [GetBlockNumber]") |
||||
} |
||||
|
||||
func (req *getBlockNumberRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||
return target.Version.GreaterThanOrEqual(MinVersion) |
||||
} |
||||
|
||||
func (req *getBlockNumberRequest) Encode() ([]byte, error) { |
||||
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||
return protobuf.Marshal(msg) |
||||
} |
||||
|
||||
func (req *getBlockNumberRequest) getNumberFromResponse(resp sttypes.Response) (uint64, error) { |
||||
sResp, ok := resp.(*syncResponse) |
||||
if !ok || sResp == nil { |
||||
return 0, errors.New("not sync response") |
||||
} |
||||
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||
return 0, errors.New(errResp.Error) |
||||
} |
||||
gnResp := sResp.pb.GetGetBlockNumberResponse() |
||||
if gnResp == nil { |
||||
return 0, errors.New("response not GetBlockNumber") |
||||
} |
||||
return gnResp.Number, nil |
||||
} |
||||
|
||||
type getBlockHashesRequest struct { |
||||
bns []uint64 |
||||
pbReq *syncpb.Request |
||||
} |
||||
|
||||
func newGetBlockHashesRequest(bns []uint64) *getBlockHashesRequest { |
||||
pbReq := syncpb.MakeGetBlockHashesRequest(bns) |
||||
return &getBlockHashesRequest{ |
||||
bns: bns, |
||||
pbReq: pbReq, |
||||
} |
||||
} |
||||
|
||||
func (req *getBlockHashesRequest) ReqID() uint64 { |
||||
return req.pbReq.ReqId |
||||
} |
||||
|
||||
func (req *getBlockHashesRequest) SetReqID(val uint64) { |
||||
req.pbReq.ReqId = val |
||||
} |
||||
|
||||
func (req *getBlockHashesRequest) String() string { |
||||
ss := make([]string, 0, len(req.bns)) |
||||
for _, bn := range req.bns { |
||||
ss = append(ss, strconv.Itoa(int(bn))) |
||||
} |
||||
bnsStr := strings.Join(ss, ",") |
||||
return fmt.Sprintf("REQUEST [GetBlockHashes: %s]", bnsStr) |
||||
} |
||||
|
||||
func (req *getBlockHashesRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||
return target.Version.GreaterThanOrEqual(MinVersion) |
||||
} |
||||
|
||||
func (req *getBlockHashesRequest) Encode() ([]byte, error) { |
||||
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||
return protobuf.Marshal(msg) |
||||
} |
||||
|
||||
func (req *getBlockHashesRequest) getHashesFromResponse(resp sttypes.Response) ([]common.Hash, error) { |
||||
sResp, ok := resp.(*syncResponse) |
||||
if !ok || sResp == nil { |
||||
return nil, errors.New("not sync response") |
||||
} |
||||
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||
return nil, errors.New(errResp.Error) |
||||
} |
||||
bhResp := sResp.pb.GetGetBlockHashesResponse() |
||||
if bhResp == nil { |
||||
return nil, errors.New("response not GetBlockHashes") |
||||
} |
||||
hashBytes := bhResp.Hashes |
||||
return bytesToHashes(hashBytes), nil |
||||
} |
||||
|
||||
type getBlocksByHashesRequest struct { |
||||
hashes []common.Hash |
||||
pbReq *syncpb.Request |
||||
} |
||||
|
||||
func newGetBlocksByHashesRequest(hashes []common.Hash) *getBlocksByHashesRequest { |
||||
pbReq := syncpb.MakeGetBlocksByHashesRequest(hashes) |
||||
return &getBlocksByHashesRequest{ |
||||
hashes: hashes, |
||||
pbReq: pbReq, |
||||
} |
||||
} |
||||
|
||||
func (req *getBlocksByHashesRequest) ReqID() uint64 { |
||||
return req.pbReq.GetReqId() |
||||
} |
||||
|
||||
func (req *getBlocksByHashesRequest) SetReqID(val uint64) { |
||||
req.pbReq.ReqId = val |
||||
} |
||||
|
||||
func (req *getBlocksByHashesRequest) String() string { |
||||
hashStrs := make([]string, 0, len(req.hashes)) |
||||
for _, h := range req.hashes { |
||||
hashStrs = append(hashStrs, fmt.Sprintf("%x", h[:])) |
||||
} |
||||
hStr := strings.Join(hashStrs, ", ") |
||||
return fmt.Sprintf("REQUEST [GetBlocksByHashes: %v]", hStr) |
||||
} |
||||
|
||||
func (req *getBlocksByHashesRequest) IsSupportedByProto(target sttypes.ProtoSpec) bool { |
||||
return target.Version.GreaterThanOrEqual(MinVersion) |
||||
} |
||||
|
||||
func (req *getBlocksByHashesRequest) Encode() ([]byte, error) { |
||||
msg := syncpb.MakeMessageFromRequest(req.pbReq) |
||||
return protobuf.Marshal(msg) |
||||
} |
||||
|
||||
func (req *getBlocksByHashesRequest) getBlocksFromResponse(resp sttypes.Response) ([]*types.Block, error) { |
||||
sResp, ok := resp.(*syncResponse) |
||||
if !ok || sResp == nil { |
||||
return nil, errors.New("not sync response") |
||||
} |
||||
if errResp := sResp.pb.GetErrorResponse(); errResp != nil { |
||||
return nil, errors.New(errResp.Error) |
||||
} |
||||
bhResp := sResp.pb.GetGetBlocksByHashesResponse() |
||||
if bhResp == nil { |
||||
return nil, errors.New("response not GetBlocksByHashes") |
||||
} |
||||
var ( |
||||
blockBytes = bhResp.BlocksBytes |
||||
sigs = bhResp.CommitSig |
||||
) |
||||
if len(blockBytes) != len(sigs) { |
||||
return nil, fmt.Errorf("sig size not expected: %v / %v", len(sigs), len(blockBytes)) |
||||
} |
||||
blocks := make([]*types.Block, 0, len(blockBytes)) |
||||
for i, bb := range blockBytes { |
||||
var block *types.Block |
||||
if err := rlp.DecodeBytes(bb, &block); err != nil { |
||||
return nil, errors.Wrap(err, "[GetBlocksByHashesResponse]") |
||||
} |
||||
if block != nil { |
||||
block.SetCurrentCommitSig(sigs[i]) |
||||
} |
||||
blocks = append(blocks, block) |
||||
} |
||||
return blocks, nil |
||||
} |
@ -0,0 +1,393 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
"github.com/harmony-one/harmony/block" |
||||
headerV3 "github.com/harmony-one/harmony/block/v3" |
||||
"github.com/harmony-one/harmony/core/types" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
) |
||||
|
||||
var ( |
||||
_ sttypes.Request = &getBlocksByNumberRequest{} |
||||
_ sttypes.Request = &getBlockNumberRequest{} |
||||
_ sttypes.Response = &syncResponse{&syncpb.Response{}} |
||||
) |
||||
|
||||
var ( |
||||
initStreamIDs = []sttypes.StreamID{ |
||||
makeTestStreamID(0), |
||||
makeTestStreamID(1), |
||||
makeTestStreamID(2), |
||||
makeTestStreamID(3), |
||||
} |
||||
) |
||||
|
||||
var ( |
||||
testHeader = &block.Header{Header: headerV3.NewHeader()} |
||||
testBlock = types.NewBlockWithHeader(testHeader) |
||||
testHeaderBytes, _ = rlp.EncodeToBytes(testHeader) |
||||
testBlockBytes, _ = rlp.EncodeToBytes(testBlock) |
||||
testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) |
||||
|
||||
testCurBlockNumber uint64 = 100 |
||||
testBlockNumberResponse = syncpb.MakeGetBlockNumberResponse(0, testCurBlockNumber) |
||||
|
||||
testHash = numberToHash(100) |
||||
testBlockHashesResponse = syncpb.MakeGetBlockHashesResponse(0, []common.Hash{testHash}) |
||||
|
||||
testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) |
||||
|
||||
testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error")) |
||||
) |
||||
|
||||
func TestProtocol_GetBlocksByNumber(t *testing.T) { |
||||
tests := []struct { |
||||
getResponse getResponseFn |
||||
expErr error |
||||
expStID sttypes.StreamID |
||||
}{ |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: nil, |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockNumberResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("not GetBlockByNumber"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: nil, |
||||
expErr: errors.New("get response error"), |
||||
expStID: "", |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testErrorResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("test error"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
} |
||||
|
||||
for i, test := range tests { |
||||
protocol := makeTestProtocol(test.getResponse) |
||||
blocks, stid, err := protocol.GetBlocksByNumber(context.Background(), []uint64{0}) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if stid != test.expStID { |
||||
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||
} |
||||
if test.expErr == nil && (len(blocks) == 0) { |
||||
t.Errorf("Test %v: zero blocks delivered", i) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestProtocol_GetCurrentBlockNumber(t *testing.T) { |
||||
tests := []struct { |
||||
getResponse getResponseFn |
||||
expErr error |
||||
expStID sttypes.StreamID |
||||
}{ |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockNumberResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: nil, |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("not GetBlockNumber"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: nil, |
||||
expErr: errors.New("get response error"), |
||||
expStID: "", |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testErrorResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("test error"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
} |
||||
|
||||
for i, test := range tests { |
||||
protocol := makeTestProtocol(test.getResponse) |
||||
res, stid, err := protocol.GetCurrentBlockNumber(context.Background()) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if stid != test.expStID { |
||||
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||
} |
||||
if test.expErr == nil { |
||||
if res != testCurBlockNumber { |
||||
t.Errorf("Test %v: block number not expected: %v / %v", i, res, testCurBlockNumber) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestProtocol_GetBlockHashes(t *testing.T) { |
||||
tests := []struct { |
||||
getResponse getResponseFn |
||||
expErr error |
||||
expStID sttypes.StreamID |
||||
}{ |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockHashesResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: nil, |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("not GetBlockHashes"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: nil, |
||||
expErr: errors.New("get response error"), |
||||
expStID: "", |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testErrorResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("test error"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
} |
||||
|
||||
for i, test := range tests { |
||||
protocol := makeTestProtocol(test.getResponse) |
||||
res, stid, err := protocol.GetBlockHashes(context.Background(), []uint64{100}) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if stid != test.expStID { |
||||
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||
} |
||||
if test.expErr == nil { |
||||
if len(res) != 1 { |
||||
t.Errorf("Test %v: size not 1", i) |
||||
} |
||||
if res[0] != testHash { |
||||
t.Errorf("Test %v: hash not expected", i) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestProtocol_GetBlocksByHashes(t *testing.T) { |
||||
tests := []struct { |
||||
getResponse getResponseFn |
||||
expErr error |
||||
expStID sttypes.StreamID |
||||
}{ |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlocksByHashesResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: nil, |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testBlockResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("not GetBlocksByHashes"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
{ |
||||
getResponse: nil, |
||||
expErr: errors.New("get response error"), |
||||
expStID: "", |
||||
}, |
||||
{ |
||||
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { |
||||
return &syncResponse{ |
||||
pb: testErrorResponse, |
||||
}, makeTestStreamID(0) |
||||
}, |
||||
expErr: errors.New("test error"), |
||||
expStID: makeTestStreamID(0), |
||||
}, |
||||
} |
||||
|
||||
for i, test := range tests { |
||||
protocol := makeTestProtocol(test.getResponse) |
||||
blocks, stid, err := protocol.GetBlocksByHashes(context.Background(), []common.Hash{numberToHash(100)}) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if stid != test.expStID { |
||||
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) |
||||
} |
||||
if test.expErr == nil { |
||||
if len(blocks) != 1 { |
||||
t.Errorf("Test %v: size not 1", i) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) |
||||
|
||||
type testHostRequestManager struct { |
||||
getResponse getResponseFn |
||||
} |
||||
|
||||
func makeTestProtocol(f getResponseFn) *Protocol { |
||||
rm := &testHostRequestManager{f} |
||||
|
||||
streamIDs := make([]sttypes.StreamID, len(initStreamIDs)) |
||||
copy(streamIDs, initStreamIDs) |
||||
sm := &testStreamManager{streamIDs} |
||||
|
||||
rl := ratelimiter.NewRateLimiter(sm, 10, 10) |
||||
|
||||
return &Protocol{ |
||||
rm: rm, |
||||
rl: rl, |
||||
sm: sm, |
||||
} |
||||
} |
||||
|
||||
func (rm *testHostRequestManager) Start() {} |
||||
func (rm *testHostRequestManager) Close() {} |
||||
func (rm *testHostRequestManager) DeliverResponse(sttypes.StreamID, sttypes.Response) {} |
||||
|
||||
func (rm *testHostRequestManager) DoRequest(ctx context.Context, request sttypes.Request, opts ...Option) (sttypes.Response, sttypes.StreamID, error) { |
||||
if rm.getResponse == nil { |
||||
return nil, "", errors.New("get response error") |
||||
} |
||||
resp, stid := rm.getResponse(request) |
||||
return resp, stid, nil |
||||
} |
||||
|
||||
func makeTestStreamID(index int) sttypes.StreamID { |
||||
id := fmt.Sprintf("[test stream %v]", index) |
||||
return sttypes.StreamID(id) |
||||
} |
||||
|
||||
// mock stream manager
|
||||
type testStreamManager struct { |
||||
streamIDs []sttypes.StreamID |
||||
} |
||||
|
||||
func (sm *testStreamManager) Start() {} |
||||
func (sm *testStreamManager) Close() {} |
||||
func (sm *testStreamManager) SubscribeAddStreamEvent(chan<- streammanager.EvtStreamAdded) event.Subscription { |
||||
return nil |
||||
} |
||||
func (sm *testStreamManager) SubscribeRemoveStreamEvent(chan<- streammanager.EvtStreamRemoved) event.Subscription { |
||||
return nil |
||||
} |
||||
|
||||
func (sm *testStreamManager) NewStream(stream sttypes.Stream) error { |
||||
stid := stream.ID() |
||||
for _, id := range sm.streamIDs { |
||||
if id == stid { |
||||
return errors.New("stream already exist") |
||||
} |
||||
} |
||||
sm.streamIDs = append(sm.streamIDs, stid) |
||||
return nil |
||||
} |
||||
|
||||
func (sm *testStreamManager) RemoveStream(stID sttypes.StreamID) error { |
||||
for i, id := range sm.streamIDs { |
||||
if id == stID { |
||||
sm.streamIDs = append(sm.streamIDs[:i], sm.streamIDs[i+1:]...) |
||||
} |
||||
} |
||||
return errors.New("stream not exist") |
||||
} |
||||
|
||||
func (sm *testStreamManager) isStreamExist(stid sttypes.StreamID) bool { |
||||
for _, id := range sm.streamIDs { |
||||
if id == stid { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func (sm *testStreamManager) GetStreams() []sttypes.Stream { |
||||
return nil |
||||
} |
||||
|
||||
func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) { |
||||
return nil, false |
||||
} |
||||
|
||||
func assertError(got, expect error) error { |
||||
if (got == nil) != (expect == nil) { |
||||
return fmt.Errorf("unexpected error: %v / %v", got, expect) |
||||
} |
||||
if got == nil { |
||||
return nil |
||||
} |
||||
if !strings.Contains(got.Error(), expect.Error()) { |
||||
return fmt.Errorf("unexpected error: %v/ %v", got, expect) |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,17 @@ |
||||
package sync |
||||
|
||||
import "time" |
||||
|
||||
const ( |
||||
// GetBlockHashesAmountCap is the cap of GetBlockHashes reqeust
|
||||
GetBlockHashesAmountCap = 50 |
||||
|
||||
// GetBlocksByNumAmountCap is the cap of request of a single GetBlocksByNum request
|
||||
GetBlocksByNumAmountCap = 10 |
||||
|
||||
// GetBlocksByHashesAmountCap is the cap of request of single GetBlocksByHashes request
|
||||
GetBlocksByHashesAmountCap = 10 |
||||
|
||||
// minAdvertiseInterval is the minimum advertise interval
|
||||
minAdvertiseInterval = 1 * time.Minute |
||||
) |
@ -0,0 +1,260 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"context" |
||||
"strconv" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/harmony-one/harmony/consensus/engine" |
||||
nodeconfig "github.com/harmony-one/harmony/internal/configs/node" |
||||
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" |
||||
"github.com/harmony-one/harmony/internal/utils" |
||||
"github.com/harmony-one/harmony/p2p/discovery" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/ratelimiter" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/requestmanager" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
"github.com/hashicorp/go-version" |
||||
libp2p_host "github.com/libp2p/go-libp2p-core/host" |
||||
libp2p_network "github.com/libp2p/go-libp2p-core/network" |
||||
"github.com/rs/zerolog" |
||||
) |
||||
|
||||
const ( |
||||
// serviceSpecifier is the specifier for the service.
|
||||
serviceSpecifier = "sync" |
||||
) |
||||
|
||||
var ( |
||||
version100, _ = version.NewVersion("1.0.0") |
||||
|
||||
// MyVersion is the version of sync protocol of the local node
|
||||
MyVersion = version100 |
||||
|
||||
// MinVersion is the minimum version for matching function
|
||||
MinVersion = version100 |
||||
) |
||||
|
||||
type ( |
||||
// Protocol is the protocol for sync streaming
|
||||
Protocol struct { |
||||
chain engine.ChainReader // provide SYNC data
|
||||
schedule shardingconfig.Schedule // provide schedule information
|
||||
rl ratelimiter.RateLimiter // limit the incoming request rate
|
||||
sm streammanager.StreamManager // stream management
|
||||
rm requestmanager.RequestManager // deliver the response from stream
|
||||
disc discovery.Discovery |
||||
|
||||
config Config |
||||
logger zerolog.Logger |
||||
|
||||
ctx context.Context |
||||
cancel func() |
||||
closeC chan struct{} |
||||
} |
||||
|
||||
// Config is the sync protocol config
|
||||
Config struct { |
||||
Chain engine.ChainReader |
||||
Host libp2p_host.Host |
||||
Discovery discovery.Discovery |
||||
ShardID nodeconfig.ShardID |
||||
Network nodeconfig.NetworkType |
||||
|
||||
// stream manager config
|
||||
SmSoftLowCap int |
||||
SmHardLowCap int |
||||
SmHiCap int |
||||
DiscBatch int |
||||
} |
||||
) |
||||
|
||||
// NewProtocol creates a new sync protocol
|
||||
func NewProtocol(config Config) *Protocol { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
||||
sp := &Protocol{ |
||||
chain: config.Chain, |
||||
disc: config.Discovery, |
||||
config: config, |
||||
ctx: ctx, |
||||
cancel: cancel, |
||||
closeC: make(chan struct{}), |
||||
} |
||||
smConfig := streammanager.Config{ |
||||
SoftLoCap: config.SmSoftLowCap, |
||||
HardLoCap: config.SmHardLowCap, |
||||
HiCap: config.SmHiCap, |
||||
DiscBatch: config.DiscBatch, |
||||
} |
||||
sp.sm = streammanager.NewStreamManager(sp.ProtoID(), config.Host, config.Discovery, |
||||
sp.HandleStream, smConfig) |
||||
|
||||
sp.rl = ratelimiter.NewRateLimiter(sp.sm, 50, 10) |
||||
|
||||
sp.rm = requestmanager.NewRequestManager(sp.sm) |
||||
|
||||
sp.logger = utils.Logger().With().Str("Protocol", string(sp.ProtoID())).Logger() |
||||
return sp |
||||
} |
||||
|
||||
// Start starts the sync protocol
|
||||
func (p *Protocol) Start() { |
||||
p.sm.Start() |
||||
p.rm.Start() |
||||
p.rl.Start() |
||||
go p.advertiseLoop() |
||||
} |
||||
|
||||
// Close close the protocol
|
||||
func (p *Protocol) Close() { |
||||
p.rl.Close() |
||||
p.rm.Close() |
||||
p.sm.Close() |
||||
p.cancel() |
||||
close(p.closeC) |
||||
} |
||||
|
||||
// Specifier return the specifier for the protocol
|
||||
func (p *Protocol) Specifier() string { |
||||
return serviceSpecifier + "/" + strconv.Itoa(int(p.config.ShardID)) |
||||
} |
||||
|
||||
// ProtoID return the ProtoID of the sync protocol
|
||||
func (p *Protocol) ProtoID() sttypes.ProtoID { |
||||
return p.protoIDByVersion(MyVersion) |
||||
} |
||||
|
||||
// Version returns the sync protocol version
|
||||
func (p *Protocol) Version() *version.Version { |
||||
return MyVersion |
||||
} |
||||
|
||||
// Match checks the compatibility to the target protocol ID.
|
||||
func (p *Protocol) Match(targetID string) bool { |
||||
target, err := sttypes.ProtoIDToProtoSpec(sttypes.ProtoID(targetID)) |
||||
if err != nil { |
||||
return false |
||||
} |
||||
if target.Service != serviceSpecifier { |
||||
return false |
||||
} |
||||
if target.NetworkType != p.config.Network { |
||||
return false |
||||
} |
||||
if target.ShardID != p.config.ShardID { |
||||
return false |
||||
} |
||||
if target.Version.LessThan(MinVersion) { |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// HandleStream is the stream handle function being registered to libp2p.
|
||||
func (p *Protocol) HandleStream(raw libp2p_network.Stream) { |
||||
p.logger.Info().Str("stream", raw.ID()).Msg("handle new sync stream") |
||||
st := p.wrapStream(raw) |
||||
if err := p.sm.NewStream(st); err != nil { |
||||
// Possibly we have reach the hard limit of the stream
|
||||
p.logger.Warn().Err(err).Str("stream ID", string(st.ID())). |
||||
Msg("failed to add new stream") |
||||
return |
||||
} |
||||
st.run() |
||||
} |
||||
|
||||
func (p *Protocol) advertiseLoop() { |
||||
for { |
||||
sleep := p.advertise() |
||||
select { |
||||
case <-time.After(sleep): |
||||
case <-p.closeC: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// advertise will advertise all compatible protocol versions for helping nodes running low
|
||||
// version
|
||||
func (p *Protocol) advertise() time.Duration { |
||||
var nextWait time.Duration |
||||
|
||||
pids := p.supportedProtoIDs() |
||||
for _, pid := range pids { |
||||
w, e := p.disc.Advertise(p.ctx, string(pid)) |
||||
if e != nil { |
||||
p.logger.Warn().Err(e).Str("protocol", string(pid)). |
||||
Msg("cannot advertise sync protocol") |
||||
continue |
||||
} |
||||
if nextWait == 0 || nextWait > w { |
||||
nextWait = w |
||||
} |
||||
} |
||||
if nextWait < minAdvertiseInterval { |
||||
nextWait = minAdvertiseInterval |
||||
} |
||||
return nextWait |
||||
} |
||||
|
||||
func (p *Protocol) supportedProtoIDs() []sttypes.ProtoID { |
||||
vs := p.supportedVersions() |
||||
|
||||
pids := make([]sttypes.ProtoID, 0, len(vs)) |
||||
for _, v := range vs { |
||||
pids = append(pids, p.protoIDByVersion(v)) |
||||
} |
||||
return pids |
||||
} |
||||
|
||||
func (p *Protocol) supportedVersions() []*version.Version { |
||||
return []*version.Version{version100} |
||||
} |
||||
|
||||
func (p *Protocol) protoIDByVersion(v *version.Version) sttypes.ProtoID { |
||||
spec := sttypes.ProtoSpec{ |
||||
Service: serviceSpecifier, |
||||
NetworkType: p.config.Network, |
||||
ShardID: p.config.ShardID, |
||||
Version: v, |
||||
} |
||||
return spec.ToProtoID() |
||||
} |
||||
|
||||
// RemoveStream removes the stream of the given stream ID
|
||||
func (p *Protocol) RemoveStream(stID sttypes.StreamID) { |
||||
if stID == "" { |
||||
return |
||||
} |
||||
st, exist := p.sm.GetStreamByID(stID) |
||||
if exist && st != nil { |
||||
st.Close() |
||||
} |
||||
} |
||||
|
||||
// NumStreams return the streams with minimum version.
|
||||
// Note: nodes with sync version smaller than minVersion is not counted.
|
||||
func (p *Protocol) NumStreams() int { |
||||
res := 0 |
||||
sts := p.sm.GetStreams() |
||||
|
||||
for _, st := range sts { |
||||
ps, _ := st.ProtoSpec() |
||||
if ps.Version.GreaterThanOrEqual(MinVersion) { |
||||
res++ |
||||
} |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// GetStreamManager get the underlying stream manager for upper level stream operations
|
||||
func (p *Protocol) GetStreamManager() streammanager.StreamManager { |
||||
return p.sm |
||||
} |
||||
|
||||
// SubscribeAddStreamEvent subscribe the stream add event
|
||||
func (p *Protocol) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription { |
||||
return p.sm.SubscribeAddStreamEvent(ch) |
||||
} |
@ -0,0 +1,95 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/libp2p/go-libp2p-core/discovery" |
||||
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||
) |
||||
|
||||
func TestProtocol_Match(t *testing.T) { |
||||
tests := []struct { |
||||
targetID string |
||||
exp bool |
||||
}{ |
||||
{"harmony/sync/unitest/0/1.0.1", true}, |
||||
{"h123456", false}, |
||||
{"harmony/sync/unitest/0/0.9.9", false}, |
||||
{"harmony/epoch/unitest/0/1.0.1", false}, |
||||
{"harmony/sync/mainnet/0/1.0.1", false}, |
||||
{"harmony/sync/unitest/1/1.0.1", false}, |
||||
} |
||||
|
||||
for i, test := range tests { |
||||
p := &Protocol{ |
||||
config: Config{ |
||||
Network: "unitest", |
||||
ShardID: 0, |
||||
}, |
||||
} |
||||
|
||||
res := p.Match(test.targetID) |
||||
|
||||
if res != test.exp { |
||||
t.Errorf("Test %v: unexpected result %v / %v", i, res, test.exp) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestProtocol_advertiseLoop(t *testing.T) { |
||||
disc := newTestDiscovery(100 * time.Millisecond) |
||||
p := &Protocol{ |
||||
disc: disc, |
||||
closeC: make(chan struct{}), |
||||
} |
||||
|
||||
go p.advertiseLoop() |
||||
|
||||
time.Sleep(150 * time.Millisecond) |
||||
close(p.closeC) |
||||
|
||||
if len(disc.advCnt) != len(p.supportedVersions()) { |
||||
t.Errorf("unexpected advertise topic count: %v / %v", len(disc.advCnt), |
||||
len(p.supportedVersions())) |
||||
} |
||||
for _, cnt := range disc.advCnt { |
||||
if cnt < 1 { |
||||
t.Errorf("unexpected discovery count: %v", cnt) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type testDiscovery struct { |
||||
advCnt map[string]int |
||||
sleep time.Duration |
||||
} |
||||
|
||||
func newTestDiscovery(discInterval time.Duration) *testDiscovery { |
||||
return &testDiscovery{ |
||||
advCnt: make(map[string]int), |
||||
sleep: discInterval, |
||||
} |
||||
} |
||||
|
||||
func (disc *testDiscovery) Start() error { |
||||
return nil |
||||
} |
||||
|
||||
func (disc *testDiscovery) Close() error { |
||||
return nil |
||||
} |
||||
|
||||
func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) { |
||||
disc.advCnt[ns]++ |
||||
return disc.sleep, nil |
||||
} |
||||
|
||||
func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { |
||||
return nil, nil |
||||
} |
||||
|
||||
func (disc *testDiscovery) GetRawDiscovery() discovery.Discovery { |
||||
return nil |
||||
} |
@ -0,0 +1,348 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
protobuf "github.com/golang/protobuf/proto" |
||||
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
libp2p_network "github.com/libp2p/go-libp2p-core/network" |
||||
"github.com/pkg/errors" |
||||
"github.com/rs/zerolog" |
||||
) |
||||
|
||||
// syncStream is the structure for a stream running sync protocol.
|
||||
type syncStream struct { |
||||
// Basic stream
|
||||
*sttypes.BaseStream |
||||
|
||||
protocol *Protocol |
||||
chain chainHelper |
||||
|
||||
// pipeline channels
|
||||
reqC chan *syncpb.Request |
||||
respC chan *syncpb.Response |
||||
|
||||
// close related fields. Concurrent call of close is possible.
|
||||
closeC chan struct{} |
||||
closeStat uint32 |
||||
|
||||
logger zerolog.Logger |
||||
} |
||||
|
||||
// wrapStream wraps the raw libp2p stream to syncStream
|
||||
func (p *Protocol) wrapStream(raw libp2p_network.Stream) *syncStream { |
||||
bs := sttypes.NewBaseStream(raw) |
||||
logger := p.logger.With(). |
||||
Str("ID", string(bs.ID())). |
||||
Str("Remote Protocol", string(bs.ProtoID())). |
||||
Logger() |
||||
|
||||
return &syncStream{ |
||||
BaseStream: bs, |
||||
protocol: p, |
||||
chain: newChainHelper(p.chain, p.schedule), |
||||
reqC: make(chan *syncpb.Request, 100), |
||||
respC: make(chan *syncpb.Response, 100), |
||||
closeC: make(chan struct{}), |
||||
closeStat: 0, |
||||
logger: logger, |
||||
} |
||||
} |
||||
|
||||
func (st *syncStream) run() { |
||||
st.logger.Info().Str("StreamID", string(st.ID())).Msg("running sync protocol on stream") |
||||
defer st.logger.Info().Str("StreamID", string(st.ID())).Msg("end running sync protocol on stream") |
||||
|
||||
go st.handleReqLoop() |
||||
go st.handleRespLoop() |
||||
st.readMsgLoop() |
||||
} |
||||
|
||||
// readMsgLoop is the loop
|
||||
func (st *syncStream) readMsgLoop() { |
||||
for { |
||||
msg, err := st.readMsg() |
||||
if err != nil { |
||||
if err := st.Close(); err != nil { |
||||
st.logger.Err(err).Msg("failed to close sync stream") |
||||
} |
||||
return |
||||
} |
||||
st.deliverMsg(msg) |
||||
} |
||||
} |
||||
|
||||
// deliverMsg process the delivered message and forward to the corresponding channel
|
||||
func (st *syncStream) deliverMsg(msg protobuf.Message) { |
||||
syncMsg := msg.(*syncpb.Message) |
||||
if syncMsg == nil { |
||||
st.logger.Info().Str("message", msg.String()).Msg("received unexpected sync message") |
||||
return |
||||
} |
||||
if req := syncMsg.GetReq(); req != nil { |
||||
go func() { |
||||
select { |
||||
case st.reqC <- req: |
||||
case <-time.After(1 * time.Minute): |
||||
st.logger.Warn().Str("request", req.String()). |
||||
Msg("request handler severely jammed, message dropped") |
||||
} |
||||
}() |
||||
} |
||||
if resp := syncMsg.GetResp(); resp != nil { |
||||
go func() { |
||||
select { |
||||
case st.respC <- resp: |
||||
case <-time.After(1 * time.Minute): |
||||
st.logger.Warn().Str("response", resp.String()). |
||||
Msg("response handler severely jammed, message dropped") |
||||
} |
||||
}() |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (st *syncStream) handleReqLoop() { |
||||
for { |
||||
select { |
||||
case req := <-st.reqC: |
||||
st.protocol.rl.LimitRequest(st.ID()) |
||||
err := st.handleReq(req) |
||||
|
||||
if err != nil { |
||||
st.logger.Info().Err(err).Str("request", req.String()). |
||||
Msg("handle request error. Closing stream") |
||||
if err := st.Close(); err != nil { |
||||
st.logger.Err(err).Msg("failed to close sync stream") |
||||
} |
||||
return |
||||
} |
||||
|
||||
case <-st.closeC: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (st *syncStream) handleRespLoop() { |
||||
for { |
||||
select { |
||||
case resp := <-st.respC: |
||||
st.handleResp(resp) |
||||
|
||||
case <-st.closeC: |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Close stops the stream handling and closes the underlying stream
|
||||
func (st *syncStream) Close() error { |
||||
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1) |
||||
if !notClosed { |
||||
// Already closed by another goroutine. Directly return
|
||||
return nil |
||||
} |
||||
if err := st.protocol.sm.RemoveStream(st.ID()); err != nil { |
||||
st.logger.Err(err).Str("stream ID", string(st.ID())). |
||||
Msg("failed to remove sync stream on close") |
||||
} |
||||
close(st.closeC) |
||||
return st.BaseStream.Close() |
||||
} |
||||
|
||||
// ResetOnClose reset the stream on close
|
||||
func (st *syncStream) ResetOnClose() error { |
||||
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1) |
||||
if !notClosed { |
||||
// Already closed by another goroutine. Directly return
|
||||
return nil |
||||
} |
||||
close(st.closeC) |
||||
return st.BaseStream.ResetOnClose() |
||||
} |
||||
|
||||
func (st *syncStream) handleReq(req *syncpb.Request) error { |
||||
if gnReq := req.GetGetBlockNumberRequest(); gnReq != nil { |
||||
return st.handleGetBlockNumberRequest(req.ReqId) |
||||
} |
||||
if ghReq := req.GetGetBlockHashesRequest(); ghReq != nil { |
||||
return st.handleGetBlockHashesRequest(req.ReqId, ghReq) |
||||
} |
||||
if bnReq := req.GetGetBlocksByNumRequest(); bnReq != nil { |
||||
return st.handleGetBlocksByNumRequest(req.ReqId, bnReq) |
||||
} |
||||
if bhReq := req.GetGetBlocksByHashesRequest(); bhReq != nil { |
||||
return st.handleGetBlocksByHashesRequest(req.ReqId, bhReq) |
||||
} |
||||
// unsupported request type
|
||||
resp := syncpb.MakeErrorResponseMessage(req.ReqId, errUnknownReqType) |
||||
return st.writeMsg(resp) |
||||
} |
||||
|
||||
func (st *syncStream) handleGetBlockNumberRequest(rid uint64) error { |
||||
resp := st.computeBlockNumberResp(rid) |
||||
if err := st.writeMsg(resp); err != nil { |
||||
return errors.Wrap(err, "[GetBlockNumber]: writeMsg") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (st *syncStream) handleGetBlockHashesRequest(rid uint64, req *syncpb.GetBlockHashesRequest) error { |
||||
resp, err := st.computeGetBlockHashesResp(rid, req.Nums) |
||||
if err != nil { |
||||
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||
} |
||||
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||
if err == nil { |
||||
err = writeErr |
||||
} else { |
||||
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||
} |
||||
} |
||||
return errors.Wrap(err, "[GetBlockHashes]") |
||||
} |
||||
|
||||
func (st *syncStream) handleGetBlocksByNumRequest(rid uint64, req *syncpb.GetBlocksByNumRequest) error { |
||||
resp, err := st.computeRespFromBlockNumber(rid, req.Nums) |
||||
if resp == nil && err != nil { |
||||
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||
} |
||||
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||
if err == nil { |
||||
err = writeErr |
||||
} else { |
||||
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||
} |
||||
} |
||||
return errors.Wrap(err, "[GetBlocksByNumber]") |
||||
} |
||||
|
||||
func (st *syncStream) handleGetBlocksByHashesRequest(rid uint64, req *syncpb.GetBlocksByHashesRequest) error { |
||||
hashes := bytesToHashes(req.BlockHashes) |
||||
resp, err := st.computeRespFromBlockHashes(rid, hashes) |
||||
if resp == nil && err != nil { |
||||
resp = syncpb.MakeErrorResponseMessage(rid, err) |
||||
} |
||||
if writeErr := st.writeMsg(resp); writeErr != nil { |
||||
if err == nil { |
||||
err = writeErr |
||||
} else { |
||||
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr) |
||||
} |
||||
} |
||||
return errors.Wrap(err, "[GetBlocksByHashes]") |
||||
} |
||||
|
||||
func (st *syncStream) handleResp(resp *syncpb.Response) { |
||||
st.protocol.rm.DeliverResponse(st.ID(), &syncResponse{resp}) |
||||
} |
||||
|
||||
func (st *syncStream) readMsg() (*syncpb.Message, error) { |
||||
b, err := st.ReadBytes() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
var msg = &syncpb.Message{} |
||||
if err := protobuf.Unmarshal(b, msg); err != nil { |
||||
return nil, err |
||||
} |
||||
return msg, nil |
||||
} |
||||
|
||||
func (st *syncStream) writeMsg(msg *syncpb.Message) error { |
||||
b, err := protobuf.Marshal(msg) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return st.WriteBytes(b) |
||||
} |
||||
|
||||
func (st *syncStream) computeBlockNumberResp(rid uint64) *syncpb.Message { |
||||
bn := st.chain.getCurrentBlockNumber() |
||||
return syncpb.MakeGetBlockNumberResponseMessage(rid, bn) |
||||
} |
||||
|
||||
func (st syncStream) computeGetBlockHashesResp(rid uint64, bns []uint64) (*syncpb.Message, error) { |
||||
if len(bns) > GetBlockHashesAmountCap { |
||||
err := fmt.Errorf("GetBlockHashes amount exceed cap: %v>%v", len(bns), GetBlockHashesAmountCap) |
||||
return nil, err |
||||
} |
||||
hashes := st.chain.getBlockHashes(bns) |
||||
return syncpb.MakeGetBlockHashesResponseMessage(rid, hashes), nil |
||||
} |
||||
|
||||
func (st *syncStream) computeRespFromBlockNumber(rid uint64, bns []uint64) (*syncpb.Message, error) { |
||||
if len(bns) > GetBlocksByNumAmountCap { |
||||
err := fmt.Errorf("GetBlocksByNum amount exceed cap: %v>%v", len(bns), GetBlocksByNumAmountCap) |
||||
return nil, err |
||||
} |
||||
blocks, err := st.chain.getBlocksByNumber(bns) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var ( |
||||
blocksBytes = make([][]byte, 0, len(blocks)) |
||||
sigs = make([][]byte, 0, len(blocks)) |
||||
) |
||||
for _, block := range blocks { |
||||
bb, err := rlp.EncodeToBytes(block) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
blocksBytes = append(blocksBytes, bb) |
||||
|
||||
var sig []byte |
||||
if block != nil { |
||||
sig = block.GetCurrentCommitSig() |
||||
} |
||||
sigs = append(sigs, sig) |
||||
} |
||||
return syncpb.MakeGetBlocksByNumResponseMessage(rid, blocksBytes, sigs), nil |
||||
} |
||||
|
||||
func (st *syncStream) computeRespFromBlockHashes(rid uint64, hs []common.Hash) (*syncpb.Message, error) { |
||||
if len(hs) > GetBlocksByHashesAmountCap { |
||||
err := fmt.Errorf("GetBlockByHashes amount exceed cap: %v > %v", len(hs), GetBlocksByHashesAmountCap) |
||||
return nil, err |
||||
} |
||||
blocks, err := st.chain.getBlocksByHashes(hs) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var ( |
||||
blocksBytes = make([][]byte, 0, len(blocks)) |
||||
sigs = make([][]byte, 0, len(blocks)) |
||||
) |
||||
for _, block := range blocks { |
||||
bb, err := rlp.EncodeToBytes(block) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
blocksBytes = append(blocksBytes, bb) |
||||
|
||||
var sig []byte |
||||
if block != nil { |
||||
sig = block.GetCurrentCommitSig() |
||||
} |
||||
sigs = append(sigs, sig) |
||||
} |
||||
return syncpb.MakeGetBlocksByHashesResponseMessage(rid, blocksBytes, sigs), nil |
||||
} |
||||
|
||||
func bytesToHashes(bs [][]byte) []common.Hash { |
||||
hs := make([]common.Hash, 0, len(bs)) |
||||
for _, b := range bs { |
||||
var h common.Hash |
||||
copy(h[:], b) |
||||
hs = append(hs, h) |
||||
} |
||||
return hs |
||||
} |
@ -0,0 +1,242 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
protobuf "github.com/golang/protobuf/proto" |
||||
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
ic "github.com/libp2p/go-libp2p-core/crypto" |
||||
libp2p_network "github.com/libp2p/go-libp2p-core/network" |
||||
"github.com/libp2p/go-libp2p-core/peer" |
||||
"github.com/libp2p/go-libp2p-core/protocol" |
||||
ma "github.com/multiformats/go-multiaddr" |
||||
) |
||||
|
||||
var _ sttypes.Protocol = &Protocol{} |
||||
|
||||
var ( |
||||
testGetBlockNumbers = []uint64{1, 2, 3, 4, 5} |
||||
testGetBlockRequest = syncpb.MakeGetBlocksByNumRequest(testGetBlockNumbers) |
||||
testGetBlockRequestMsg = syncpb.MakeMessageFromRequest(testGetBlockRequest) |
||||
|
||||
testCurrentNumberRequest = syncpb.MakeGetBlockNumberRequest() |
||||
testCurrentNumberRequestMsg = syncpb.MakeMessageFromRequest(testCurrentNumberRequest) |
||||
|
||||
testGetBlockHashNums = []uint64{1, 2, 3, 4, 5} |
||||
testGetBlockHashesRequest = syncpb.MakeGetBlockHashesRequest(testGetBlockHashNums) |
||||
testGetBlockHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlockHashesRequest) |
||||
|
||||
testGetBlockByHashes = []common.Hash{ |
||||
numberToHash(1), |
||||
numberToHash(2), |
||||
numberToHash(3), |
||||
numberToHash(4), |
||||
numberToHash(5), |
||||
} |
||||
testGetBlocksByHashesRequest = syncpb.MakeGetBlocksByHashesRequest(testGetBlockByHashes) |
||||
testGetBlocksByHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlocksByHashesRequest) |
||||
) |
||||
|
||||
func TestSyncStream_HandleGetBlocksByRequest(t *testing.T) { |
||||
st, remoteSt := makeTestSyncStream() |
||||
|
||||
go st.run() |
||||
defer close(st.closeC) |
||||
|
||||
req := testGetBlockRequestMsg |
||||
b, _ := protobuf.Marshal(req) |
||||
err := remoteSt.WriteBytes(b) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
time.Sleep(200 * time.Millisecond) |
||||
receivedBytes, _ := remoteSt.ReadBytes() |
||||
|
||||
if err := checkBlocksResult(testGetBlockNumbers, receivedBytes); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
|
||||
func TestSyncStream_HandleCurrentBlockNumber(t *testing.T) { |
||||
st, remoteSt := makeTestSyncStream() |
||||
|
||||
go st.run() |
||||
defer close(st.closeC) |
||||
|
||||
req := testCurrentNumberRequestMsg |
||||
b, _ := protobuf.Marshal(req) |
||||
err := remoteSt.WriteBytes(b) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
time.Sleep(200 * time.Millisecond) |
||||
receivedBytes, _ := remoteSt.ReadBytes() |
||||
|
||||
if err := checkBlockNumberResult(receivedBytes); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
|
||||
func TestSyncStream_HandleGetBlockHashes(t *testing.T) { |
||||
st, remoteSt := makeTestSyncStream() |
||||
|
||||
go st.run() |
||||
defer close(st.closeC) |
||||
|
||||
req := testGetBlockHashesRequestMsg |
||||
b, _ := protobuf.Marshal(req) |
||||
err := remoteSt.WriteBytes(b) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
time.Sleep(200 * time.Millisecond) |
||||
receivedBytes, _ := remoteSt.ReadBytes() |
||||
|
||||
if err := checkBlockHashesResult(receivedBytes, testGetBlockNumbers); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
|
||||
func TestSyncStream_HandleGetBlocksByHashes(t *testing.T) { |
||||
st, remoteSt := makeTestSyncStream() |
||||
|
||||
go st.run() |
||||
defer close(st.closeC) |
||||
|
||||
req := testGetBlocksByHashesRequestMsg |
||||
b, _ := protobuf.Marshal(req) |
||||
err := remoteSt.WriteBytes(b) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
time.Sleep(200 * time.Millisecond) |
||||
receivedBytes, _ := remoteSt.ReadBytes() |
||||
|
||||
if err := checkBlocksByHashesResult(receivedBytes, testGetBlockByHashes); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
|
||||
func makeTestSyncStream() (*syncStream, *testRemoteBaseStream) { |
||||
localRaw, remoteRaw := makePairP2PStreams() |
||||
remote := newTestRemoteBaseStream(remoteRaw) |
||||
|
||||
bs := sttypes.NewBaseStream(localRaw) |
||||
|
||||
return &syncStream{ |
||||
BaseStream: bs, |
||||
chain: &testChainHelper{}, |
||||
protocol: makeTestProtocol(nil), |
||||
reqC: make(chan *syncpb.Request, 100), |
||||
respC: make(chan *syncpb.Response, 100), |
||||
closeC: make(chan struct{}), |
||||
closeStat: 0, |
||||
}, remote |
||||
} |
||||
|
||||
type testP2PStream struct { |
||||
readBuf *bytes.Buffer |
||||
inC chan struct{} |
||||
identity string |
||||
|
||||
writeHook func([]byte) (int, error) |
||||
} |
||||
|
||||
func makePairP2PStreams() (*testP2PStream, *testP2PStream) { |
||||
buf1 := bytes.NewBuffer(nil) |
||||
buf2 := bytes.NewBuffer(nil) |
||||
|
||||
st1 := &testP2PStream{ |
||||
readBuf: buf1, |
||||
inC: make(chan struct{}, 1), |
||||
identity: "local", |
||||
} |
||||
st2 := &testP2PStream{ |
||||
readBuf: buf2, |
||||
inC: make(chan struct{}, 1), |
||||
identity: "remote", |
||||
} |
||||
st1.writeHook = st2.receiveBytes |
||||
st2.writeHook = st1.receiveBytes |
||||
return st1, st2 |
||||
} |
||||
|
||||
func (st *testP2PStream) Read(b []byte) (n int, err error) { |
||||
<-st.inC |
||||
n, err = st.readBuf.Read(b) |
||||
if st.readBuf.Len() != 0 { |
||||
select { |
||||
case st.inC <- struct{}{}: |
||||
default: |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (st *testP2PStream) Write(b []byte) (n int, err error) { |
||||
return st.writeHook(b) |
||||
} |
||||
|
||||
func (st *testP2PStream) receiveBytes(b []byte) (n int, err error) { |
||||
n, err = st.readBuf.Write(b) |
||||
select { |
||||
case st.inC <- struct{}{}: |
||||
default: |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (st *testP2PStream) Close() error { return nil } |
||||
func (st *testP2PStream) CloseRead() error { return nil } |
||||
func (st *testP2PStream) CloseWrite() error { return nil } |
||||
func (st *testP2PStream) Reset() error { return nil } |
||||
func (st *testP2PStream) SetDeadline(time.Time) error { return nil } |
||||
func (st *testP2PStream) SetReadDeadline(time.Time) error { return nil } |
||||
func (st *testP2PStream) SetWriteDeadline(time.Time) error { return nil } |
||||
func (st *testP2PStream) ID() string { return "" } |
||||
func (st *testP2PStream) Protocol() protocol.ID { return "" } |
||||
func (st *testP2PStream) SetProtocol(protocol.ID) {} |
||||
func (st *testP2PStream) Stat() libp2p_network.Stat { return libp2p_network.Stat{} } |
||||
func (st *testP2PStream) Conn() libp2p_network.Conn { return &fakeConn{} } |
||||
|
||||
type testRemoteBaseStream struct { |
||||
base *sttypes.BaseStream |
||||
} |
||||
|
||||
func newTestRemoteBaseStream(st *testP2PStream) *testRemoteBaseStream { |
||||
rst := &testRemoteBaseStream{ |
||||
base: sttypes.NewBaseStream(st), |
||||
} |
||||
return rst |
||||
} |
||||
|
||||
func (st *testRemoteBaseStream) ReadBytes() ([]byte, error) { |
||||
return st.base.ReadBytes() |
||||
} |
||||
|
||||
func (st *testRemoteBaseStream) WriteBytes(b []byte) error { |
||||
return st.base.WriteBytes(b) |
||||
} |
||||
|
||||
type fakeConn struct{} |
||||
|
||||
func (conn *fakeConn) Close() error { return nil } |
||||
func (conn *fakeConn) LocalPeer() peer.ID { return "" } |
||||
func (conn *fakeConn) LocalPrivateKey() ic.PrivKey { return nil } |
||||
func (conn *fakeConn) RemotePeer() peer.ID { return "" } |
||||
func (conn *fakeConn) RemotePublicKey() ic.PubKey { return nil } |
||||
func (conn *fakeConn) LocalMultiaddr() ma.Multiaddr { return nil } |
||||
func (conn *fakeConn) RemoteMultiaddr() ma.Multiaddr { return nil } |
||||
func (conn *fakeConn) ID() string { return "" } |
||||
func (conn *fakeConn) NewStream(context.Context) (libp2p_network.Stream, error) { return nil, nil } |
||||
func (conn *fakeConn) GetStreams() []libp2p_network.Stream { return nil } |
||||
func (conn *fakeConn) Stat() libp2p_network.Stat { return libp2p_network.Stat{} } |
@ -0,0 +1,52 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
protobuf "github.com/golang/protobuf/proto" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/requestmanager" |
||||
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
errUnknownReqType = errors.New("unknown request") |
||||
) |
||||
|
||||
// syncResponse is the sync protocol response which implements sttypes.Response
|
||||
type syncResponse struct { |
||||
pb *syncpb.Response |
||||
} |
||||
|
||||
// ReqID return the request ID of the response
|
||||
func (resp *syncResponse) ReqID() uint64 { |
||||
return resp.pb.ReqId |
||||
} |
||||
|
||||
// GetProtobufMsg return the raw protobuf message
|
||||
func (resp *syncResponse) GetProtobufMsg() protobuf.Message { |
||||
return resp.pb |
||||
} |
||||
|
||||
func (resp *syncResponse) String() string { |
||||
return fmt.Sprintf("[SyncResponse %v]", resp.pb.String()) |
||||
} |
||||
|
||||
// Option is the additional option to do requests.
|
||||
// Currently, two options are supported:
|
||||
// 1. WithHighPriority - do the request in high priority.
|
||||
// 2. WithBlacklist - do the request without the given stream ids as blacklist
|
||||
// 3. WithWhitelist - do the request only with the given stream ids
|
||||
type Option = requestmanager.RequestOption |
||||
|
||||
var ( |
||||
// WithHighPriority instruct the request manager to do the request with high
|
||||
// priority
|
||||
WithHighPriority = requestmanager.WithHighPriority |
||||
// WithBlacklist instruct the request manager not to assign the request to the
|
||||
// given streamID
|
||||
WithBlacklist = requestmanager.WithBlacklist |
||||
// WithWhitelist instruct the request manager only to assign the request to the
|
||||
// given streamID
|
||||
WithWhitelist = requestmanager.WithWhitelist |
||||
) |
Loading…
Reference in new issue