add tests for client new functionalities and fix the receipt and node data functions in protocol

pull/4452/head
“GheisMohammadi” 1 year ago
parent ba6e516072
commit 7f131d84ce
No known key found for this signature in database
GPG Key ID: 15073AED3829FE90
  1. 111
      p2p/stream/protocols/sync/client.go
  2. 132
      p2p/stream/protocols/sync/client_test.go

@ -15,34 +15,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// FetchNodeData do fetch node data through sync stream protocol.
// Return the state node data as result, and error
func (p *Protocol) FetchNodeData(ctx context.Context, hashes []common.Hash, opts ...Option) (data [][]byte, stid sttypes.StreamID, err error) {
timer := p.doMetricClientRequest("fetchNodeData")
defer p.doMetricPostClientRequest("fetchNodeData", err, timer)
if len(hashes) == 0 {
err = fmt.Errorf("zero hashes array requested")
return
}
if len(hashes) > GetNodeDataCap {
err = fmt.Errorf("number of node data hashes cap of %v", GetNodeDataCap)
return
}
req := newGetNodeDataRequest(hashes)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
// At this point, error can be context canceled, context timed out, or waiting queue
// is already full.
return
}
// Parse and return blocks
data, err = req.getNodeDataResponse(resp)
return
}
// GetBlocksByNumber do getBlocksByNumberRequest through sync stream protocol. // GetBlocksByNumber do getBlocksByNumberRequest through sync stream protocol.
// Return the block as result, target stream id, and error // Return the block as result, target stream id, and error
func (p *Protocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...Option) (blocks []*types.Block, stid sttypes.StreamID, err error) { func (p *Protocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...Option) (blocks []*types.Block, stid sttypes.StreamID, err error) {
@ -163,6 +135,51 @@ func (p *Protocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts
return return
} }
// GetReceipts do getBlocksByHashesRequest through sync stream protocol.
func (p *Protocol) GetReceipts(ctx context.Context, hs []common.Hash, opts ...Option) (receipts []*types.Receipt, stid sttypes.StreamID, err error) {
timer := p.doMetricClientRequest("getReceipts")
defer p.doMetricPostClientRequest("getReceipts", err, timer)
if len(hs) == 0 {
err = fmt.Errorf("zero receipt hashes requested")
return
}
if len(hs) > GetReceiptsCap {
err = fmt.Errorf("number of requested hashes exceed limit")
return
}
req := newGetReceiptsRequest(hs)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
return
}
receipts, err = req.getReceiptsFromResponse(resp)
return
}
// GetNodeData do getNodeData through sync stream protocol.
// Return the state node data as result, and error
func (p *Protocol) GetNodeData(ctx context.Context, hs []common.Hash, opts ...Option) (data [][]byte, stid sttypes.StreamID, err error) {
timer := p.doMetricClientRequest("getNodeData")
defer p.doMetricPostClientRequest("getNodeData", err, timer)
if len(hs) == 0 {
err = fmt.Errorf("zero node data hashes requested")
return
}
if len(hs) > GetNodeDataCap {
err = fmt.Errorf("number of requested hashes exceed limit")
return
}
req := newGetNodeDataRequest(hs)
resp, stid, err := p.rm.DoRequest(ctx, req, opts...)
if err != nil {
return
}
data, err = req.getNodeDataFromResponse(resp)
return
}
// getBlocksByNumberRequest is the request for get block by numbers which implements // getBlocksByNumberRequest is the request for get block by numbers which implements
// sttypes.Request interface // sttypes.Request interface
type getBlocksByNumberRequest struct { type getBlocksByNumberRequest struct {
@ -457,10 +474,10 @@ func (req *getNodeDataRequest) Encode() ([]byte, error) {
return protobuf.Marshal(msg) return protobuf.Marshal(msg)
} }
func (req *getNodeDataRequest) getNodeDataResponse(resp sttypes.Response) ([][]byte, error) { func (req *getNodeDataRequest) getNodeDataFromResponse(resp sttypes.Response) ([][]byte, error) {
sResp, ok := resp.(*syncResponse) sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil { if !ok || sResp == nil {
return nil, errors.New("not sync response for node data") return nil, errors.New("not sync response")
} }
dataBytes, err := req.parseNodeDataBytes(sResp) dataBytes, err := req.parseNodeDataBytes(sResp)
if err != nil { if err != nil {
@ -473,11 +490,11 @@ func (req *getNodeDataRequest) parseNodeDataBytes(resp *syncResponse) ([][]byte,
if errResp := resp.pb.GetErrorResponse(); errResp != nil { if errResp := resp.pb.GetErrorResponse(); errResp != nil {
return nil, errors.New(errResp.Error) return nil, errors.New(errResp.Error)
} }
gbResp := resp.pb.GetGetNodeDataResponse() ndResp := resp.pb.GetGetNodeDataResponse()
if gbResp == nil { if ndResp == nil {
return nil, errors.New("The response is not for GetNodeData") return nil, errors.New("response not GetNodeData")
} }
return gbResp.DataBytes, nil return ndResp.DataBytes, nil
} }
// getReceiptsRequest is the request for get receipts which implements // getReceiptsRequest is the request for get receipts which implements
@ -521,25 +538,33 @@ func (req *getReceiptsRequest) Encode() ([]byte, error) {
return protobuf.Marshal(msg) return protobuf.Marshal(msg)
} }
func (req *getReceiptsRequest) getReceiptsResponse(resp sttypes.Response) ([][]byte, error) { func (req *getReceiptsRequest) getReceiptsFromResponse(resp sttypes.Response) ([]*types.Receipt, error) {
sResp, ok := resp.(*syncResponse) sResp, ok := resp.(*syncResponse)
if !ok || sResp == nil { if !ok || sResp == nil {
return nil, errors.New("not sync response for receipts") return nil, errors.New("not sync response")
} }
receiptsBytes, err := req.parseGetReceiptsBytes(sResp) receipts, err := req.parseGetReceiptsBytes(sResp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return receiptsBytes, nil return receipts, nil
} }
func (req *getReceiptsRequest) parseGetReceiptsBytes(resp *syncResponse) ([][]byte, error) { func (req *getReceiptsRequest) parseGetReceiptsBytes(resp *syncResponse) ([]*types.Receipt, error) {
if errResp := resp.pb.GetErrorResponse(); errResp != nil { if errResp := resp.pb.GetErrorResponse(); errResp != nil {
return nil, errors.New(errResp.Error) return nil, errors.New(errResp.Error)
} }
gbResp := resp.pb.GetGetReceiptsResponse() grResp := resp.pb.GetGetReceiptsResponse()
if gbResp == nil { if grResp == nil {
return nil, errors.New("The response is not for GetReceipts") return nil, errors.New("response not GetReceipts")
}
receipts := make([]*types.Receipt, 0, len(grResp.ReceiptsBytes))
for _, rcptBytes := range grResp.ReceiptsBytes {
var receipt *types.Receipt
if err := rlp.DecodeBytes(rcptBytes, &receipt); err != nil {
return nil, errors.Wrap(err, "[GetReceiptsResponse]")
}
receipts = append(receipts, receipt)
} }
return gbResp.ReceiptsBytes, nil return receipts, nil
} }

@ -22,6 +22,7 @@ import (
var ( var (
_ sttypes.Request = &getBlocksByNumberRequest{} _ sttypes.Request = &getBlocksByNumberRequest{}
_ sttypes.Request = &getBlockNumberRequest{} _ sttypes.Request = &getBlockNumberRequest{}
_ sttypes.Request = &getReceiptsRequest{}
_ sttypes.Response = &syncResponse{&syncpb.Response{}} _ sttypes.Response = &syncResponse{&syncpb.Response{}}
) )
@ -37,8 +38,17 @@ var (
var ( var (
testHeader = &block.Header{Header: headerV3.NewHeader()} testHeader = &block.Header{Header: headerV3.NewHeader()}
testBlock = types.NewBlockWithHeader(testHeader) testBlock = types.NewBlockWithHeader(testHeader)
testReceipt = &types.Receipt{
Status: types.ReceiptStatusSuccessful,
CumulativeGasUsed: 0x888888888,
Logs: []*types.Log{},
}
testNodeData = numberToHash(123456789).Bytes()
testHeaderBytes, _ = rlp.EncodeToBytes(testHeader) testHeaderBytes, _ = rlp.EncodeToBytes(testHeader)
testBlockBytes, _ = rlp.EncodeToBytes(testBlock) testBlockBytes, _ = rlp.EncodeToBytes(testBlock)
testReceiptBytes, _ = rlp.EncodeToBytes(testReceipt)
testNodeDataBytes, _ = rlp.EncodeToBytes(testNodeData)
testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
testCurBlockNumber uint64 = 100 testCurBlockNumber uint64 = 100
@ -49,6 +59,10 @@ var (
testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1))
testReceiptResponse = syncpb.MakeGetReceiptsResponse(0, [][]byte{testReceiptBytes})
testNodeDataResponse = syncpb.MakeGetNodeDataResponse(0, [][]byte{testNodeDataBytes})
testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error")) testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error"))
) )
@ -289,6 +303,124 @@ func TestProtocol_GetBlocksByHashes(t *testing.T) {
} }
} }
func TestProtocol_GetReceipts(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testReceiptResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("response not GetReceipts"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
receipts, stid, err := protocol.GetReceipts(context.Background(), []common.Hash{numberToHash(100)})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if len(receipts) != 1 {
t.Errorf("Test %v: size not 1", i)
}
}
}
}
func TestProtocol_GetNodeData(t *testing.T) {
tests := []struct {
getResponse getResponseFn
expErr error
expStID sttypes.StreamID
}{
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testNodeDataResponse,
}, makeTestStreamID(0)
},
expErr: nil,
expStID: makeTestStreamID(0),
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testBlockResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("response not GetNodeData"),
expStID: makeTestStreamID(0),
},
{
getResponse: nil,
expErr: errors.New("get response error"),
expStID: "",
},
{
getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) {
return &syncResponse{
pb: testErrorResponse,
}, makeTestStreamID(0)
},
expErr: errors.New("test error"),
expStID: makeTestStreamID(0),
},
}
for i, test := range tests {
protocol := makeTestProtocol(test.getResponse)
receipts, stid, err := protocol.GetNodeData(context.Background(), []common.Hash{numberToHash(100)})
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if stid != test.expStID {
t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID)
}
if test.expErr == nil {
if len(receipts) != 1 {
t.Errorf("Test %v: size not 1", i)
}
}
}
}
type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID)
type testHostRequestManager struct { type testHostRequestManager struct {

Loading…
Cancel
Save