diff --git a/p2p/stream/protocols/sync/client.go b/p2p/stream/protocols/sync/client.go index 85730ada6..fc204fd55 100644 --- a/p2p/stream/protocols/sync/client.go +++ b/p2p/stream/protocols/sync/client.go @@ -15,34 +15,6 @@ import ( "github.com/pkg/errors" ) -// FetchNodeData do fetch node data through sync stream protocol. -// Return the state node data as result, and error -func (p *Protocol) FetchNodeData(ctx context.Context, hashes []common.Hash, opts ...Option) (data [][]byte, stid sttypes.StreamID, err error) { - timer := p.doMetricClientRequest("fetchNodeData") - defer p.doMetricPostClientRequest("fetchNodeData", err, timer) - - if len(hashes) == 0 { - err = fmt.Errorf("zero hashes array requested") - return - } - if len(hashes) > GetNodeDataCap { - err = fmt.Errorf("number of node data hashes cap of %v", GetNodeDataCap) - return - } - - req := newGetNodeDataRequest(hashes) - resp, stid, err := p.rm.DoRequest(ctx, req, opts...) - if err != nil { - // At this point, error can be context canceled, context timed out, or waiting queue - // is already full. - return - } - - // Parse and return blocks - data, err = req.getNodeDataResponse(resp) - return -} - // GetBlocksByNumber do getBlocksByNumberRequest through sync stream protocol. // Return the block as result, target stream id, and error func (p *Protocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...Option) (blocks []*types.Block, stid sttypes.StreamID, err error) { @@ -163,6 +135,51 @@ func (p *Protocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts return } +// GetReceipts do getBlocksByHashesRequest through sync stream protocol. +func (p *Protocol) GetReceipts(ctx context.Context, hs []common.Hash, opts ...Option) (receipts []*types.Receipt, stid sttypes.StreamID, err error) { + timer := p.doMetricClientRequest("getReceipts") + defer p.doMetricPostClientRequest("getReceipts", err, timer) + + if len(hs) == 0 { + err = fmt.Errorf("zero receipt hashes requested") + return + } + if len(hs) > GetReceiptsCap { + err = fmt.Errorf("number of requested hashes exceed limit") + return + } + req := newGetReceiptsRequest(hs) + resp, stid, err := p.rm.DoRequest(ctx, req, opts...) + if err != nil { + return + } + receipts, err = req.getReceiptsFromResponse(resp) + return +} + +// GetNodeData do getNodeData through sync stream protocol. +// Return the state node data as result, and error +func (p *Protocol) GetNodeData(ctx context.Context, hs []common.Hash, opts ...Option) (data [][]byte, stid sttypes.StreamID, err error) { + timer := p.doMetricClientRequest("getNodeData") + defer p.doMetricPostClientRequest("getNodeData", err, timer) + + if len(hs) == 0 { + err = fmt.Errorf("zero node data hashes requested") + return + } + if len(hs) > GetNodeDataCap { + err = fmt.Errorf("number of requested hashes exceed limit") + return + } + req := newGetNodeDataRequest(hs) + resp, stid, err := p.rm.DoRequest(ctx, req, opts...) + if err != nil { + return + } + data, err = req.getNodeDataFromResponse(resp) + return +} + // getBlocksByNumberRequest is the request for get block by numbers which implements // sttypes.Request interface type getBlocksByNumberRequest struct { @@ -457,10 +474,10 @@ func (req *getNodeDataRequest) Encode() ([]byte, error) { return protobuf.Marshal(msg) } -func (req *getNodeDataRequest) getNodeDataResponse(resp sttypes.Response) ([][]byte, error) { +func (req *getNodeDataRequest) getNodeDataFromResponse(resp sttypes.Response) ([][]byte, error) { sResp, ok := resp.(*syncResponse) if !ok || sResp == nil { - return nil, errors.New("not sync response for node data") + return nil, errors.New("not sync response") } dataBytes, err := req.parseNodeDataBytes(sResp) if err != nil { @@ -473,11 +490,11 @@ func (req *getNodeDataRequest) parseNodeDataBytes(resp *syncResponse) ([][]byte, if errResp := resp.pb.GetErrorResponse(); errResp != nil { return nil, errors.New(errResp.Error) } - gbResp := resp.pb.GetGetNodeDataResponse() - if gbResp == nil { - return nil, errors.New("The response is not for GetNodeData") + ndResp := resp.pb.GetGetNodeDataResponse() + if ndResp == nil { + return nil, errors.New("response not GetNodeData") } - return gbResp.DataBytes, nil + return ndResp.DataBytes, nil } // getReceiptsRequest is the request for get receipts which implements @@ -521,25 +538,33 @@ func (req *getReceiptsRequest) Encode() ([]byte, error) { return protobuf.Marshal(msg) } -func (req *getReceiptsRequest) getReceiptsResponse(resp sttypes.Response) ([][]byte, error) { +func (req *getReceiptsRequest) getReceiptsFromResponse(resp sttypes.Response) ([]*types.Receipt, error) { sResp, ok := resp.(*syncResponse) if !ok || sResp == nil { - return nil, errors.New("not sync response for receipts") + return nil, errors.New("not sync response") } - receiptsBytes, err := req.parseGetReceiptsBytes(sResp) + receipts, err := req.parseGetReceiptsBytes(sResp) if err != nil { return nil, err } - return receiptsBytes, nil + return receipts, nil } -func (req *getReceiptsRequest) parseGetReceiptsBytes(resp *syncResponse) ([][]byte, error) { +func (req *getReceiptsRequest) parseGetReceiptsBytes(resp *syncResponse) ([]*types.Receipt, error) { if errResp := resp.pb.GetErrorResponse(); errResp != nil { return nil, errors.New(errResp.Error) } - gbResp := resp.pb.GetGetReceiptsResponse() - if gbResp == nil { - return nil, errors.New("The response is not for GetReceipts") + grResp := resp.pb.GetGetReceiptsResponse() + if grResp == nil { + return nil, errors.New("response not GetReceipts") + } + receipts := make([]*types.Receipt, 0, len(grResp.ReceiptsBytes)) + for _, rcptBytes := range grResp.ReceiptsBytes { + var receipt *types.Receipt + if err := rlp.DecodeBytes(rcptBytes, &receipt); err != nil { + return nil, errors.Wrap(err, "[GetReceiptsResponse]") + } + receipts = append(receipts, receipt) } - return gbResp.ReceiptsBytes, nil + return receipts, nil } diff --git a/p2p/stream/protocols/sync/client_test.go b/p2p/stream/protocols/sync/client_test.go index 3509915ce..9a675fa0d 100644 --- a/p2p/stream/protocols/sync/client_test.go +++ b/p2p/stream/protocols/sync/client_test.go @@ -22,6 +22,7 @@ import ( var ( _ sttypes.Request = &getBlocksByNumberRequest{} _ sttypes.Request = &getBlockNumberRequest{} + _ sttypes.Request = &getReceiptsRequest{} _ sttypes.Response = &syncResponse{&syncpb.Response{}} ) @@ -35,11 +36,20 @@ var ( ) var ( - testHeader = &block.Header{Header: headerV3.NewHeader()} - testBlock = types.NewBlockWithHeader(testHeader) - testHeaderBytes, _ = rlp.EncodeToBytes(testHeader) - testBlockBytes, _ = rlp.EncodeToBytes(testBlock) - testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) + testHeader = &block.Header{Header: headerV3.NewHeader()} + testBlock = types.NewBlockWithHeader(testHeader) + testReceipt = &types.Receipt{ + Status: types.ReceiptStatusSuccessful, + CumulativeGasUsed: 0x888888888, + Logs: []*types.Log{}, + } + testNodeData = numberToHash(123456789).Bytes() + testHeaderBytes, _ = rlp.EncodeToBytes(testHeader) + testBlockBytes, _ = rlp.EncodeToBytes(testBlock) + testReceiptBytes, _ = rlp.EncodeToBytes(testReceipt) + testNodeDataBytes, _ = rlp.EncodeToBytes(testNodeData) + + testBlockResponse = syncpb.MakeGetBlocksByNumResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) testCurBlockNumber uint64 = 100 testBlockNumberResponse = syncpb.MakeGetBlockNumberResponse(0, testCurBlockNumber) @@ -49,6 +59,10 @@ var ( testBlocksByHashesResponse = syncpb.MakeGetBlocksByHashesResponse(0, [][]byte{testBlockBytes}, make([][]byte, 1)) + testReceiptResponse = syncpb.MakeGetReceiptsResponse(0, [][]byte{testReceiptBytes}) + + testNodeDataResponse = syncpb.MakeGetNodeDataResponse(0, [][]byte{testNodeDataBytes}) + testErrorResponse = syncpb.MakeErrorResponse(0, errors.New("test error")) ) @@ -289,6 +303,124 @@ func TestProtocol_GetBlocksByHashes(t *testing.T) { } } +func TestProtocol_GetReceipts(t *testing.T) { + tests := []struct { + getResponse getResponseFn + expErr error + expStID sttypes.StreamID + }{ + { + getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { + return &syncResponse{ + pb: testReceiptResponse, + }, makeTestStreamID(0) + }, + expErr: nil, + expStID: makeTestStreamID(0), + }, + { + getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { + return &syncResponse{ + pb: testBlockResponse, + }, makeTestStreamID(0) + }, + expErr: errors.New("response not GetReceipts"), + expStID: makeTestStreamID(0), + }, + { + getResponse: nil, + expErr: errors.New("get response error"), + expStID: "", + }, + { + getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { + return &syncResponse{ + pb: testErrorResponse, + }, makeTestStreamID(0) + }, + expErr: errors.New("test error"), + expStID: makeTestStreamID(0), + }, + } + + for i, test := range tests { + protocol := makeTestProtocol(test.getResponse) + receipts, stid, err := protocol.GetReceipts(context.Background(), []common.Hash{numberToHash(100)}) + + if assErr := assertError(err, test.expErr); assErr != nil { + t.Errorf("Test %v: %v", i, assErr) + continue + } + if stid != test.expStID { + t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) + } + if test.expErr == nil { + if len(receipts) != 1 { + t.Errorf("Test %v: size not 1", i) + } + } + } +} + +func TestProtocol_GetNodeData(t *testing.T) { + tests := []struct { + getResponse getResponseFn + expErr error + expStID sttypes.StreamID + }{ + { + getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { + return &syncResponse{ + pb: testNodeDataResponse, + }, makeTestStreamID(0) + }, + expErr: nil, + expStID: makeTestStreamID(0), + }, + { + getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { + return &syncResponse{ + pb: testBlockResponse, + }, makeTestStreamID(0) + }, + expErr: errors.New("response not GetNodeData"), + expStID: makeTestStreamID(0), + }, + { + getResponse: nil, + expErr: errors.New("get response error"), + expStID: "", + }, + { + getResponse: func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) { + return &syncResponse{ + pb: testErrorResponse, + }, makeTestStreamID(0) + }, + expErr: errors.New("test error"), + expStID: makeTestStreamID(0), + }, + } + + for i, test := range tests { + protocol := makeTestProtocol(test.getResponse) + receipts, stid, err := protocol.GetNodeData(context.Background(), []common.Hash{numberToHash(100)}) + + if assErr := assertError(err, test.expErr); assErr != nil { + t.Errorf("Test %v: %v", i, assErr) + continue + } + if stid != test.expStID { + t.Errorf("Test %v: unexpected st id: %v / %v", i, stid, test.expStID) + } + if test.expErr == nil { + if len(receipts) != 1 { + t.Errorf("Test %v: size not 1", i) + } + } + } +} + type getResponseFn func(request sttypes.Request) (sttypes.Response, sttypes.StreamID) type testHostRequestManager struct {