parent
e08c2ce41e
commit
fbed5bcf3b
@ -0,0 +1,19 @@ |
||||
package requestmanager |
||||
|
||||
import "time" |
||||
|
||||
// TODO: determine the values in production environment
|
||||
const ( |
||||
// throttle to do request every 100 milliseconds
|
||||
throttleInterval = 100 * time.Millisecond |
||||
|
||||
// number of request to be done in each throttle loop
|
||||
throttleBatch = 16 |
||||
|
||||
// deliverTimeout is the timeout for a response delivery. If the response cannot be delivered
|
||||
// within timeout because blocking of the channel, the response will be dropped.
|
||||
deliverTimeout = 5 * time.Second |
||||
|
||||
// maxWaitingSize is the maximum requests that are in waiting list
|
||||
maxWaitingSize = 1024 |
||||
) |
@ -0,0 +1,25 @@ |
||||
package requestmanager |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
p2ptypes "github.com/harmony-one/harmony/p2p/types" |
||||
) |
||||
|
||||
// Requester is the interface to do request
|
||||
type Requester interface { |
||||
DoRequest(ctx context.Context, request sttypes.Request, options ...RequestOption) (sttypes.Response, sttypes.StreamID, error) |
||||
} |
||||
|
||||
// Deliverer is the interface to deliver a response
|
||||
type Deliverer interface { |
||||
DeliverResponse(stID sttypes.StreamID, resp sttypes.Response) |
||||
} |
||||
|
||||
// RequestManager manages over the requests
|
||||
type RequestManager interface { |
||||
p2ptypes.LifeCycle |
||||
Requester |
||||
Deliverer |
||||
} |
@ -0,0 +1,171 @@ |
||||
package requestmanager |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"strconv" |
||||
|
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/ethereum/go-ethereum/rlp" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
) |
||||
|
||||
var testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") |
||||
|
||||
type testStreamManager struct { |
||||
newStreamFeed event.Feed |
||||
rmStreamFeed event.Feed |
||||
} |
||||
|
||||
func newTestStreamManager() *testStreamManager { |
||||
return &testStreamManager{} |
||||
} |
||||
|
||||
func (sm *testStreamManager) addNewStream(st sttypes.Stream) { |
||||
sm.newStreamFeed.Send(streammanager.EvtStreamAdded{Stream: st}) |
||||
} |
||||
|
||||
func (sm *testStreamManager) rmStream(stid sttypes.StreamID) { |
||||
sm.rmStreamFeed.Send(streammanager.EvtStreamRemoved{ID: stid}) |
||||
} |
||||
|
||||
func (sm *testStreamManager) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription { |
||||
return sm.newStreamFeed.Subscribe(ch) |
||||
} |
||||
|
||||
func (sm *testStreamManager) SubscribeRemoveStreamEvent(ch chan<- streammanager.EvtStreamRemoved) event.Subscription { |
||||
return sm.rmStreamFeed.Subscribe(ch) |
||||
} |
||||
|
||||
type testStream struct { |
||||
id sttypes.StreamID |
||||
rm *requestManager |
||||
deliver func(*testRequest) // use goroutine inside this function
|
||||
} |
||||
|
||||
func (st *testStream) ID() sttypes.StreamID { |
||||
return st.id |
||||
} |
||||
|
||||
func (st *testStream) ProtoID() sttypes.ProtoID { |
||||
return testProtoID |
||||
} |
||||
|
||||
func (st *testStream) WriteBytes(b []byte) error { |
||||
req, err := decodeTestRequest(b) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if st.rm != nil && st.deliver != nil { |
||||
st.deliver(req) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (st *testStream) ReadBytes() ([]byte, error) { |
||||
return nil, nil |
||||
} |
||||
|
||||
func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) { |
||||
return sttypes.ProtoIDToProtoSpec(testProtoID) |
||||
} |
||||
|
||||
func (st *testStream) Close() error { |
||||
return nil |
||||
} |
||||
|
||||
func (st *testStream) ResetOnClose() error { |
||||
return nil |
||||
} |
||||
|
||||
func makeStreamID(index int) sttypes.StreamID { |
||||
return sttypes.StreamID(strconv.Itoa(index)) |
||||
} |
||||
|
||||
type testRequest struct { |
||||
reqID uint64 |
||||
index uint64 |
||||
} |
||||
|
||||
func makeTestRequest(index uint64) *testRequest { |
||||
return &testRequest{ |
||||
reqID: 0, |
||||
index: index, |
||||
} |
||||
} |
||||
|
||||
func (req *testRequest) ReqID() uint64 { |
||||
return req.reqID |
||||
} |
||||
|
||||
func (req *testRequest) SetReqID(rid uint64) { |
||||
req.reqID = rid |
||||
} |
||||
|
||||
func (req *testRequest) String() string { |
||||
return fmt.Sprintf("test request %v", req.index) |
||||
} |
||||
|
||||
func (req *testRequest) Encode() ([]byte, error) { |
||||
return rlp.EncodeToBytes(struct { |
||||
ReqID uint64 |
||||
Index uint64 |
||||
}{ |
||||
ReqID: req.reqID, |
||||
Index: req.index, |
||||
}) |
||||
} |
||||
|
||||
func (req *testRequest) checkResponse(rawResp sttypes.Response) error { |
||||
resp, ok := rawResp.(*testResponse) |
||||
if !ok || resp == nil { |
||||
return errors.New("not test Response") |
||||
} |
||||
if req.reqID != resp.reqID { |
||||
return errors.New("request id not expected") |
||||
} |
||||
if req.index != resp.index { |
||||
return errors.New("response id not expected") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func decodeTestRequest(b []byte) (*testRequest, error) { |
||||
type SerRequest struct { |
||||
ReqID uint64 |
||||
Index uint64 |
||||
} |
||||
var sr SerRequest |
||||
if err := rlp.DecodeBytes(b, &sr); err != nil { |
||||
return nil, err |
||||
} |
||||
return &testRequest{ |
||||
reqID: sr.ReqID, |
||||
index: sr.Index, |
||||
}, nil |
||||
} |
||||
|
||||
func (req *testRequest) IsSupportedByProto(spec sttypes.ProtoSpec) bool { |
||||
return true |
||||
} |
||||
|
||||
func (req *testRequest) getResponse() *testResponse { |
||||
return &testResponse{ |
||||
reqID: req.reqID, |
||||
index: req.index, |
||||
} |
||||
} |
||||
|
||||
type testResponse struct { |
||||
reqID uint64 |
||||
index uint64 |
||||
} |
||||
|
||||
func (tr *testResponse) ReqID() uint64 { |
||||
return tr.reqID |
||||
} |
||||
|
||||
func (tr *testResponse) String() string { |
||||
return fmt.Sprintf("test response %v", tr.index) |
||||
} |
@ -0,0 +1,39 @@ |
||||
package requestmanager |
||||
|
||||
import sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
|
||||
// RequestOption is the additional instruction for requests.
|
||||
// Currently, two options are supported:
|
||||
// 1. WithHighPriority
|
||||
// 2. WithBlacklist
|
||||
// 3. WithWhitelist
|
||||
type RequestOption func(*request) |
||||
|
||||
// WithHighPriority is the request option to do request with higher priority.
|
||||
// High priority requests are done first.
|
||||
func WithHighPriority() RequestOption { |
||||
return func(req *request) { |
||||
req.priority = reqPriorityHigh |
||||
} |
||||
} |
||||
|
||||
// WithBlacklist is the request option not to assign the request to the blacklisted
|
||||
// stream ID.
|
||||
func WithBlacklist(blacklist []sttypes.StreamID) RequestOption { |
||||
return func(req *request) { |
||||
for _, stid := range blacklist { |
||||
req.addBlacklistedStream(stid) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// WithWhitelist is the request option to restrict the request to be assigned to the
|
||||
// given stream IDs.
|
||||
// If a request is not with this option, all streams will be allowed.
|
||||
func WithWhitelist(whitelist []sttypes.StreamID) RequestOption { |
||||
return func(req *request) { |
||||
for _, stid := range whitelist { |
||||
req.addWhiteListStream(stid) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,410 @@ |
||||
package requestmanager |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/pkg/errors" |
||||
"github.com/rs/zerolog" |
||||
|
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/harmony-one/harmony/internal/utils" |
||||
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
) |
||||
|
||||
// requestManager implements RequestManager. It is responsible for matching response
|
||||
// with requests.
|
||||
// TODO: each peer is able to have a queue of requests instead of one request at a time.
|
||||
// TODO: add QoS evaluation for each stream
|
||||
type requestManager struct { |
||||
streams map[sttypes.StreamID]*stream // All streams
|
||||
available map[sttypes.StreamID]struct{} // Streams that are available for request
|
||||
pendings map[uint64]*request // requests that are sent but not received response
|
||||
waitings requestQueue // double linked list of requests that are on the waiting list
|
||||
|
||||
// Stream events
|
||||
newStreamC <-chan streammanager.EvtStreamAdded |
||||
rmStreamC <-chan streammanager.EvtStreamRemoved |
||||
// Request events
|
||||
cancelReqC chan cancelReqData // request being canceled
|
||||
deliveryC chan responseData |
||||
newRequestC chan *request |
||||
|
||||
subs []event.Subscription |
||||
logger zerolog.Logger |
||||
stopC chan struct{} |
||||
lock sync.Mutex |
||||
} |
||||
|
||||
// NewRequestManager creates a new request manager
|
||||
func NewRequestManager(sm streammanager.Subscriber) RequestManager { |
||||
return newRequestManager(sm) |
||||
} |
||||
|
||||
func newRequestManager(sm streammanager.Subscriber) *requestManager { |
||||
// subscribe at initialize to prevent misuse of upper function which might cause
|
||||
// the bootstrap peers are ignored
|
||||
newStreamC := make(chan streammanager.EvtStreamAdded) |
||||
rmStreamC := make(chan streammanager.EvtStreamRemoved) |
||||
sub1 := sm.SubscribeAddStreamEvent(newStreamC) |
||||
sub2 := sm.SubscribeRemoveStreamEvent(rmStreamC) |
||||
|
||||
logger := utils.Logger().With().Str("module", "request manager").Logger() |
||||
|
||||
return &requestManager{ |
||||
streams: make(map[sttypes.StreamID]*stream), |
||||
available: make(map[sttypes.StreamID]struct{}), |
||||
pendings: make(map[uint64]*request), |
||||
waitings: newRequestQueue(), |
||||
|
||||
newStreamC: newStreamC, |
||||
rmStreamC: rmStreamC, |
||||
cancelReqC: make(chan cancelReqData, 16), |
||||
deliveryC: make(chan responseData, 128), |
||||
newRequestC: make(chan *request, 128), |
||||
|
||||
subs: []event.Subscription{sub1, sub2}, |
||||
logger: logger, |
||||
stopC: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
func (rm *requestManager) Start() { |
||||
go rm.loop() |
||||
} |
||||
|
||||
func (rm *requestManager) Close() { |
||||
rm.stopC <- struct{}{} |
||||
} |
||||
|
||||
// DoRequest do the given request with a stream picked randomly. Return the response, stream id that
|
||||
// is responsible for response, delivery and error.
|
||||
func (rm *requestManager) DoRequest(ctx context.Context, raw sttypes.Request, options ...RequestOption) (sttypes.Response, sttypes.StreamID, error) { |
||||
resp := <-rm.doRequestAsync(ctx, raw, options...) |
||||
return resp.resp, resp.stID, resp.err |
||||
} |
||||
|
||||
func (rm *requestManager) doRequestAsync(ctx context.Context, raw sttypes.Request, options ...RequestOption) <-chan responseData { |
||||
req := &request{ |
||||
Request: raw, |
||||
respC: make(chan responseData), |
||||
doneC: make(chan struct{}), |
||||
} |
||||
for _, opt := range options { |
||||
opt(req) |
||||
} |
||||
rm.newRequestC <- req |
||||
|
||||
go func() { |
||||
select { |
||||
case <-ctx.Done(): // canceled or timeout in upper function calls
|
||||
rm.cancelReqC <- cancelReqData{ |
||||
reqID: req.ReqID(), |
||||
err: ctx.Err(), |
||||
} |
||||
case <-req.doneC: |
||||
} |
||||
}() |
||||
return req.respC |
||||
} |
||||
|
||||
// DeliverResponse delivers the response to the corresponding request.
|
||||
// The function behaves non-block
|
||||
func (rm *requestManager) DeliverResponse(stID sttypes.StreamID, resp sttypes.Response) { |
||||
sd := responseData{ |
||||
resp: resp, |
||||
stID: stID, |
||||
} |
||||
go func() { |
||||
select { |
||||
case rm.deliveryC <- sd: |
||||
case <-time.After(deliverTimeout): |
||||
rm.logger.Error().Msg("WARNING: delivery timeout. Possible stuck in loop") |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func (rm *requestManager) loop() { |
||||
var ( |
||||
throttleC = make(chan struct{}, 1) // throttle the waiting requests periodically
|
||||
ticker = time.NewTicker(throttleInterval) |
||||
) |
||||
throttle := func() { |
||||
select { |
||||
case throttleC <- struct{}{}: |
||||
default: |
||||
} |
||||
} |
||||
|
||||
for { |
||||
select { |
||||
case <-ticker.C: |
||||
throttle() |
||||
|
||||
case <-throttleC: |
||||
loop: |
||||
for i := 0; i != throttleBatch; i++ { |
||||
req, st := rm.getNextRequest() |
||||
if req == nil { |
||||
break loop |
||||
} |
||||
rm.addPendingRequest(req, st) |
||||
b, err := req.Encode() |
||||
if err != nil { |
||||
rm.logger.Warn().Str("request", req.String()).Err(err). |
||||
Msg("request encode error") |
||||
} |
||||
|
||||
go func() { |
||||
if err := st.WriteBytes(b); err != nil { |
||||
rm.logger.Warn().Str("streamID", string(st.ID())).Err(err). |
||||
Msg("write bytes") |
||||
req.doneWithResponse(responseData{ |
||||
stID: st.ID(), |
||||
err: errors.Wrap(err, "write bytes"), |
||||
}) |
||||
} |
||||
}() |
||||
} |
||||
|
||||
case req := <-rm.newRequestC: |
||||
added := rm.handleNewRequest(req) |
||||
if added { |
||||
throttle() |
||||
} |
||||
|
||||
case data := <-rm.deliveryC: |
||||
rm.handleDeliverData(data) |
||||
|
||||
case data := <-rm.cancelReqC: |
||||
rm.handleCancelRequest(data) |
||||
|
||||
case evt := <-rm.newStreamC: |
||||
rm.logger.Info().Str("streamID", string(evt.Stream.ID())).Msg("add new stream") |
||||
rm.addNewStream(evt.Stream) |
||||
|
||||
case evt := <-rm.rmStreamC: |
||||
rm.logger.Info().Str("streamID", string(evt.ID)).Msg("remove stream") |
||||
rm.removeStream(evt.ID) |
||||
|
||||
case <-rm.stopC: |
||||
rm.logger.Info().Msg("request manager stopped") |
||||
rm.close() |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (rm *requestManager) handleNewRequest(req *request) bool { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
err := rm.addRequestToWaitings(req, reqPriorityLow) |
||||
if err != nil { |
||||
rm.logger.Warn().Err(err).Msg("failed to add new request to waitings") |
||||
req.doneWithResponse(responseData{ |
||||
err: errors.Wrap(err, "failed to add new request to waitings"), |
||||
}) |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
func (rm *requestManager) handleDeliverData(data responseData) { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
if err := rm.validateDelivery(data); err != nil { |
||||
// if error happens in delivery, most likely it's a stale delivery. No action needed
|
||||
// and return
|
||||
rm.logger.Info().Err(err).Str("response", data.resp.String()).Msg("unable to validate deliver") |
||||
return |
||||
} |
||||
// req and st is ensured not to be empty in validateDelivery
|
||||
req := rm.pendings[data.resp.ReqID()] |
||||
req.doneWithResponse(data) |
||||
rm.removePendingRequest(req) |
||||
} |
||||
|
||||
func (rm *requestManager) validateDelivery(data responseData) error { |
||||
if data.err != nil { |
||||
return data.err |
||||
} |
||||
st := rm.streams[data.stID] |
||||
if st == nil { |
||||
return fmt.Errorf("data delivered from dead stream: %v", data.stID) |
||||
} |
||||
req := rm.pendings[data.resp.ReqID()] |
||||
if req == nil { |
||||
return fmt.Errorf("stale p2p response delivery") |
||||
} |
||||
if req.owner == nil || req.owner.ID() != data.stID { |
||||
return fmt.Errorf("unexpected delivery stream") |
||||
} |
||||
if st.req == nil || st.req.ReqID() != data.resp.ReqID() { |
||||
// Possible when request is canceled
|
||||
return fmt.Errorf("unexpected deliver request") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (rm *requestManager) handleCancelRequest(data cancelReqData) { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
req, ok := rm.pendings[data.reqID] |
||||
if !ok { |
||||
return |
||||
} |
||||
rm.removePendingRequest(req) |
||||
|
||||
var stid sttypes.StreamID |
||||
if req.owner != nil { |
||||
stid = req.owner.ID() |
||||
} |
||||
|
||||
req.doneWithResponse(responseData{ |
||||
resp: nil, |
||||
stID: stid, |
||||
err: data.err, |
||||
}) |
||||
} |
||||
|
||||
func (rm *requestManager) getNextRequest() (*request, *stream) { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
var req *request |
||||
for { |
||||
req = rm.waitings.Pop() |
||||
if req == nil { |
||||
return nil, nil |
||||
} |
||||
if !req.isDone() { |
||||
break |
||||
} |
||||
} |
||||
|
||||
st, err := rm.pickAvailableStream(req) |
||||
if err != nil { |
||||
rm.logger.Debug().Msg("No available streams.") |
||||
rm.addRequestToWaitings(req, reqPriorityHigh) |
||||
return nil, nil |
||||
} |
||||
return req, st |
||||
} |
||||
|
||||
func (rm *requestManager) genReqID() uint64 { |
||||
for { |
||||
rid := sttypes.GenReqID() |
||||
if _, ok := rm.pendings[rid]; !ok { |
||||
return rid |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (rm *requestManager) addPendingRequest(req *request, st *stream) { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
reqID := rm.genReqID() |
||||
req.SetReqID(reqID) |
||||
|
||||
req.owner = st |
||||
st.req = req |
||||
|
||||
delete(rm.available, st.ID()) |
||||
rm.pendings[req.ReqID()] = req |
||||
} |
||||
|
||||
func (rm *requestManager) removePendingRequest(req *request) { |
||||
delete(rm.pendings, req.ReqID()) |
||||
|
||||
if st := req.owner; st != nil { |
||||
st.clearPendingRequest() |
||||
rm.available[st.ID()] = struct{}{} |
||||
} |
||||
} |
||||
|
||||
func (rm *requestManager) pickAvailableStream(req *request) (*stream, error) { |
||||
for id := range rm.available { |
||||
if !req.isStreamAllowed(id) { |
||||
continue |
||||
} |
||||
st, ok := rm.streams[id] |
||||
if !ok { |
||||
return nil, errors.New("sanity error: available stream not registered") |
||||
} |
||||
if st.req != nil { |
||||
return nil, errors.New("sanity error: available stream has pending requests") |
||||
} |
||||
spec, _ := st.ProtoSpec() |
||||
if req.Request.IsSupportedByProto(spec) { |
||||
return st, nil |
||||
} |
||||
} |
||||
return nil, errors.New("no more available streams") |
||||
} |
||||
|
||||
func (rm *requestManager) addNewStream(st sttypes.Stream) { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
if _, ok := rm.streams[st.ID()]; !ok { |
||||
rm.streams[st.ID()] = &stream{Stream: st} |
||||
rm.available[st.ID()] = struct{}{} |
||||
} |
||||
} |
||||
|
||||
// removeStream remove the stream from request manager, clear the pending request
|
||||
// of the stream. Return whether a pending request is canceled in the stream,
|
||||
func (rm *requestManager) removeStream(id sttypes.StreamID) { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
st, ok := rm.streams[id] |
||||
if !ok { |
||||
return |
||||
} |
||||
delete(rm.available, id) |
||||
delete(rm.streams, id) |
||||
|
||||
cleared := st.clearPendingRequest() |
||||
if cleared != nil { |
||||
cleared.doneWithResponse(responseData{ |
||||
stID: id, |
||||
err: errors.New("stream removed when doing request"), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func (rm *requestManager) close() { |
||||
rm.lock.Lock() |
||||
defer rm.lock.Unlock() |
||||
|
||||
for _, sub := range rm.subs { |
||||
sub.Unsubscribe() |
||||
} |
||||
for _, req := range rm.pendings { |
||||
req.doneWithResponse(responseData{err: ErrClosed}) |
||||
} |
||||
rm.pendings = make(map[uint64]*request) |
||||
rm.available = make(map[sttypes.StreamID]struct{}) |
||||
rm.streams = make(map[sttypes.StreamID]*stream) |
||||
rm.waitings = newRequestQueue() |
||||
close(rm.stopC) |
||||
} |
||||
|
||||
type reqPriority int |
||||
|
||||
const ( |
||||
reqPriorityLow reqPriority = iota |
||||
reqPriorityHigh |
||||
) |
||||
|
||||
func (rm *requestManager) addRequestToWaitings(req *request, priority reqPriority) error { |
||||
return rm.waitings.Push(req, priority) |
||||
} |
@ -0,0 +1,428 @@ |
||||
package requestmanager |
||||
|
||||
import ( |
||||
"context" |
||||
"sync" |
||||
"sync/atomic" |
||||
"testing" |
||||
"time" |
||||
|
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
defTestSleep = 50 * time.Millisecond |
||||
) |
||||
|
||||
// Request is delivered right away as expected
|
||||
func TestRequestManager_Request_Normal(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second) |
||||
res := <-ts.rm.doRequestAsync(ctx, req) |
||||
|
||||
if res.err != nil { |
||||
t.Errorf("unexpected error: %v", res.err) |
||||
return |
||||
} |
||||
if err := req.checkResponse(res.resp); err != nil { |
||||
t.Error(err) |
||||
} |
||||
if res.stID == "" { |
||||
t.Errorf("unexpected stid") |
||||
} |
||||
} |
||||
|
||||
// The request is canceled by context
|
||||
func TestRequestManager_Request_Cancel(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(500 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
||||
resC := ts.rm.doRequestAsync(ctx, req) |
||||
|
||||
time.Sleep(defTestSleep) |
||||
cancel() |
||||
|
||||
res := <-resC |
||||
if res.err != context.Canceled { |
||||
t.Errorf("unexpected error: %v", res.err) |
||||
} |
||||
if res.stID == "" { |
||||
t.Errorf("unexpected canceled request should also have stid") |
||||
} |
||||
} |
||||
|
||||
// error happens when adding request to waiting list
|
||||
func TestRequestManager_NewStream(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(500 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
ts.sm.addNewStream(ts.makeTestStream(3)) |
||||
|
||||
time.Sleep(defTestSleep) |
||||
|
||||
ts.rm.lock.Lock() |
||||
if len(ts.rm.streams) != 4 || len(ts.rm.available) != 4 { |
||||
t.Errorf("unexpected stream size") |
||||
} |
||||
ts.rm.lock.Unlock() |
||||
} |
||||
|
||||
// For request assigned to the stream being removed, the request will be rescheduled.
|
||||
func TestRequestManager_RemoveStream(t *testing.T) { |
||||
delayF := makeOnceBlockDelayFunc(150 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) |
||||
resC := ts.rm.doRequestAsync(ctx, req) |
||||
time.Sleep(defTestSleep) |
||||
|
||||
// remove the stream which is responsible for the request
|
||||
idToRemove := ts.pickOneOccupiedStream() |
||||
ts.sm.rmStream(idToRemove) |
||||
|
||||
// the request is rescheduled thus there is supposed to be no errors
|
||||
res := <-resC |
||||
if res.err == nil { |
||||
t.Errorf("unexpected error: %v", errors.New("stream removed when doing request")) |
||||
} |
||||
|
||||
ts.rm.lock.Lock() |
||||
if len(ts.rm.streams) != 2 || len(ts.rm.available) != 2 { |
||||
t.Errorf("unexpected stream size") |
||||
} |
||||
ts.rm.lock.Unlock() |
||||
} |
||||
|
||||
// stream delivers an unknown request ID
|
||||
func TestRequestManager_UnknownDelivery(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||
respF := func(req *testRequest) *testResponse { |
||||
var rid uint64 |
||||
for rid == req.reqID { |
||||
rid++ |
||||
} |
||||
return &testResponse{ |
||||
reqID: rid, |
||||
index: 0, |
||||
} |
||||
} |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
||||
resC := ts.rm.doRequestAsync(ctx, req) |
||||
time.Sleep(6 * time.Second) |
||||
cancel() |
||||
|
||||
// Since the reqID is not delivered, the result is not delivered to the request
|
||||
// and be canceled
|
||||
res := <-resC |
||||
if res.err != context.Canceled { |
||||
t.Errorf("unexpected error: %v", res.err) |
||||
} |
||||
} |
||||
|
||||
// stream delivers a response for a canceled request
|
||||
func TestRequestManager_StaleDelivery(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(1 * time.Second) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) |
||||
resC := ts.rm.doRequestAsync(ctx, req) |
||||
time.Sleep(2 * time.Second) |
||||
|
||||
// Since the reqID is not delivered, the result is not delivered to the request
|
||||
// and be canceled
|
||||
res := <-resC |
||||
if res.err != context.DeadlineExceeded { |
||||
t.Errorf("unexpected error: %v", res.err) |
||||
} |
||||
} |
||||
|
||||
// closing request manager will also close all
|
||||
func TestRequestManager_Close(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(1 * time.Second) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 3) |
||||
ts.Start() |
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) |
||||
resC := ts.rm.doRequestAsync(ctx, makeTestRequest(0)) |
||||
time.Sleep(100 * time.Millisecond) |
||||
ts.Close() |
||||
|
||||
// Since the reqID is not delivered, the result is not delivered to the request
|
||||
// and be canceled
|
||||
res := <-resC |
||||
if assErr := assertError(res.err, errors.New("request manager module closed")); assErr != nil { |
||||
t.Errorf("unexpected error: %v", assErr) |
||||
} |
||||
} |
||||
|
||||
func TestRequestManager_Request_Blacklist(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 4) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second) |
||||
res := <-ts.rm.doRequestAsync(ctx, req, WithBlacklist([]sttypes.StreamID{ |
||||
makeStreamID(0), |
||||
makeStreamID(1), |
||||
makeStreamID(2), |
||||
})) |
||||
|
||||
if res.err != nil { |
||||
t.Errorf("unexpected error: %v", res.err) |
||||
return |
||||
} |
||||
if err := req.checkResponse(res.resp); err != nil { |
||||
t.Error(err) |
||||
} |
||||
if res.stID != makeStreamID(3) { |
||||
t.Errorf("unexpected stid") |
||||
} |
||||
} |
||||
|
||||
func TestRequestManager_Request_Whitelist(t *testing.T) { |
||||
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 4) |
||||
ts.Start() |
||||
defer ts.Close() |
||||
|
||||
req := makeTestRequest(100) |
||||
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second) |
||||
res := <-ts.rm.doRequestAsync(ctx, req, WithWhitelist([]sttypes.StreamID{ |
||||
makeStreamID(3), |
||||
})) |
||||
|
||||
if res.err != nil { |
||||
t.Errorf("unexpected error: %v", res.err) |
||||
return |
||||
} |
||||
if err := req.checkResponse(res.resp); err != nil { |
||||
t.Error(err) |
||||
} |
||||
if res.stID != makeStreamID(3) { |
||||
t.Errorf("unexpected stid") |
||||
} |
||||
} |
||||
|
||||
// test the race condition by spinning up a lot of goroutines
|
||||
func TestRequestManager_Concurrency(t *testing.T) { |
||||
var ( |
||||
testDuration = 10 * time.Second |
||||
numThreads = 25 |
||||
) |
||||
delayF := makeDefaultDelayFunc(100 * time.Millisecond) |
||||
respF := makeDefaultResponseFunc() |
||||
ts := newTestSuite(delayF, respF, 18) |
||||
ts.Start() |
||||
|
||||
stopC := make(chan struct{}) |
||||
var ( |
||||
aErr atomic.Value |
||||
numReqs uint64 |
||||
wg sync.WaitGroup |
||||
) |
||||
wg.Add(numThreads) |
||||
for i := 0; i != numThreads; i++ { |
||||
go func() { |
||||
defer wg.Done() |
||||
for { |
||||
resC := ts.rm.doRequestAsync(context.Background(), makeTestRequest(1000)) |
||||
select { |
||||
case res := <-resC: |
||||
if res.err == nil { |
||||
atomic.AddUint64(&numReqs, 1) |
||||
continue |
||||
} |
||||
if res.err.Error() == "request manager module closed" { |
||||
return |
||||
} |
||||
aErr.Store(res.err.Error()) |
||||
case <-stopC: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
time.Sleep(testDuration) |
||||
close(stopC) |
||||
ts.Close() |
||||
wg.Wait() |
||||
|
||||
if isNilErr := aErr.Load() == nil; !isNilErr { |
||||
err := aErr.Load().(error) |
||||
t.Errorf("unexpected error: %v", err) |
||||
} |
||||
num := atomic.LoadUint64(&numReqs) |
||||
t.Logf("Mock processed requests: %v", num) |
||||
} |
||||
|
||||
func TestGenReqID(t *testing.T) { |
||||
retry := 100000 |
||||
rm := &requestManager{ |
||||
pendings: make(map[uint64]*request), |
||||
} |
||||
|
||||
for i := 0; i != retry; i++ { |
||||
rid := rm.genReqID() |
||||
if _, ok := rm.pendings[rid]; ok { |
||||
t.Errorf("rid collision") |
||||
} |
||||
rm.pendings[rid] = nil |
||||
} |
||||
} |
||||
|
||||
type testSuite struct { |
||||
rm *requestManager |
||||
sm *testStreamManager |
||||
bootStreams []*testStream |
||||
|
||||
delayFunc delayFunc |
||||
respFunc responseFunc |
||||
|
||||
ctx context.Context |
||||
cancel func() |
||||
} |
||||
|
||||
func newTestSuite(delayF delayFunc, respF responseFunc, numStreams int) *testSuite { |
||||
sm := newTestStreamManager() |
||||
rm := newRequestManager(sm) |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
||||
ts := &testSuite{ |
||||
rm: rm, |
||||
sm: sm, |
||||
bootStreams: make([]*testStream, 0, numStreams), |
||||
delayFunc: delayF, |
||||
respFunc: respF, |
||||
ctx: ctx, |
||||
cancel: cancel, |
||||
} |
||||
for i := 0; i != numStreams; i++ { |
||||
ts.bootStreams = append(ts.bootStreams, ts.makeTestStream(i)) |
||||
} |
||||
return ts |
||||
} |
||||
|
||||
func (ts *testSuite) Start() { |
||||
ts.rm.Start() |
||||
for _, st := range ts.bootStreams { |
||||
ts.sm.addNewStream(st) |
||||
} |
||||
} |
||||
|
||||
func (ts *testSuite) Close() { |
||||
ts.rm.Close() |
||||
ts.cancel() |
||||
} |
||||
|
||||
func (ts *testSuite) pickOneOccupiedStream() sttypes.StreamID { |
||||
ts.rm.lock.Lock() |
||||
defer ts.rm.lock.Unlock() |
||||
|
||||
for _, req := range ts.rm.pendings { |
||||
return req.owner.ID() |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
type ( |
||||
// responseFunc is the function to compose a response
|
||||
responseFunc func(request *testRequest) *testResponse |
||||
|
||||
// delayFunc is the function to determine the delay to deliver a response
|
||||
delayFunc func() time.Duration |
||||
) |
||||
|
||||
func makeDefaultResponseFunc() responseFunc { |
||||
return func(request *testRequest) *testResponse { |
||||
resp := &testResponse{ |
||||
reqID: request.reqID, |
||||
index: request.index, |
||||
} |
||||
return resp |
||||
} |
||||
} |
||||
|
||||
func checkResponseMessage(request sttypes.Request, response sttypes.Response) error { |
||||
tReq, ok := request.(*testRequest) |
||||
if !ok || tReq == nil { |
||||
return errors.New("request not testRequest") |
||||
} |
||||
tResp, ok := response.(*testResponse) |
||||
if !ok || tResp == nil { |
||||
return errors.New("response not testResponse") |
||||
} |
||||
return tReq.checkResponse(tResp) |
||||
} |
||||
|
||||
func makeDefaultDelayFunc(delay time.Duration) delayFunc { |
||||
return func() time.Duration { |
||||
return delay |
||||
} |
||||
} |
||||
|
||||
func makeOnceBlockDelayFunc(normalDelay time.Duration) delayFunc { |
||||
// This usage of once is nasty. Avoid using once like this in production code.
|
||||
var once sync.Once |
||||
return func() time.Duration { |
||||
var block bool |
||||
once.Do(func() { |
||||
block = true |
||||
}) |
||||
if block { |
||||
return time.Hour |
||||
} |
||||
return normalDelay |
||||
} |
||||
} |
||||
|
||||
func (ts *testSuite) makeTestStream(index int) *testStream { |
||||
stid := makeStreamID(index) |
||||
return &testStream{ |
||||
id: stid, |
||||
rm: ts.rm, |
||||
deliver: func(req *testRequest) { |
||||
delay := ts.delayFunc() |
||||
resp := ts.respFunc(req) |
||||
go func() { |
||||
select { |
||||
case <-ts.ctx.Done(): |
||||
case <-time.After(delay): |
||||
ts.rm.DeliverResponse(stid, resp) |
||||
} |
||||
}() |
||||
}, |
||||
} |
||||
} |
@ -0,0 +1,182 @@ |
||||
package requestmanager |
||||
|
||||
import ( |
||||
"container/list" |
||||
"sync" |
||||
"sync/atomic" |
||||
|
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
// ErrQueueFull is the error happens when the waiting queue is already full
|
||||
ErrQueueFull = errors.New("waiting request queue already full") |
||||
|
||||
// ErrClosed is request error that the module is closed during request
|
||||
ErrClosed = errors.New("request manager module closed") |
||||
) |
||||
|
||||
// stream is the wrapped version of sttypes.Stream.
|
||||
// TODO: enable stream handle multiple pending requests at the same time
|
||||
type stream struct { |
||||
sttypes.Stream |
||||
req *request // currently one stream is dealing with one request
|
||||
} |
||||
|
||||
// request is the wrapped request within module
|
||||
type request struct { |
||||
sttypes.Request // underlying request
|
||||
// result field
|
||||
respC chan responseData // channel to receive response from delivered message
|
||||
// concurrency control
|
||||
atmDone uint32 |
||||
doneC chan struct{} |
||||
// stream info
|
||||
owner *stream // Current owner
|
||||
// utils
|
||||
lock sync.RWMutex |
||||
raw *interface{} |
||||
// options
|
||||
priority reqPriority |
||||
blacklist map[sttypes.StreamID]struct{} // banned streams
|
||||
whitelist map[sttypes.StreamID]struct{} // allowed streams
|
||||
} |
||||
|
||||
func (req *request) ReqID() uint64 { |
||||
req.lock.RLock() |
||||
defer req.lock.RUnlock() |
||||
|
||||
return req.Request.ReqID() |
||||
} |
||||
|
||||
func (req *request) SetReqID(val uint64) { |
||||
req.lock.Lock() |
||||
defer req.lock.Unlock() |
||||
|
||||
req.Request.SetReqID(val) |
||||
} |
||||
|
||||
func (req *request) doneWithResponse(resp responseData) { |
||||
notDone := atomic.CompareAndSwapUint32(&req.atmDone, 0, 1) |
||||
if notDone { |
||||
req.respC <- resp |
||||
close(req.respC) |
||||
close(req.doneC) |
||||
} |
||||
} |
||||
|
||||
func (req *request) isDone() bool { |
||||
return atomic.LoadUint32(&req.atmDone) == 1 |
||||
} |
||||
|
||||
func (req *request) isStreamAllowed(stid sttypes.StreamID) bool { |
||||
return req.isStreamWhitelisted(stid) && !req.isStreamBlacklisted(stid) |
||||
} |
||||
|
||||
func (req *request) addBlacklistedStream(stid sttypes.StreamID) { |
||||
if req.blacklist == nil { |
||||
req.blacklist = make(map[sttypes.StreamID]struct{}) |
||||
} |
||||
req.blacklist[stid] = struct{}{} |
||||
} |
||||
|
||||
func (req *request) isStreamBlacklisted(stid sttypes.StreamID) bool { |
||||
if req.blacklist == nil { |
||||
return false |
||||
} |
||||
_, ok := req.blacklist[stid] |
||||
return ok |
||||
} |
||||
|
||||
func (req *request) addWhiteListStream(stid sttypes.StreamID) { |
||||
if req.whitelist == nil { |
||||
req.whitelist = make(map[sttypes.StreamID]struct{}) |
||||
} |
||||
req.whitelist[stid] = struct{}{} |
||||
} |
||||
|
||||
func (req *request) isStreamWhitelisted(stid sttypes.StreamID) bool { |
||||
if req.whitelist == nil { |
||||
return true |
||||
} |
||||
_, ok := req.whitelist[stid] |
||||
return ok |
||||
} |
||||
|
||||
func (st *stream) clearPendingRequest() *request { |
||||
req := st.req |
||||
if req == nil { |
||||
return nil |
||||
} |
||||
st.req = nil |
||||
return req |
||||
} |
||||
|
||||
type cancelReqData struct { |
||||
reqID uint64 |
||||
err error |
||||
} |
||||
|
||||
// responseData is the wrapped response for stream requests
|
||||
type responseData struct { |
||||
resp sttypes.Response |
||||
stID sttypes.StreamID |
||||
err error |
||||
} |
||||
|
||||
// requestQueue is a wrapper of double linked list with Request as type
|
||||
type requestQueue struct { |
||||
reqsPHigh *list.List // high priority, currently defined by upper function calls
|
||||
reqsPLow *list.List // low priority, applied to all normal requests
|
||||
lock sync.Mutex |
||||
} |
||||
|
||||
func newRequestQueue() requestQueue { |
||||
return requestQueue{ |
||||
reqsPHigh: list.New(), |
||||
reqsPLow: list.New(), |
||||
} |
||||
} |
||||
|
||||
// Push add a new request to requestQueue.
|
||||
func (q *requestQueue) Push(req *request, priority reqPriority) error { |
||||
q.lock.Lock() |
||||
defer q.lock.Unlock() |
||||
|
||||
if priority == reqPriorityHigh || req.priority == reqPriorityHigh { |
||||
return pushRequestToList(q.reqsPHigh, req) |
||||
} |
||||
if priority == reqPriorityLow { |
||||
return pushRequestToList(q.reqsPLow, req) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Pop will first pop the request from high priority, and then pop from low priority
|
||||
func (q *requestQueue) Pop() *request { |
||||
q.lock.Lock() |
||||
defer q.lock.Unlock() |
||||
|
||||
if req := popRequestFromList(q.reqsPHigh); req != nil { |
||||
return req |
||||
} |
||||
return popRequestFromList(q.reqsPLow) |
||||
} |
||||
|
||||
func pushRequestToList(l *list.List, req *request) error { |
||||
if l.Len() >= maxWaitingSize { |
||||
return ErrQueueFull |
||||
} |
||||
l.PushBack(req) |
||||
return nil |
||||
} |
||||
|
||||
func popRequestFromList(l *list.List) *request { |
||||
elem := l.Front() |
||||
if elem == nil { |
||||
return nil |
||||
} |
||||
l.Remove(elem) |
||||
return elem.Value.(*request) |
||||
} |
@ -0,0 +1,165 @@ |
||||
package requestmanager |
||||
|
||||
import ( |
||||
"container/list" |
||||
"fmt" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
func TestRequestQueue_Push(t *testing.T) { |
||||
tests := []struct { |
||||
initSize []int |
||||
priority reqPriority |
||||
expSize []int |
||||
expErr error |
||||
}{ |
||||
{ |
||||
initSize: []int{10, 10}, |
||||
priority: reqPriorityHigh, |
||||
expSize: []int{11, 10}, |
||||
expErr: nil, |
||||
}, |
||||
{ |
||||
initSize: []int{10, 10}, |
||||
priority: reqPriorityLow, |
||||
expSize: []int{10, 11}, |
||||
expErr: nil, |
||||
}, |
||||
{ |
||||
initSize: []int{maxWaitingSize, maxWaitingSize}, |
||||
priority: reqPriorityLow, |
||||
expErr: ErrQueueFull, |
||||
}, |
||||
{ |
||||
initSize: []int{maxWaitingSize, maxWaitingSize}, |
||||
priority: reqPriorityHigh, |
||||
expErr: ErrQueueFull, |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
q := makeTestRequestQueue(test.initSize) |
||||
req := wrapRequestFromRaw(makeTestRequest(100)) |
||||
|
||||
err := q.Push(req, test.priority) |
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
if err != nil || test.expErr != nil { |
||||
continue |
||||
} |
||||
|
||||
if err := q.checkSizes(test.expSize); err != nil { |
||||
t.Errorf("Test %v: %v", i, err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestRequestQueue_Pop(t *testing.T) { |
||||
tests := []struct { |
||||
initSizes []int |
||||
expNil bool |
||||
expIndex uint64 |
||||
expSizes []int |
||||
}{ |
||||
{ |
||||
initSizes: []int{10, 10}, |
||||
expNil: false, |
||||
expIndex: 0, |
||||
expSizes: []int{9, 10}, |
||||
}, |
||||
{ |
||||
initSizes: []int{0, 10}, |
||||
expNil: false, |
||||
expIndex: 0, |
||||
expSizes: []int{0, 9}, |
||||
}, |
||||
{ |
||||
initSizes: []int{0, 0}, |
||||
expNil: true, |
||||
expSizes: []int{0, 0}, |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
q := makeTestRequestQueue(test.initSizes) |
||||
req := q.Pop() |
||||
|
||||
if err := q.checkSizes(test.expSizes); err != nil { |
||||
t.Errorf("Test %v: %v", i, err) |
||||
} |
||||
if req == nil != (test.expNil) { |
||||
t.Errorf("test %v: unpected nil", i) |
||||
} |
||||
if req == nil { |
||||
continue |
||||
} |
||||
index := req.Request.(*testRequest).index |
||||
if index != test.expIndex { |
||||
t.Errorf("Test %v: unexpected index: %v / %v", i, index, test.expIndex) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func makeTestRequestQueue(sizes []int) requestQueue { |
||||
if len(sizes) != 2 { |
||||
panic("unexpected sizes") |
||||
} |
||||
q := newRequestQueue() |
||||
|
||||
index := 0 |
||||
for i := 0; i != sizes[0]; i++ { |
||||
q.reqsPHigh.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) |
||||
index++ |
||||
} |
||||
for i := 0; i != sizes[1]; i++ { |
||||
q.reqsPLow.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) |
||||
index++ |
||||
} |
||||
return q |
||||
} |
||||
|
||||
func wrapRequestFromRaw(raw *testRequest) *request { |
||||
return &request{ |
||||
Request: raw, |
||||
} |
||||
} |
||||
|
||||
func getTestRequestFromElem(elem *list.Element) (*testRequest, error) { |
||||
req, ok := elem.Value.(*request) |
||||
if !ok { |
||||
return nil, errors.New("unexpected type") |
||||
} |
||||
raw, ok := req.Request.(*testRequest) |
||||
if !ok { |
||||
return nil, errors.New("unexpected raw types") |
||||
} |
||||
return raw, nil |
||||
} |
||||
|
||||
func (q *requestQueue) checkSizes(sizes []int) error { |
||||
if len(sizes) != 2 { |
||||
panic("expect 2 sizes") |
||||
} |
||||
if q.reqsPHigh.Len() != sizes[0] { |
||||
return fmt.Errorf("high priority %v / %v", q.reqsPHigh.Len(), sizes[0]) |
||||
} |
||||
if q.reqsPLow.Len() != sizes[1] { |
||||
return fmt.Errorf("low priority %v / %v", q.reqsPLow.Len(), sizes[2]) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func assertError(got, exp error) error { |
||||
if (got == nil) != (exp == nil) { |
||||
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||
} |
||||
if got == nil { |
||||
return nil |
||||
} |
||||
if !strings.Contains(got.Error(), exp.Error()) { |
||||
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||
} |
||||
return nil |
||||
} |
Loading…
Reference in new issue