Merge pull request #3560 from JackyWYX/requestmanager
[Stream] Added module requestmanagerpull/3578/head
commit
d97cf9b6f4
@ -0,0 +1,19 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import "time" |
||||||
|
|
||||||
|
// TODO: determine the values in production environment
|
||||||
|
const ( |
||||||
|
// throttle to do request every 100 milliseconds
|
||||||
|
throttleInterval = 100 * time.Millisecond |
||||||
|
|
||||||
|
// number of request to be done in each throttle loop
|
||||||
|
throttleBatch = 16 |
||||||
|
|
||||||
|
// deliverTimeout is the timeout for a response delivery. If the response cannot be delivered
|
||||||
|
// within timeout because blocking of the channel, the response will be dropped.
|
||||||
|
deliverTimeout = 5 * time.Second |
||||||
|
|
||||||
|
// maxWaitingSize is the maximum requests that are in waiting list
|
||||||
|
maxWaitingSize = 1024 |
||||||
|
) |
@ -0,0 +1,25 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
|
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
p2ptypes "github.com/harmony-one/harmony/p2p/types" |
||||||
|
) |
||||||
|
|
||||||
|
// Requester is the interface to do request
|
||||||
|
type Requester interface { |
||||||
|
DoRequest(ctx context.Context, request sttypes.Request, options ...RequestOption) (sttypes.Response, sttypes.StreamID, error) |
||||||
|
} |
||||||
|
|
||||||
|
// Deliverer is the interface to deliver a response
|
||||||
|
type Deliverer interface { |
||||||
|
DeliverResponse(stID sttypes.StreamID, resp sttypes.Response) |
||||||
|
} |
||||||
|
|
||||||
|
// RequestManager manages over the requests
|
||||||
|
type RequestManager interface { |
||||||
|
p2ptypes.LifeCycle |
||||||
|
Requester |
||||||
|
Deliverer |
||||||
|
} |
@ -0,0 +1,171 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"strconv" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
"github.com/ethereum/go-ethereum/rlp" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
) |
||||||
|
|
||||||
|
var testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") |
||||||
|
|
||||||
|
type testStreamManager struct { |
||||||
|
newStreamFeed event.Feed |
||||||
|
rmStreamFeed event.Feed |
||||||
|
} |
||||||
|
|
||||||
|
func newTestStreamManager() *testStreamManager { |
||||||
|
return &testStreamManager{} |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) addNewStream(st sttypes.Stream) { |
||||||
|
sm.newStreamFeed.Send(streammanager.EvtStreamAdded{Stream: st}) |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) rmStream(stid sttypes.StreamID) { |
||||||
|
sm.rmStreamFeed.Send(streammanager.EvtStreamRemoved{ID: stid}) |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription { |
||||||
|
return sm.newStreamFeed.Subscribe(ch) |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *testStreamManager) SubscribeRemoveStreamEvent(ch chan<- streammanager.EvtStreamRemoved) event.Subscription { |
||||||
|
return sm.rmStreamFeed.Subscribe(ch) |
||||||
|
} |
||||||
|
|
||||||
|
type testStream struct { |
||||||
|
id sttypes.StreamID |
||||||
|
rm *requestManager |
||||||
|
deliver func(*testRequest) // use goroutine inside this function
|
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ID() sttypes.StreamID { |
||||||
|
return st.id |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ProtoID() sttypes.ProtoID { |
||||||
|
return testProtoID |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) WriteBytes(b []byte) error { |
||||||
|
req, err := decodeTestRequest(b) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if st.rm != nil && st.deliver != nil { |
||||||
|
st.deliver(req) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ReadBytes() ([]byte, error) { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) { |
||||||
|
return sttypes.ProtoIDToProtoSpec(testProtoID) |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) Close() error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ResetOnClose() error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func makeStreamID(index int) sttypes.StreamID { |
||||||
|
return sttypes.StreamID(strconv.Itoa(index)) |
||||||
|
} |
||||||
|
|
||||||
|
type testRequest struct { |
||||||
|
reqID uint64 |
||||||
|
index uint64 |
||||||
|
} |
||||||
|
|
||||||
|
func makeTestRequest(index uint64) *testRequest { |
||||||
|
return &testRequest{ |
||||||
|
reqID: 0, |
||||||
|
index: index, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) ReqID() uint64 { |
||||||
|
return req.reqID |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) SetReqID(rid uint64) { |
||||||
|
req.reqID = rid |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) String() string { |
||||||
|
return fmt.Sprintf("test request %v", req.index) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) Encode() ([]byte, error) { |
||||||
|
return rlp.EncodeToBytes(struct { |
||||||
|
ReqID uint64 |
||||||
|
Index uint64 |
||||||
|
}{ |
||||||
|
ReqID: req.reqID, |
||||||
|
Index: req.index, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) checkResponse(rawResp sttypes.Response) error { |
||||||
|
resp, ok := rawResp.(*testResponse) |
||||||
|
if !ok || resp == nil { |
||||||
|
return errors.New("not test Response") |
||||||
|
} |
||||||
|
if req.reqID != resp.reqID { |
||||||
|
return errors.New("request id not expected") |
||||||
|
} |
||||||
|
if req.index != resp.index { |
||||||
|
return errors.New("response id not expected") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func decodeTestRequest(b []byte) (*testRequest, error) { |
||||||
|
type SerRequest struct { |
||||||
|
ReqID uint64 |
||||||
|
Index uint64 |
||||||
|
} |
||||||
|
var sr SerRequest |
||||||
|
if err := rlp.DecodeBytes(b, &sr); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &testRequest{ |
||||||
|
reqID: sr.ReqID, |
||||||
|
index: sr.Index, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) IsSupportedByProto(spec sttypes.ProtoSpec) bool { |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
func (req *testRequest) getResponse() *testResponse { |
||||||
|
return &testResponse{ |
||||||
|
reqID: req.reqID, |
||||||
|
index: req.index, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type testResponse struct { |
||||||
|
reqID uint64 |
||||||
|
index uint64 |
||||||
|
} |
||||||
|
|
||||||
|
func (tr *testResponse) ReqID() uint64 { |
||||||
|
return tr.reqID |
||||||
|
} |
||||||
|
|
||||||
|
func (tr *testResponse) String() string { |
||||||
|
return fmt.Sprintf("test response %v", tr.index) |
||||||
|
} |
@ -0,0 +1,39 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
|
||||||
|
// RequestOption is the additional instruction for requests.
|
||||||
|
// Currently, two options are supported:
|
||||||
|
// 1. WithHighPriority
|
||||||
|
// 2. WithBlacklist
|
||||||
|
// 3. WithWhitelist
|
||||||
|
type RequestOption func(*request) |
||||||
|
|
||||||
|
// WithHighPriority is the request option to do request with higher priority.
|
||||||
|
// High priority requests are done first.
|
||||||
|
func WithHighPriority() RequestOption { |
||||||
|
return func(req *request) { |
||||||
|
req.priority = reqPriorityHigh |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// WithBlacklist is the request option not to assign the request to the blacklisted
|
||||||
|
// stream ID.
|
||||||
|
func WithBlacklist(blacklist []sttypes.StreamID) RequestOption { |
||||||
|
return func(req *request) { |
||||||
|
for _, stid := range blacklist { |
||||||
|
req.addBlacklistedStream(stid) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// WithWhitelist is the request option to restrict the request to be assigned to the
|
||||||
|
// given stream IDs.
|
||||||
|
// If a request is not with this option, all streams will be allowed.
|
||||||
|
func WithWhitelist(whitelist []sttypes.StreamID) RequestOption { |
||||||
|
return func(req *request) { |
||||||
|
for _, stid := range whitelist { |
||||||
|
req.addWhiteListStream(stid) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,410 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
"github.com/rs/zerolog" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
"github.com/harmony-one/harmony/internal/utils" |
||||||
|
"github.com/harmony-one/harmony/p2p/stream/common/streammanager" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
) |
||||||
|
|
||||||
|
// requestManager implements RequestManager. It is responsible for matching response
|
||||||
|
// with requests.
|
||||||
|
// TODO: each peer is able to have a queue of requests instead of one request at a time.
|
||||||
|
// TODO: add QoS evaluation for each stream
|
||||||
|
type requestManager struct { |
||||||
|
streams map[sttypes.StreamID]*stream // All streams
|
||||||
|
available map[sttypes.StreamID]struct{} // Streams that are available for request
|
||||||
|
pendings map[uint64]*request // requests that are sent but not received response
|
||||||
|
waitings requestQueue // double linked list of requests that are on the waiting list
|
||||||
|
|
||||||
|
// Stream events
|
||||||
|
newStreamC <-chan streammanager.EvtStreamAdded |
||||||
|
rmStreamC <-chan streammanager.EvtStreamRemoved |
||||||
|
// Request events
|
||||||
|
cancelReqC chan cancelReqData // request being canceled
|
||||||
|
deliveryC chan responseData |
||||||
|
newRequestC chan *request |
||||||
|
|
||||||
|
subs []event.Subscription |
||||||
|
logger zerolog.Logger |
||||||
|
stopC chan struct{} |
||||||
|
lock sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
// NewRequestManager creates a new request manager
|
||||||
|
func NewRequestManager(sm streammanager.Subscriber) RequestManager { |
||||||
|
return newRequestManager(sm) |
||||||
|
} |
||||||
|
|
||||||
|
func newRequestManager(sm streammanager.Subscriber) *requestManager { |
||||||
|
// subscribe at initialize to prevent misuse of upper function which might cause
|
||||||
|
// the bootstrap peers are ignored
|
||||||
|
newStreamC := make(chan streammanager.EvtStreamAdded) |
||||||
|
rmStreamC := make(chan streammanager.EvtStreamRemoved) |
||||||
|
sub1 := sm.SubscribeAddStreamEvent(newStreamC) |
||||||
|
sub2 := sm.SubscribeRemoveStreamEvent(rmStreamC) |
||||||
|
|
||||||
|
logger := utils.Logger().With().Str("module", "request manager").Logger() |
||||||
|
|
||||||
|
return &requestManager{ |
||||||
|
streams: make(map[sttypes.StreamID]*stream), |
||||||
|
available: make(map[sttypes.StreamID]struct{}), |
||||||
|
pendings: make(map[uint64]*request), |
||||||
|
waitings: newRequestQueue(), |
||||||
|
|
||||||
|
newStreamC: newStreamC, |
||||||
|
rmStreamC: rmStreamC, |
||||||
|
cancelReqC: make(chan cancelReqData, 16), |
||||||
|
deliveryC: make(chan responseData, 128), |
||||||
|
newRequestC: make(chan *request, 128), |
||||||
|
|
||||||
|
subs: []event.Subscription{sub1, sub2}, |
||||||
|
logger: logger, |
||||||
|
stopC: make(chan struct{}), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) Start() { |
||||||
|
go rm.loop() |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) Close() { |
||||||
|
rm.stopC <- struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
// DoRequest do the given request with a stream picked randomly. Return the response, stream id that
|
||||||
|
// is responsible for response, delivery and error.
|
||||||
|
func (rm *requestManager) DoRequest(ctx context.Context, raw sttypes.Request, options ...RequestOption) (sttypes.Response, sttypes.StreamID, error) { |
||||||
|
resp := <-rm.doRequestAsync(ctx, raw, options...) |
||||||
|
return resp.resp, resp.stID, resp.err |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) doRequestAsync(ctx context.Context, raw sttypes.Request, options ...RequestOption) <-chan responseData { |
||||||
|
req := &request{ |
||||||
|
Request: raw, |
||||||
|
respC: make(chan responseData), |
||||||
|
doneC: make(chan struct{}), |
||||||
|
} |
||||||
|
for _, opt := range options { |
||||||
|
opt(req) |
||||||
|
} |
||||||
|
rm.newRequestC <- req |
||||||
|
|
||||||
|
go func() { |
||||||
|
select { |
||||||
|
case <-ctx.Done(): // canceled or timeout in upper function calls
|
||||||
|
rm.cancelReqC <- cancelReqData{ |
||||||
|
reqID: req.ReqID(), |
||||||
|
err: ctx.Err(), |
||||||
|
} |
||||||
|
case <-req.doneC: |
||||||
|
} |
||||||
|
}() |
||||||
|
return req.respC |
||||||
|
} |
||||||
|
|
||||||
|
// DeliverResponse delivers the response to the corresponding request.
|
||||||
|
// The function behaves non-block
|
||||||
|
func (rm *requestManager) DeliverResponse(stID sttypes.StreamID, resp sttypes.Response) { |
||||||
|
sd := responseData{ |
||||||
|
resp: resp, |
||||||
|
stID: stID, |
||||||
|
} |
||||||
|
go func() { |
||||||
|
select { |
||||||
|
case rm.deliveryC <- sd: |
||||||
|
case <-time.After(deliverTimeout): |
||||||
|
rm.logger.Error().Msg("WARNING: delivery timeout. Possible stuck in loop") |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) loop() { |
||||||
|
var ( |
||||||
|
throttleC = make(chan struct{}, 1) // throttle the waiting requests periodically
|
||||||
|
ticker = time.NewTicker(throttleInterval) |
||||||
|
) |
||||||
|
throttle := func() { |
||||||
|
select { |
||||||
|
case throttleC <- struct{}{}: |
||||||
|
default: |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-ticker.C: |
||||||
|
throttle() |
||||||
|
|
||||||
|
case <-throttleC: |
||||||
|
loop: |
||||||
|
for i := 0; i != throttleBatch; i++ { |
||||||
|
req, st := rm.getNextRequest() |
||||||
|
if req == nil { |
||||||
|
break loop |
||||||
|
} |
||||||
|
rm.addPendingRequest(req, st) |
||||||
|
b, err := req.Encode() |
||||||
|
if err != nil { |
||||||
|
rm.logger.Warn().Str("request", req.String()).Err(err). |
||||||
|
Msg("request encode error") |
||||||
|
} |
||||||
|
|
||||||
|
go func() { |
||||||
|
if err := st.WriteBytes(b); err != nil { |
||||||
|
rm.logger.Warn().Str("streamID", string(st.ID())).Err(err). |
||||||
|
Msg("write bytes") |
||||||
|
req.doneWithResponse(responseData{ |
||||||
|
stID: st.ID(), |
||||||
|
err: errors.Wrap(err, "write bytes"), |
||||||
|
}) |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
|
||||||
|
case req := <-rm.newRequestC: |
||||||
|
added := rm.handleNewRequest(req) |
||||||
|
if added { |
||||||
|
throttle() |
||||||
|
} |
||||||
|
|
||||||
|
case data := <-rm.deliveryC: |
||||||
|
rm.handleDeliverData(data) |
||||||
|
|
||||||
|
case data := <-rm.cancelReqC: |
||||||
|
rm.handleCancelRequest(data) |
||||||
|
|
||||||
|
case evt := <-rm.newStreamC: |
||||||
|
rm.logger.Info().Str("streamID", string(evt.Stream.ID())).Msg("add new stream") |
||||||
|
rm.addNewStream(evt.Stream) |
||||||
|
|
||||||
|
case evt := <-rm.rmStreamC: |
||||||
|
rm.logger.Info().Str("streamID", string(evt.ID)).Msg("remove stream") |
||||||
|
rm.removeStream(evt.ID) |
||||||
|
|
||||||
|
case <-rm.stopC: |
||||||
|
rm.logger.Info().Msg("request manager stopped") |
||||||
|
rm.close() |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) handleNewRequest(req *request) bool { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
err := rm.addRequestToWaitings(req, reqPriorityLow) |
||||||
|
if err != nil { |
||||||
|
rm.logger.Warn().Err(err).Msg("failed to add new request to waitings") |
||||||
|
req.doneWithResponse(responseData{ |
||||||
|
err: errors.Wrap(err, "failed to add new request to waitings"), |
||||||
|
}) |
||||||
|
return false |
||||||
|
} |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) handleDeliverData(data responseData) { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
if err := rm.validateDelivery(data); err != nil { |
||||||
|
// if error happens in delivery, most likely it's a stale delivery. No action needed
|
||||||
|
// and return
|
||||||
|
rm.logger.Info().Err(err).Str("response", data.resp.String()).Msg("unable to validate deliver") |
||||||
|
return |
||||||
|
} |
||||||
|
// req and st is ensured not to be empty in validateDelivery
|
||||||
|
req := rm.pendings[data.resp.ReqID()] |
||||||
|
req.doneWithResponse(data) |
||||||
|
rm.removePendingRequest(req) |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) validateDelivery(data responseData) error { |
||||||
|
if data.err != nil { |
||||||
|
return data.err |
||||||
|
} |
||||||
|
st := rm.streams[data.stID] |
||||||
|
if st == nil { |
||||||
|
return fmt.Errorf("data delivered from dead stream: %v", data.stID) |
||||||
|
} |
||||||
|
req := rm.pendings[data.resp.ReqID()] |
||||||
|
if req == nil { |
||||||
|
return fmt.Errorf("stale p2p response delivery") |
||||||
|
} |
||||||
|
if req.owner == nil || req.owner.ID() != data.stID { |
||||||
|
return fmt.Errorf("unexpected delivery stream") |
||||||
|
} |
||||||
|
if st.req == nil || st.req.ReqID() != data.resp.ReqID() { |
||||||
|
// Possible when request is canceled
|
||||||
|
return fmt.Errorf("unexpected deliver request") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) handleCancelRequest(data cancelReqData) { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
req, ok := rm.pendings[data.reqID] |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
rm.removePendingRequest(req) |
||||||
|
|
||||||
|
var stid sttypes.StreamID |
||||||
|
if req.owner != nil { |
||||||
|
stid = req.owner.ID() |
||||||
|
} |
||||||
|
|
||||||
|
req.doneWithResponse(responseData{ |
||||||
|
resp: nil, |
||||||
|
stID: stid, |
||||||
|
err: data.err, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) getNextRequest() (*request, *stream) { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
var req *request |
||||||
|
for { |
||||||
|
req = rm.waitings.Pop() |
||||||
|
if req == nil { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
if !req.isDone() { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
st, err := rm.pickAvailableStream(req) |
||||||
|
if err != nil { |
||||||
|
rm.logger.Debug().Msg("No available streams.") |
||||||
|
rm.addRequestToWaitings(req, reqPriorityHigh) |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
return req, st |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) genReqID() uint64 { |
||||||
|
for { |
||||||
|
rid := sttypes.GenReqID() |
||||||
|
if _, ok := rm.pendings[rid]; !ok { |
||||||
|
return rid |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) addPendingRequest(req *request, st *stream) { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
reqID := rm.genReqID() |
||||||
|
req.SetReqID(reqID) |
||||||
|
|
||||||
|
req.owner = st |
||||||
|
st.req = req |
||||||
|
|
||||||
|
delete(rm.available, st.ID()) |
||||||
|
rm.pendings[req.ReqID()] = req |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) removePendingRequest(req *request) { |
||||||
|
delete(rm.pendings, req.ReqID()) |
||||||
|
|
||||||
|
if st := req.owner; st != nil { |
||||||
|
st.clearPendingRequest() |
||||||
|
rm.available[st.ID()] = struct{}{} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) pickAvailableStream(req *request) (*stream, error) { |
||||||
|
for id := range rm.available { |
||||||
|
if !req.isStreamAllowed(id) { |
||||||
|
continue |
||||||
|
} |
||||||
|
st, ok := rm.streams[id] |
||||||
|
if !ok { |
||||||
|
return nil, errors.New("sanity error: available stream not registered") |
||||||
|
} |
||||||
|
if st.req != nil { |
||||||
|
return nil, errors.New("sanity error: available stream has pending requests") |
||||||
|
} |
||||||
|
spec, _ := st.ProtoSpec() |
||||||
|
if req.Request.IsSupportedByProto(spec) { |
||||||
|
return st, nil |
||||||
|
} |
||||||
|
} |
||||||
|
return nil, errors.New("no more available streams") |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) addNewStream(st sttypes.Stream) { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
if _, ok := rm.streams[st.ID()]; !ok { |
||||||
|
rm.streams[st.ID()] = &stream{Stream: st} |
||||||
|
rm.available[st.ID()] = struct{}{} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// removeStream remove the stream from request manager, clear the pending request
|
||||||
|
// of the stream. Return whether a pending request is canceled in the stream,
|
||||||
|
func (rm *requestManager) removeStream(id sttypes.StreamID) { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
st, ok := rm.streams[id] |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
delete(rm.available, id) |
||||||
|
delete(rm.streams, id) |
||||||
|
|
||||||
|
cleared := st.clearPendingRequest() |
||||||
|
if cleared != nil { |
||||||
|
cleared.doneWithResponse(responseData{ |
||||||
|
stID: id, |
||||||
|
err: errors.New("stream removed when doing request"), |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (rm *requestManager) close() { |
||||||
|
rm.lock.Lock() |
||||||
|
defer rm.lock.Unlock() |
||||||
|
|
||||||
|
for _, sub := range rm.subs { |
||||||
|
sub.Unsubscribe() |
||||||
|
} |
||||||
|
for _, req := range rm.pendings { |
||||||
|
req.doneWithResponse(responseData{err: ErrClosed}) |
||||||
|
} |
||||||
|
rm.pendings = make(map[uint64]*request) |
||||||
|
rm.available = make(map[sttypes.StreamID]struct{}) |
||||||
|
rm.streams = make(map[sttypes.StreamID]*stream) |
||||||
|
rm.waitings = newRequestQueue() |
||||||
|
close(rm.stopC) |
||||||
|
} |
||||||
|
|
||||||
|
type reqPriority int |
||||||
|
|
||||||
|
const ( |
||||||
|
reqPriorityLow reqPriority = iota |
||||||
|
reqPriorityHigh |
||||||
|
) |
||||||
|
|
||||||
|
func (rm *requestManager) addRequestToWaitings(req *request, priority reqPriority) error { |
||||||
|
return rm.waitings.Push(req, priority) |
||||||
|
} |
@ -0,0 +1,428 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
defTestSleep = 50 * time.Millisecond |
||||||
|
) |
||||||
|
|
||||||
|
// Request is delivered right away as expected
|
||||||
|
func TestRequestManager_Request_Normal(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second) |
||||||
|
res := <-ts.rm.doRequestAsync(ctx, req) |
||||||
|
|
||||||
|
if res.err != nil { |
||||||
|
t.Errorf("unexpected error: %v", res.err) |
||||||
|
return |
||||||
|
} |
||||||
|
if err := req.checkResponse(res.resp); err != nil { |
||||||
|
t.Error(err) |
||||||
|
} |
||||||
|
if res.stID == "" { |
||||||
|
t.Errorf("unexpected stid") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// The request is canceled by context
|
||||||
|
func TestRequestManager_Request_Cancel(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(500 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
||||||
|
resC := ts.rm.doRequestAsync(ctx, req) |
||||||
|
|
||||||
|
time.Sleep(defTestSleep) |
||||||
|
cancel() |
||||||
|
|
||||||
|
res := <-resC |
||||||
|
if res.err != context.Canceled { |
||||||
|
t.Errorf("unexpected error: %v", res.err) |
||||||
|
} |
||||||
|
if res.stID == "" { |
||||||
|
t.Errorf("unexpected canceled request should also have stid") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// error happens when adding request to waiting list
|
||||||
|
func TestRequestManager_NewStream(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(500 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
ts.sm.addNewStream(ts.makeTestStream(3)) |
||||||
|
|
||||||
|
time.Sleep(defTestSleep) |
||||||
|
|
||||||
|
ts.rm.lock.Lock() |
||||||
|
if len(ts.rm.streams) != 4 || len(ts.rm.available) != 4 { |
||||||
|
t.Errorf("unexpected stream size") |
||||||
|
} |
||||||
|
ts.rm.lock.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// For request assigned to the stream being removed, the request will be rescheduled.
|
||||||
|
func TestRequestManager_RemoveStream(t *testing.T) { |
||||||
|
delayF := makeOnceBlockDelayFunc(150 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) |
||||||
|
resC := ts.rm.doRequestAsync(ctx, req) |
||||||
|
time.Sleep(defTestSleep) |
||||||
|
|
||||||
|
// remove the stream which is responsible for the request
|
||||||
|
idToRemove := ts.pickOneOccupiedStream() |
||||||
|
ts.sm.rmStream(idToRemove) |
||||||
|
|
||||||
|
// the request is rescheduled thus there is supposed to be no errors
|
||||||
|
res := <-resC |
||||||
|
if res.err == nil { |
||||||
|
t.Errorf("unexpected error: %v", errors.New("stream removed when doing request")) |
||||||
|
} |
||||||
|
|
||||||
|
ts.rm.lock.Lock() |
||||||
|
if len(ts.rm.streams) != 2 || len(ts.rm.available) != 2 { |
||||||
|
t.Errorf("unexpected stream size") |
||||||
|
} |
||||||
|
ts.rm.lock.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// stream delivers an unknown request ID
|
||||||
|
func TestRequestManager_UnknownDelivery(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||||
|
respF := func(req *testRequest) *testResponse { |
||||||
|
var rid uint64 |
||||||
|
for rid == req.reqID { |
||||||
|
rid++ |
||||||
|
} |
||||||
|
return &testResponse{ |
||||||
|
reqID: rid, |
||||||
|
index: 0, |
||||||
|
} |
||||||
|
} |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
||||||
|
resC := ts.rm.doRequestAsync(ctx, req) |
||||||
|
time.Sleep(6 * time.Second) |
||||||
|
cancel() |
||||||
|
|
||||||
|
// Since the reqID is not delivered, the result is not delivered to the request
|
||||||
|
// and be canceled
|
||||||
|
res := <-resC |
||||||
|
if res.err != context.Canceled { |
||||||
|
t.Errorf("unexpected error: %v", res.err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// stream delivers a response for a canceled request
|
||||||
|
func TestRequestManager_StaleDelivery(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(1 * time.Second) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) |
||||||
|
resC := ts.rm.doRequestAsync(ctx, req) |
||||||
|
time.Sleep(2 * time.Second) |
||||||
|
|
||||||
|
// Since the reqID is not delivered, the result is not delivered to the request
|
||||||
|
// and be canceled
|
||||||
|
res := <-resC |
||||||
|
if res.err != context.DeadlineExceeded { |
||||||
|
t.Errorf("unexpected error: %v", res.err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// closing request manager will also close all
|
||||||
|
func TestRequestManager_Close(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(1 * time.Second) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 3) |
||||||
|
ts.Start() |
||||||
|
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) |
||||||
|
resC := ts.rm.doRequestAsync(ctx, makeTestRequest(0)) |
||||||
|
time.Sleep(100 * time.Millisecond) |
||||||
|
ts.Close() |
||||||
|
|
||||||
|
// Since the reqID is not delivered, the result is not delivered to the request
|
||||||
|
// and be canceled
|
||||||
|
res := <-resC |
||||||
|
if assErr := assertError(res.err, errors.New("request manager module closed")); assErr != nil { |
||||||
|
t.Errorf("unexpected error: %v", assErr) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRequestManager_Request_Blacklist(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 4) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second) |
||||||
|
res := <-ts.rm.doRequestAsync(ctx, req, WithBlacklist([]sttypes.StreamID{ |
||||||
|
makeStreamID(0), |
||||||
|
makeStreamID(1), |
||||||
|
makeStreamID(2), |
||||||
|
})) |
||||||
|
|
||||||
|
if res.err != nil { |
||||||
|
t.Errorf("unexpected error: %v", res.err) |
||||||
|
return |
||||||
|
} |
||||||
|
if err := req.checkResponse(res.resp); err != nil { |
||||||
|
t.Error(err) |
||||||
|
} |
||||||
|
if res.stID != makeStreamID(3) { |
||||||
|
t.Errorf("unexpected stid") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRequestManager_Request_Whitelist(t *testing.T) { |
||||||
|
delayF := makeDefaultDelayFunc(150 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 4) |
||||||
|
ts.Start() |
||||||
|
defer ts.Close() |
||||||
|
|
||||||
|
req := makeTestRequest(100) |
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second) |
||||||
|
res := <-ts.rm.doRequestAsync(ctx, req, WithWhitelist([]sttypes.StreamID{ |
||||||
|
makeStreamID(3), |
||||||
|
})) |
||||||
|
|
||||||
|
if res.err != nil { |
||||||
|
t.Errorf("unexpected error: %v", res.err) |
||||||
|
return |
||||||
|
} |
||||||
|
if err := req.checkResponse(res.resp); err != nil { |
||||||
|
t.Error(err) |
||||||
|
} |
||||||
|
if res.stID != makeStreamID(3) { |
||||||
|
t.Errorf("unexpected stid") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// test the race condition by spinning up a lot of goroutines
|
||||||
|
func TestRequestManager_Concurrency(t *testing.T) { |
||||||
|
var ( |
||||||
|
testDuration = 10 * time.Second |
||||||
|
numThreads = 25 |
||||||
|
) |
||||||
|
delayF := makeDefaultDelayFunc(100 * time.Millisecond) |
||||||
|
respF := makeDefaultResponseFunc() |
||||||
|
ts := newTestSuite(delayF, respF, 18) |
||||||
|
ts.Start() |
||||||
|
|
||||||
|
stopC := make(chan struct{}) |
||||||
|
var ( |
||||||
|
aErr atomic.Value |
||||||
|
numReqs uint64 |
||||||
|
wg sync.WaitGroup |
||||||
|
) |
||||||
|
wg.Add(numThreads) |
||||||
|
for i := 0; i != numThreads; i++ { |
||||||
|
go func() { |
||||||
|
defer wg.Done() |
||||||
|
for { |
||||||
|
resC := ts.rm.doRequestAsync(context.Background(), makeTestRequest(1000)) |
||||||
|
select { |
||||||
|
case res := <-resC: |
||||||
|
if res.err == nil { |
||||||
|
atomic.AddUint64(&numReqs, 1) |
||||||
|
continue |
||||||
|
} |
||||||
|
if res.err.Error() == "request manager module closed" { |
||||||
|
return |
||||||
|
} |
||||||
|
aErr.Store(res.err.Error()) |
||||||
|
case <-stopC: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
time.Sleep(testDuration) |
||||||
|
close(stopC) |
||||||
|
ts.Close() |
||||||
|
wg.Wait() |
||||||
|
|
||||||
|
if isNilErr := aErr.Load() == nil; !isNilErr { |
||||||
|
err := aErr.Load().(error) |
||||||
|
t.Errorf("unexpected error: %v", err) |
||||||
|
} |
||||||
|
num := atomic.LoadUint64(&numReqs) |
||||||
|
t.Logf("Mock processed requests: %v", num) |
||||||
|
} |
||||||
|
|
||||||
|
func TestGenReqID(t *testing.T) { |
||||||
|
retry := 100000 |
||||||
|
rm := &requestManager{ |
||||||
|
pendings: make(map[uint64]*request), |
||||||
|
} |
||||||
|
|
||||||
|
for i := 0; i != retry; i++ { |
||||||
|
rid := rm.genReqID() |
||||||
|
if _, ok := rm.pendings[rid]; ok { |
||||||
|
t.Errorf("rid collision") |
||||||
|
} |
||||||
|
rm.pendings[rid] = nil |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type testSuite struct { |
||||||
|
rm *requestManager |
||||||
|
sm *testStreamManager |
||||||
|
bootStreams []*testStream |
||||||
|
|
||||||
|
delayFunc delayFunc |
||||||
|
respFunc responseFunc |
||||||
|
|
||||||
|
ctx context.Context |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
func newTestSuite(delayF delayFunc, respF responseFunc, numStreams int) *testSuite { |
||||||
|
sm := newTestStreamManager() |
||||||
|
rm := newRequestManager(sm) |
||||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||||
|
|
||||||
|
ts := &testSuite{ |
||||||
|
rm: rm, |
||||||
|
sm: sm, |
||||||
|
bootStreams: make([]*testStream, 0, numStreams), |
||||||
|
delayFunc: delayF, |
||||||
|
respFunc: respF, |
||||||
|
ctx: ctx, |
||||||
|
cancel: cancel, |
||||||
|
} |
||||||
|
for i := 0; i != numStreams; i++ { |
||||||
|
ts.bootStreams = append(ts.bootStreams, ts.makeTestStream(i)) |
||||||
|
} |
||||||
|
return ts |
||||||
|
} |
||||||
|
|
||||||
|
func (ts *testSuite) Start() { |
||||||
|
ts.rm.Start() |
||||||
|
for _, st := range ts.bootStreams { |
||||||
|
ts.sm.addNewStream(st) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (ts *testSuite) Close() { |
||||||
|
ts.rm.Close() |
||||||
|
ts.cancel() |
||||||
|
} |
||||||
|
|
||||||
|
func (ts *testSuite) pickOneOccupiedStream() sttypes.StreamID { |
||||||
|
ts.rm.lock.Lock() |
||||||
|
defer ts.rm.lock.Unlock() |
||||||
|
|
||||||
|
for _, req := range ts.rm.pendings { |
||||||
|
return req.owner.ID() |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
type ( |
||||||
|
// responseFunc is the function to compose a response
|
||||||
|
responseFunc func(request *testRequest) *testResponse |
||||||
|
|
||||||
|
// delayFunc is the function to determine the delay to deliver a response
|
||||||
|
delayFunc func() time.Duration |
||||||
|
) |
||||||
|
|
||||||
|
func makeDefaultResponseFunc() responseFunc { |
||||||
|
return func(request *testRequest) *testResponse { |
||||||
|
resp := &testResponse{ |
||||||
|
reqID: request.reqID, |
||||||
|
index: request.index, |
||||||
|
} |
||||||
|
return resp |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func checkResponseMessage(request sttypes.Request, response sttypes.Response) error { |
||||||
|
tReq, ok := request.(*testRequest) |
||||||
|
if !ok || tReq == nil { |
||||||
|
return errors.New("request not testRequest") |
||||||
|
} |
||||||
|
tResp, ok := response.(*testResponse) |
||||||
|
if !ok || tResp == nil { |
||||||
|
return errors.New("response not testResponse") |
||||||
|
} |
||||||
|
return tReq.checkResponse(tResp) |
||||||
|
} |
||||||
|
|
||||||
|
func makeDefaultDelayFunc(delay time.Duration) delayFunc { |
||||||
|
return func() time.Duration { |
||||||
|
return delay |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func makeOnceBlockDelayFunc(normalDelay time.Duration) delayFunc { |
||||||
|
// This usage of once is nasty. Avoid using once like this in production code.
|
||||||
|
var once sync.Once |
||||||
|
return func() time.Duration { |
||||||
|
var block bool |
||||||
|
once.Do(func() { |
||||||
|
block = true |
||||||
|
}) |
||||||
|
if block { |
||||||
|
return time.Hour |
||||||
|
} |
||||||
|
return normalDelay |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (ts *testSuite) makeTestStream(index int) *testStream { |
||||||
|
stid := makeStreamID(index) |
||||||
|
return &testStream{ |
||||||
|
id: stid, |
||||||
|
rm: ts.rm, |
||||||
|
deliver: func(req *testRequest) { |
||||||
|
delay := ts.delayFunc() |
||||||
|
resp := ts.respFunc(req) |
||||||
|
go func() { |
||||||
|
select { |
||||||
|
case <-ts.ctx.Done(): |
||||||
|
case <-time.After(delay): |
||||||
|
ts.rm.DeliverResponse(stid, resp) |
||||||
|
} |
||||||
|
}() |
||||||
|
}, |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,182 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"container/list" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
|
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// ErrQueueFull is the error happens when the waiting queue is already full
|
||||||
|
ErrQueueFull = errors.New("waiting request queue already full") |
||||||
|
|
||||||
|
// ErrClosed is request error that the module is closed during request
|
||||||
|
ErrClosed = errors.New("request manager module closed") |
||||||
|
) |
||||||
|
|
||||||
|
// stream is the wrapped version of sttypes.Stream.
|
||||||
|
// TODO: enable stream handle multiple pending requests at the same time
|
||||||
|
type stream struct { |
||||||
|
sttypes.Stream |
||||||
|
req *request // currently one stream is dealing with one request
|
||||||
|
} |
||||||
|
|
||||||
|
// request is the wrapped request within module
|
||||||
|
type request struct { |
||||||
|
sttypes.Request // underlying request
|
||||||
|
// result field
|
||||||
|
respC chan responseData // channel to receive response from delivered message
|
||||||
|
// concurrency control
|
||||||
|
atmDone uint32 |
||||||
|
doneC chan struct{} |
||||||
|
// stream info
|
||||||
|
owner *stream // Current owner
|
||||||
|
// utils
|
||||||
|
lock sync.RWMutex |
||||||
|
raw *interface{} |
||||||
|
// options
|
||||||
|
priority reqPriority |
||||||
|
blacklist map[sttypes.StreamID]struct{} // banned streams
|
||||||
|
whitelist map[sttypes.StreamID]struct{} // allowed streams
|
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) ReqID() uint64 { |
||||||
|
req.lock.RLock() |
||||||
|
defer req.lock.RUnlock() |
||||||
|
|
||||||
|
return req.Request.ReqID() |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) SetReqID(val uint64) { |
||||||
|
req.lock.Lock() |
||||||
|
defer req.lock.Unlock() |
||||||
|
|
||||||
|
req.Request.SetReqID(val) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) doneWithResponse(resp responseData) { |
||||||
|
notDone := atomic.CompareAndSwapUint32(&req.atmDone, 0, 1) |
||||||
|
if notDone { |
||||||
|
req.respC <- resp |
||||||
|
close(req.respC) |
||||||
|
close(req.doneC) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) isDone() bool { |
||||||
|
return atomic.LoadUint32(&req.atmDone) == 1 |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) isStreamAllowed(stid sttypes.StreamID) bool { |
||||||
|
return req.isStreamWhitelisted(stid) && !req.isStreamBlacklisted(stid) |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) addBlacklistedStream(stid sttypes.StreamID) { |
||||||
|
if req.blacklist == nil { |
||||||
|
req.blacklist = make(map[sttypes.StreamID]struct{}) |
||||||
|
} |
||||||
|
req.blacklist[stid] = struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) isStreamBlacklisted(stid sttypes.StreamID) bool { |
||||||
|
if req.blacklist == nil { |
||||||
|
return false |
||||||
|
} |
||||||
|
_, ok := req.blacklist[stid] |
||||||
|
return ok |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) addWhiteListStream(stid sttypes.StreamID) { |
||||||
|
if req.whitelist == nil { |
||||||
|
req.whitelist = make(map[sttypes.StreamID]struct{}) |
||||||
|
} |
||||||
|
req.whitelist[stid] = struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
func (req *request) isStreamWhitelisted(stid sttypes.StreamID) bool { |
||||||
|
if req.whitelist == nil { |
||||||
|
return true |
||||||
|
} |
||||||
|
_, ok := req.whitelist[stid] |
||||||
|
return ok |
||||||
|
} |
||||||
|
|
||||||
|
func (st *stream) clearPendingRequest() *request { |
||||||
|
req := st.req |
||||||
|
if req == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
st.req = nil |
||||||
|
return req |
||||||
|
} |
||||||
|
|
||||||
|
type cancelReqData struct { |
||||||
|
reqID uint64 |
||||||
|
err error |
||||||
|
} |
||||||
|
|
||||||
|
// responseData is the wrapped response for stream requests
|
||||||
|
type responseData struct { |
||||||
|
resp sttypes.Response |
||||||
|
stID sttypes.StreamID |
||||||
|
err error |
||||||
|
} |
||||||
|
|
||||||
|
// requestQueue is a wrapper of double linked list with Request as type
|
||||||
|
type requestQueue struct { |
||||||
|
reqsPHigh *list.List // high priority, currently defined by upper function calls
|
||||||
|
reqsPLow *list.List // low priority, applied to all normal requests
|
||||||
|
lock sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
func newRequestQueue() requestQueue { |
||||||
|
return requestQueue{ |
||||||
|
reqsPHigh: list.New(), |
||||||
|
reqsPLow: list.New(), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Push add a new request to requestQueue.
|
||||||
|
func (q *requestQueue) Push(req *request, priority reqPriority) error { |
||||||
|
q.lock.Lock() |
||||||
|
defer q.lock.Unlock() |
||||||
|
|
||||||
|
if priority == reqPriorityHigh || req.priority == reqPriorityHigh { |
||||||
|
return pushRequestToList(q.reqsPHigh, req) |
||||||
|
} |
||||||
|
if priority == reqPriorityLow { |
||||||
|
return pushRequestToList(q.reqsPLow, req) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Pop will first pop the request from high priority, and then pop from low priority
|
||||||
|
func (q *requestQueue) Pop() *request { |
||||||
|
q.lock.Lock() |
||||||
|
defer q.lock.Unlock() |
||||||
|
|
||||||
|
if req := popRequestFromList(q.reqsPHigh); req != nil { |
||||||
|
return req |
||||||
|
} |
||||||
|
return popRequestFromList(q.reqsPLow) |
||||||
|
} |
||||||
|
|
||||||
|
func pushRequestToList(l *list.List, req *request) error { |
||||||
|
if l.Len() >= maxWaitingSize { |
||||||
|
return ErrQueueFull |
||||||
|
} |
||||||
|
l.PushBack(req) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func popRequestFromList(l *list.List) *request { |
||||||
|
elem := l.Front() |
||||||
|
if elem == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
l.Remove(elem) |
||||||
|
return elem.Value.(*request) |
||||||
|
} |
@ -0,0 +1,165 @@ |
|||||||
|
package requestmanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"container/list" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
func TestRequestQueue_Push(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
initSize []int |
||||||
|
priority reqPriority |
||||||
|
expSize []int |
||||||
|
expErr error |
||||||
|
}{ |
||||||
|
{ |
||||||
|
initSize: []int{10, 10}, |
||||||
|
priority: reqPriorityHigh, |
||||||
|
expSize: []int{11, 10}, |
||||||
|
expErr: nil, |
||||||
|
}, |
||||||
|
{ |
||||||
|
initSize: []int{10, 10}, |
||||||
|
priority: reqPriorityLow, |
||||||
|
expSize: []int{10, 11}, |
||||||
|
expErr: nil, |
||||||
|
}, |
||||||
|
{ |
||||||
|
initSize: []int{maxWaitingSize, maxWaitingSize}, |
||||||
|
priority: reqPriorityLow, |
||||||
|
expErr: ErrQueueFull, |
||||||
|
}, |
||||||
|
{ |
||||||
|
initSize: []int{maxWaitingSize, maxWaitingSize}, |
||||||
|
priority: reqPriorityHigh, |
||||||
|
expErr: ErrQueueFull, |
||||||
|
}, |
||||||
|
} |
||||||
|
for i, test := range tests { |
||||||
|
q := makeTestRequestQueue(test.initSize) |
||||||
|
req := wrapRequestFromRaw(makeTestRequest(100)) |
||||||
|
|
||||||
|
err := q.Push(req, test.priority) |
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
} |
||||||
|
if err != nil || test.expErr != nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if err := q.checkSizes(test.expSize); err != nil { |
||||||
|
t.Errorf("Test %v: %v", i, err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRequestQueue_Pop(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
initSizes []int |
||||||
|
expNil bool |
||||||
|
expIndex uint64 |
||||||
|
expSizes []int |
||||||
|
}{ |
||||||
|
{ |
||||||
|
initSizes: []int{10, 10}, |
||||||
|
expNil: false, |
||||||
|
expIndex: 0, |
||||||
|
expSizes: []int{9, 10}, |
||||||
|
}, |
||||||
|
{ |
||||||
|
initSizes: []int{0, 10}, |
||||||
|
expNil: false, |
||||||
|
expIndex: 0, |
||||||
|
expSizes: []int{0, 9}, |
||||||
|
}, |
||||||
|
{ |
||||||
|
initSizes: []int{0, 0}, |
||||||
|
expNil: true, |
||||||
|
expSizes: []int{0, 0}, |
||||||
|
}, |
||||||
|
} |
||||||
|
for i, test := range tests { |
||||||
|
q := makeTestRequestQueue(test.initSizes) |
||||||
|
req := q.Pop() |
||||||
|
|
||||||
|
if err := q.checkSizes(test.expSizes); err != nil { |
||||||
|
t.Errorf("Test %v: %v", i, err) |
||||||
|
} |
||||||
|
if req == nil != (test.expNil) { |
||||||
|
t.Errorf("test %v: unpected nil", i) |
||||||
|
} |
||||||
|
if req == nil { |
||||||
|
continue |
||||||
|
} |
||||||
|
index := req.Request.(*testRequest).index |
||||||
|
if index != test.expIndex { |
||||||
|
t.Errorf("Test %v: unexpected index: %v / %v", i, index, test.expIndex) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func makeTestRequestQueue(sizes []int) requestQueue { |
||||||
|
if len(sizes) != 2 { |
||||||
|
panic("unexpected sizes") |
||||||
|
} |
||||||
|
q := newRequestQueue() |
||||||
|
|
||||||
|
index := 0 |
||||||
|
for i := 0; i != sizes[0]; i++ { |
||||||
|
q.reqsPHigh.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) |
||||||
|
index++ |
||||||
|
} |
||||||
|
for i := 0; i != sizes[1]; i++ { |
||||||
|
q.reqsPLow.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) |
||||||
|
index++ |
||||||
|
} |
||||||
|
return q |
||||||
|
} |
||||||
|
|
||||||
|
func wrapRequestFromRaw(raw *testRequest) *request { |
||||||
|
return &request{ |
||||||
|
Request: raw, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func getTestRequestFromElem(elem *list.Element) (*testRequest, error) { |
||||||
|
req, ok := elem.Value.(*request) |
||||||
|
if !ok { |
||||||
|
return nil, errors.New("unexpected type") |
||||||
|
} |
||||||
|
raw, ok := req.Request.(*testRequest) |
||||||
|
if !ok { |
||||||
|
return nil, errors.New("unexpected raw types") |
||||||
|
} |
||||||
|
return raw, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (q *requestQueue) checkSizes(sizes []int) error { |
||||||
|
if len(sizes) != 2 { |
||||||
|
panic("expect 2 sizes") |
||||||
|
} |
||||||
|
if q.reqsPHigh.Len() != sizes[0] { |
||||||
|
return fmt.Errorf("high priority %v / %v", q.reqsPHigh.Len(), sizes[0]) |
||||||
|
} |
||||||
|
if q.reqsPLow.Len() != sizes[1] { |
||||||
|
return fmt.Errorf("low priority %v / %v", q.reqsPLow.Len(), sizes[2]) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func assertError(got, exp error) error { |
||||||
|
if (got == nil) != (exp == nil) { |
||||||
|
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||||
|
} |
||||||
|
if got == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
if !strings.Contains(got.Error(), exp.Error()) { |
||||||
|
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
Loading…
Reference in new issue