You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
634 lines
18 KiB
634 lines
18 KiB
package sync
|
|
|
|
import (
|
|
"fmt"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
protobuf "github.com/golang/protobuf/proto"
|
|
"github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
|
|
syncpb "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message"
|
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
|
|
libp2p_network "github.com/libp2p/go-libp2p/core/network"
|
|
"github.com/pkg/errors"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
// syncStream is the structure for a stream running sync protocol.
|
|
type syncStream struct {
|
|
// Basic stream
|
|
*sttypes.BaseStream
|
|
|
|
protocol *Protocol
|
|
chain chainHelper
|
|
|
|
// pipeline channels
|
|
reqC chan *syncpb.Request
|
|
respC chan *syncpb.Response
|
|
|
|
// close related fields. Concurrent call of close is possible.
|
|
closeC chan struct{}
|
|
closeStat uint32
|
|
|
|
logger zerolog.Logger
|
|
}
|
|
|
|
// wrapStream wraps the raw libp2p stream to syncStream
|
|
func (p *Protocol) wrapStream(raw libp2p_network.Stream) *syncStream {
|
|
bs := sttypes.NewBaseStream(raw)
|
|
logger := p.logger.With().
|
|
Str("ID", string(bs.ID())).
|
|
Str("Remote Protocol", string(bs.ProtoID())).
|
|
Logger()
|
|
|
|
return &syncStream{
|
|
BaseStream: bs,
|
|
protocol: p,
|
|
chain: newChainHelper(p.chain, p.schedule),
|
|
reqC: make(chan *syncpb.Request, 100),
|
|
respC: make(chan *syncpb.Response, 100),
|
|
closeC: make(chan struct{}),
|
|
closeStat: 0,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (st *syncStream) run() {
|
|
st.logger.Info().Str("StreamID", string(st.ID())).Msg("running sync protocol on stream")
|
|
defer st.logger.Info().Str("StreamID", string(st.ID())).Msg("end running sync protocol on stream")
|
|
|
|
go st.handleReqLoop()
|
|
go st.handleRespLoop()
|
|
st.readMsgLoop()
|
|
}
|
|
|
|
// readMsgLoop is the loop
|
|
func (st *syncStream) readMsgLoop() {
|
|
for {
|
|
msg, err := st.readMsg()
|
|
if err != nil {
|
|
if err := st.Close(); err != nil {
|
|
st.logger.Err(err).Msg("failed to close sync stream")
|
|
}
|
|
return
|
|
}
|
|
st.deliverMsg(msg)
|
|
}
|
|
}
|
|
|
|
// deliverMsg process the delivered message and forward to the corresponding channel
|
|
func (st *syncStream) deliverMsg(msg protobuf.Message) {
|
|
syncMsg := msg.(*syncpb.Message)
|
|
if syncMsg == nil {
|
|
st.logger.Info().Str("message", msg.String()).Msg("received unexpected sync message")
|
|
return
|
|
}
|
|
if req := syncMsg.GetReq(); req != nil {
|
|
go func() {
|
|
select {
|
|
case st.reqC <- req:
|
|
case <-time.After(1 * time.Minute):
|
|
st.logger.Warn().Str("request", req.String()).
|
|
Msg("request handler severely jammed, message dropped")
|
|
}
|
|
}()
|
|
}
|
|
if resp := syncMsg.GetResp(); resp != nil {
|
|
go func() {
|
|
select {
|
|
case st.respC <- resp:
|
|
case <-time.After(1 * time.Minute):
|
|
st.logger.Warn().Str("response", resp.String()).
|
|
Msg("response handler severely jammed, message dropped")
|
|
}
|
|
}()
|
|
}
|
|
return
|
|
}
|
|
|
|
func (st *syncStream) handleReqLoop() {
|
|
for {
|
|
select {
|
|
case req := <-st.reqC:
|
|
st.protocol.rl.LimitRequest(st.ID())
|
|
err := st.handleReq(req)
|
|
|
|
if err != nil {
|
|
st.logger.Info().Err(err).Str("request", req.String()).
|
|
Msg("handle request error. Closing stream")
|
|
if err := st.Close(); err != nil {
|
|
st.logger.Err(err).Msg("failed to close sync stream")
|
|
}
|
|
return
|
|
}
|
|
|
|
case <-st.closeC:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (st *syncStream) handleRespLoop() {
|
|
for {
|
|
select {
|
|
case resp := <-st.respC:
|
|
st.handleResp(resp)
|
|
|
|
case <-st.closeC:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close stops the stream handling and closes the underlying stream
|
|
func (st *syncStream) Close() error {
|
|
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1)
|
|
if !notClosed {
|
|
// Already closed by another goroutine. Directly return
|
|
return nil
|
|
}
|
|
if err := st.protocol.sm.RemoveStream(st.ID()); err != nil {
|
|
st.logger.Err(err).Str("stream ID", string(st.ID())).
|
|
Msg("failed to remove sync stream on close")
|
|
}
|
|
close(st.closeC)
|
|
return st.BaseStream.Close()
|
|
}
|
|
|
|
// CloseOnExit reset the stream on exiting node
|
|
func (st *syncStream) CloseOnExit() error {
|
|
notClosed := atomic.CompareAndSwapUint32(&st.closeStat, 0, 1)
|
|
if !notClosed {
|
|
// Already closed by another goroutine. Directly return
|
|
return nil
|
|
}
|
|
close(st.closeC)
|
|
return st.BaseStream.CloseOnExit()
|
|
}
|
|
|
|
func (st *syncStream) handleReq(req *syncpb.Request) error {
|
|
if gnReq := req.GetGetBlockNumberRequest(); gnReq != nil {
|
|
return st.handleGetBlockNumberRequest(req.ReqId)
|
|
}
|
|
if ghReq := req.GetGetBlockHashesRequest(); ghReq != nil {
|
|
return st.handleGetBlockHashesRequest(req.ReqId, ghReq)
|
|
}
|
|
if bnReq := req.GetGetBlocksByNumRequest(); bnReq != nil {
|
|
return st.handleGetBlocksByNumRequest(req.ReqId, bnReq)
|
|
}
|
|
if bhReq := req.GetGetBlocksByHashesRequest(); bhReq != nil {
|
|
return st.handleGetBlocksByHashesRequest(req.ReqId, bhReq)
|
|
}
|
|
if ndReq := req.GetGetNodeDataRequest(); ndReq != nil {
|
|
return st.handleGetNodeDataRequest(req.ReqId, ndReq)
|
|
}
|
|
if rReq := req.GetGetReceiptsRequest(); rReq != nil {
|
|
return st.handleGetReceiptsRequest(req.ReqId, rReq)
|
|
}
|
|
if ndReq := req.GetGetAccountRangeRequest(); ndReq != nil {
|
|
return st.handleGetAccountRangeRequest(req.ReqId, ndReq)
|
|
}
|
|
if ndReq := req.GetGetStorageRangesRequest(); ndReq != nil {
|
|
return st.handleGetStorageRangesRequest(req.ReqId, ndReq)
|
|
}
|
|
if ndReq := req.GetGetByteCodesRequest(); ndReq != nil {
|
|
return st.handleGetByteCodesRequest(req.ReqId, ndReq)
|
|
}
|
|
if ndReq := req.GetGetTrieNodesRequest(); ndReq != nil {
|
|
return st.handleGetTrieNodesRequest(req.ReqId, ndReq)
|
|
}
|
|
// unsupported request type
|
|
return st.handleUnknownRequest(req.ReqId)
|
|
}
|
|
|
|
func (st *syncStream) handleGetBlockNumberRequest(rid uint64) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getBlockNumber",
|
|
}).Inc()
|
|
|
|
resp := st.computeBlockNumberResp(rid)
|
|
if err := st.writeMsg(resp); err != nil {
|
|
return errors.Wrap(err, "[GetBlockNumber]: writeMsg")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (st *syncStream) handleGetBlockHashesRequest(rid uint64, req *syncpb.GetBlockHashesRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getBlockHashes",
|
|
}).Inc()
|
|
|
|
resp, err := st.computeGetBlockHashesResp(rid, req.Nums)
|
|
if err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetBlockHashes]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetBlocksByNumRequest(rid uint64, req *syncpb.GetBlocksByNumRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getBlocksByNumber",
|
|
}).Inc()
|
|
|
|
resp, err := st.computeRespFromBlockNumber(rid, req.Nums)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetBlocksByNumber]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetBlocksByHashesRequest(rid uint64, req *syncpb.GetBlocksByHashesRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getBlocksByHashes",
|
|
}).Inc()
|
|
|
|
hashes := bytesToHashes(req.BlockHashes)
|
|
resp, err := st.computeRespFromBlockHashes(rid, hashes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetBlocksByHashes]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetNodeDataRequest(rid uint64, req *syncpb.GetNodeDataRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getNodeData",
|
|
}).Inc()
|
|
|
|
hashes := bytesToHashes(req.NodeHashes)
|
|
resp, err := st.computeGetNodeData(rid, hashes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetNodeData]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetReceiptsRequest(rid uint64, req *syncpb.GetReceiptsRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getReceipts",
|
|
}).Inc()
|
|
|
|
hashes := bytesToHashes(req.BlockHashes)
|
|
resp, err := st.computeGetReceipts(rid, hashes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetReceipts]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetAccountRangeRequest(rid uint64, req *syncpb.GetAccountRangeRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getAccountRangeRequest",
|
|
}).Inc()
|
|
|
|
root := common.BytesToHash(req.Root)
|
|
origin := common.BytesToHash(req.Origin)
|
|
limit := common.BytesToHash(req.Limit)
|
|
resp, err := st.computeGetAccountRangeRequest(rid, root, origin, limit, req.Bytes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetAccountRange]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetStorageRangesRequest(rid uint64, req *syncpb.GetStorageRangesRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getStorageRangesRequest",
|
|
}).Inc()
|
|
|
|
root := common.BytesToHash(req.Root)
|
|
accounts := bytesToHashes(req.Accounts)
|
|
origin := common.BytesToHash(req.Origin)
|
|
limit := common.BytesToHash(req.Limit)
|
|
resp, err := st.computeGetStorageRangesRequest(rid, root, accounts, origin, limit, req.Bytes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetStorageRanges]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetByteCodesRequest(rid uint64, req *syncpb.GetByteCodesRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getByteCodesRequest",
|
|
}).Inc()
|
|
|
|
hashes := bytesToHashes(req.Hashes)
|
|
resp, err := st.computeGetByteCodesRequest(rid, hashes, req.Bytes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetByteCodes]")
|
|
}
|
|
|
|
func (st *syncStream) handleGetTrieNodesRequest(rid uint64, req *syncpb.GetTrieNodesRequest) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "getTrieNodesRequest",
|
|
}).Inc()
|
|
|
|
root := common.BytesToHash(req.Root)
|
|
resp, err := st.computeGetTrieNodesRequest(rid, root, req.Paths, req.Bytes)
|
|
if resp == nil && err != nil {
|
|
resp = syncpb.MakeErrorResponseMessage(rid, err)
|
|
}
|
|
if writeErr := st.writeMsg(resp); writeErr != nil {
|
|
if err == nil {
|
|
err = writeErr
|
|
} else {
|
|
err = fmt.Errorf("%v; [writeMsg] %v", err.Error(), writeErr)
|
|
}
|
|
}
|
|
return errors.Wrap(err, "[GetTrieNodes]")
|
|
}
|
|
|
|
func (st *syncStream) handleUnknownRequest(rid uint64) error {
|
|
serverRequestCounterVec.With(prometheus.Labels{
|
|
"topic": string(st.ProtoID()),
|
|
"request_type": "unknown",
|
|
}).Inc()
|
|
resp := syncpb.MakeErrorResponseMessage(rid, errUnknownReqType)
|
|
return st.writeMsg(resp)
|
|
}
|
|
|
|
func (st *syncStream) handleResp(resp *syncpb.Response) {
|
|
st.protocol.rm.DeliverResponse(st.ID(), &syncResponse{resp})
|
|
}
|
|
|
|
func (st *syncStream) readMsg() (*syncpb.Message, error) {
|
|
b, err := st.ReadBytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var msg = &syncpb.Message{}
|
|
if err := protobuf.Unmarshal(b, msg); err != nil {
|
|
return nil, err
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
func (st *syncStream) writeMsg(msg *syncpb.Message) error {
|
|
b, err := protobuf.Marshal(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return st.WriteBytes(b)
|
|
}
|
|
|
|
func (st *syncStream) computeBlockNumberResp(rid uint64) *syncpb.Message {
|
|
bn := st.chain.getCurrentBlockNumber()
|
|
return syncpb.MakeGetBlockNumberResponseMessage(rid, bn)
|
|
}
|
|
|
|
func (st *syncStream) computeGetBlockHashesResp(rid uint64, bns []uint64) (*syncpb.Message, error) {
|
|
if len(bns) > GetBlockHashesAmountCap {
|
|
err := fmt.Errorf("GetBlockHashes amount exceed cap: %v>%v", len(bns), GetBlockHashesAmountCap)
|
|
return nil, err
|
|
}
|
|
hashes := st.chain.getBlockHashes(bns)
|
|
return syncpb.MakeGetBlockHashesResponseMessage(rid, hashes), nil
|
|
}
|
|
|
|
func (st *syncStream) computeRespFromBlockNumber(rid uint64, bns []uint64) (*syncpb.Message, error) {
|
|
if len(bns) > GetBlocksByNumAmountCap {
|
|
err := fmt.Errorf("GetBlocksByNum amount exceed cap: %v>%v", len(bns), GetBlocksByNumAmountCap)
|
|
return nil, err
|
|
}
|
|
blocks, err := st.chain.getBlocksByNumber(bns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var (
|
|
blocksBytes = make([][]byte, 0, len(blocks))
|
|
sigs = make([][]byte, 0, len(blocks))
|
|
)
|
|
for _, block := range blocks {
|
|
bb, err := rlp.EncodeToBytes(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
blocksBytes = append(blocksBytes, bb)
|
|
|
|
var sig []byte
|
|
if block != nil {
|
|
sig = block.GetCurrentCommitSig()
|
|
}
|
|
sigs = append(sigs, sig)
|
|
}
|
|
return syncpb.MakeGetBlocksByNumResponseMessage(rid, blocksBytes, sigs), nil
|
|
}
|
|
|
|
func (st *syncStream) computeRespFromBlockHashes(rid uint64, hs []common.Hash) (*syncpb.Message, error) {
|
|
if len(hs) > GetBlocksByHashesAmountCap {
|
|
err := fmt.Errorf("GetBlockByHashes amount exceed cap: %v > %v", len(hs), GetBlocksByHashesAmountCap)
|
|
return nil, err
|
|
}
|
|
blocks, err := st.chain.getBlocksByHashes(hs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var (
|
|
blocksBytes = make([][]byte, 0, len(blocks))
|
|
sigs = make([][]byte, 0, len(blocks))
|
|
)
|
|
for _, block := range blocks {
|
|
bb, err := rlp.EncodeToBytes(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
blocksBytes = append(blocksBytes, bb)
|
|
|
|
var sig []byte
|
|
if block != nil {
|
|
sig = block.GetCurrentCommitSig()
|
|
}
|
|
sigs = append(sigs, sig)
|
|
}
|
|
return syncpb.MakeGetBlocksByHashesResponseMessage(rid, blocksBytes, sigs), nil
|
|
}
|
|
|
|
func (st *syncStream) computeGetNodeData(rid uint64, hs []common.Hash) (*syncpb.Message, error) {
|
|
if len(hs) > GetNodeDataCap {
|
|
err := fmt.Errorf("GetNodeData amount exceed cap: %v > %v", len(hs), GetNodeDataCap)
|
|
return nil, err
|
|
}
|
|
data, err := st.chain.getNodeData(hs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return syncpb.MakeGetNodeDataResponseMessage(rid, data), nil
|
|
}
|
|
|
|
func (st *syncStream) computeGetReceipts(rid uint64, hs []common.Hash) (*syncpb.Message, error) {
|
|
if len(hs) > GetReceiptsCap {
|
|
err := fmt.Errorf("GetReceipts amount exceed cap: %v > %v", len(hs), GetReceiptsCap)
|
|
return nil, err
|
|
}
|
|
receipts, err := st.chain.getReceipts(hs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var normalizedReceipts = make(map[uint64]*syncpb.Receipts, len(receipts))
|
|
for i, blkReceipts := range receipts {
|
|
normalizedReceipts[uint64(i)] = &syncpb.Receipts{
|
|
ReceiptBytes: make([][]byte, 0),
|
|
}
|
|
for _, receipt := range blkReceipts {
|
|
receiptBytes, err := rlp.EncodeToBytes(receipt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
normalizedReceipts[uint64(i)].ReceiptBytes = append(normalizedReceipts[uint64(i)].ReceiptBytes, receiptBytes)
|
|
}
|
|
}
|
|
return syncpb.MakeGetReceiptsResponseMessage(rid, normalizedReceipts), nil
|
|
}
|
|
|
|
func (st *syncStream) computeGetAccountRangeRequest(rid uint64, root common.Hash, origin common.Hash, limit common.Hash, bytes uint64) (*syncpb.Message, error) {
|
|
if bytes == 0 {
|
|
return nil, fmt.Errorf("zero account ranges bytes requested")
|
|
}
|
|
if bytes > softResponseLimit {
|
|
return nil, fmt.Errorf("requested bytes exceed limit")
|
|
}
|
|
accounts, proof, err := st.chain.getAccountRange(root, origin, limit, bytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return syncpb.MakeGetAccountRangeResponseMessage(rid, accounts, proof), nil
|
|
}
|
|
|
|
func (st *syncStream) computeGetStorageRangesRequest(rid uint64, root common.Hash, accounts []common.Hash, origin common.Hash, limit common.Hash, bytes uint64) (*syncpb.Message, error) {
|
|
if bytes == 0 {
|
|
return nil, fmt.Errorf("zero storage ranges bytes requested")
|
|
}
|
|
if bytes > softResponseLimit {
|
|
return nil, fmt.Errorf("requested bytes exceed limit")
|
|
}
|
|
if len(accounts) > GetStorageRangesRequestCap {
|
|
err := fmt.Errorf("GetStorageRangesRequest amount exceed cap: %v > %v", len(accounts), GetStorageRangesRequestCap)
|
|
return nil, err
|
|
}
|
|
slots, proofs, err := st.chain.getStorageRanges(root, accounts, origin, limit, bytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return syncpb.MakeGetStorageRangesResponseMessage(rid, slots, proofs), nil
|
|
}
|
|
|
|
func (st *syncStream) computeGetByteCodesRequest(rid uint64, hs []common.Hash, bytes uint64) (*syncpb.Message, error) {
|
|
if bytes == 0 {
|
|
return nil, fmt.Errorf("zero byte code bytes requested")
|
|
}
|
|
if bytes > softResponseLimit {
|
|
return nil, fmt.Errorf("requested bytes exceed limit")
|
|
}
|
|
if len(hs) > GetByteCodesRequestCap {
|
|
err := fmt.Errorf("GetByteCodesRequest amount exceed cap: %v > %v", len(hs), GetByteCodesRequestCap)
|
|
return nil, err
|
|
}
|
|
codes, err := st.chain.getByteCodes(hs, bytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return syncpb.MakeGetByteCodesResponseMessage(rid, codes), nil
|
|
}
|
|
|
|
func (st *syncStream) computeGetTrieNodesRequest(rid uint64, root common.Hash, paths []*message.TrieNodePathSet, bytes uint64) (*syncpb.Message, error) {
|
|
if bytes == 0 {
|
|
return nil, fmt.Errorf("zero trie node bytes requested")
|
|
}
|
|
if bytes > softResponseLimit {
|
|
return nil, fmt.Errorf("requested bytes exceed limit")
|
|
}
|
|
if len(paths) > GetTrieNodesRequestCap {
|
|
err := fmt.Errorf("GetTrieNodesRequest amount exceed cap: %v > %v", len(paths), GetTrieNodesRequestCap)
|
|
return nil, err
|
|
}
|
|
nodes, err := st.chain.getTrieNodes(root, paths, bytes, time.Now())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return syncpb.MakeGetTrieNodesResponseMessage(rid, nodes), nil
|
|
}
|
|
|
|
func bytesToHashes(bs [][]byte) []common.Hash {
|
|
hs := make([]common.Hash, 0, len(bs))
|
|
for _, b := range bs {
|
|
var h common.Hash
|
|
copy(h[:], b)
|
|
hs = append(hs, h)
|
|
}
|
|
return hs
|
|
}
|
|
|