[stream] added module request manager to manage over client side requests

pull/3560/head
Jacky Wang 4 years ago
parent e08c2ce41e
commit fbed5bcf3b
No known key found for this signature in database
GPG Key ID: 1085CE5F4FF5842C
  1. 19
      p2p/stream/common/requestmanager/config.go
  2. 25
      p2p/stream/common/requestmanager/interface.go
  3. 171
      p2p/stream/common/requestmanager/interface_test.go
  4. 39
      p2p/stream/common/requestmanager/options.go
  5. 410
      p2p/stream/common/requestmanager/requestmanager.go
  6. 428
      p2p/stream/common/requestmanager/requestmanager_test.go
  7. 182
      p2p/stream/common/requestmanager/types.go
  8. 165
      p2p/stream/common/requestmanager/types_test.go

@ -0,0 +1,19 @@
package requestmanager
import "time"
// TODO: determine the values in production environment
const (
// throttle to do request every 100 milliseconds
throttleInterval = 100 * time.Millisecond
// number of request to be done in each throttle loop
throttleBatch = 16
// deliverTimeout is the timeout for a response delivery. If the response cannot be delivered
// within timeout because blocking of the channel, the response will be dropped.
deliverTimeout = 5 * time.Second
// maxWaitingSize is the maximum requests that are in waiting list
maxWaitingSize = 1024
)

@ -0,0 +1,25 @@
package requestmanager
import (
"context"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
p2ptypes "github.com/harmony-one/harmony/p2p/types"
)
// Requester is the interface to do request
type Requester interface {
DoRequest(ctx context.Context, request sttypes.Request, options ...RequestOption) (sttypes.Response, sttypes.StreamID, error)
}
// Deliverer is the interface to deliver a response
type Deliverer interface {
DeliverResponse(stID sttypes.StreamID, resp sttypes.Response)
}
// RequestManager manages over the requests
type RequestManager interface {
p2ptypes.LifeCycle
Requester
Deliverer
}

@ -0,0 +1,171 @@
package requestmanager
import (
"errors"
"fmt"
"strconv"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/rlp"
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
)
var testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0")
type testStreamManager struct {
newStreamFeed event.Feed
rmStreamFeed event.Feed
}
func newTestStreamManager() *testStreamManager {
return &testStreamManager{}
}
func (sm *testStreamManager) addNewStream(st sttypes.Stream) {
sm.newStreamFeed.Send(streammanager.EvtStreamAdded{Stream: st})
}
func (sm *testStreamManager) rmStream(stid sttypes.StreamID) {
sm.rmStreamFeed.Send(streammanager.EvtStreamRemoved{ID: stid})
}
func (sm *testStreamManager) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription {
return sm.newStreamFeed.Subscribe(ch)
}
func (sm *testStreamManager) SubscribeRemoveStreamEvent(ch chan<- streammanager.EvtStreamRemoved) event.Subscription {
return sm.rmStreamFeed.Subscribe(ch)
}
type testStream struct {
id sttypes.StreamID
rm *requestManager
deliver func(*testRequest) // use goroutine inside this function
}
func (st *testStream) ID() sttypes.StreamID {
return st.id
}
func (st *testStream) ProtoID() sttypes.ProtoID {
return testProtoID
}
func (st *testStream) WriteBytes(b []byte) error {
req, err := decodeTestRequest(b)
if err != nil {
return err
}
if st.rm != nil && st.deliver != nil {
st.deliver(req)
}
return nil
}
func (st *testStream) ReadBytes() ([]byte, error) {
return nil, nil
}
func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) {
return sttypes.ProtoIDToProtoSpec(testProtoID)
}
func (st *testStream) Close() error {
return nil
}
func (st *testStream) ResetOnClose() error {
return nil
}
func makeStreamID(index int) sttypes.StreamID {
return sttypes.StreamID(strconv.Itoa(index))
}
type testRequest struct {
reqID uint64
index uint64
}
func makeTestRequest(index uint64) *testRequest {
return &testRequest{
reqID: 0,
index: index,
}
}
func (req *testRequest) ReqID() uint64 {
return req.reqID
}
func (req *testRequest) SetReqID(rid uint64) {
req.reqID = rid
}
func (req *testRequest) String() string {
return fmt.Sprintf("test request %v", req.index)
}
func (req *testRequest) Encode() ([]byte, error) {
return rlp.EncodeToBytes(struct {
ReqID uint64
Index uint64
}{
ReqID: req.reqID,
Index: req.index,
})
}
func (req *testRequest) checkResponse(rawResp sttypes.Response) error {
resp, ok := rawResp.(*testResponse)
if !ok || resp == nil {
return errors.New("not test Response")
}
if req.reqID != resp.reqID {
return errors.New("request id not expected")
}
if req.index != resp.index {
return errors.New("response id not expected")
}
return nil
}
func decodeTestRequest(b []byte) (*testRequest, error) {
type SerRequest struct {
ReqID uint64
Index uint64
}
var sr SerRequest
if err := rlp.DecodeBytes(b, &sr); err != nil {
return nil, err
}
return &testRequest{
reqID: sr.ReqID,
index: sr.Index,
}, nil
}
func (req *testRequest) IsSupportedByProto(spec sttypes.ProtoSpec) bool {
return true
}
func (req *testRequest) getResponse() *testResponse {
return &testResponse{
reqID: req.reqID,
index: req.index,
}
}
type testResponse struct {
reqID uint64
index uint64
}
func (tr *testResponse) ReqID() uint64 {
return tr.reqID
}
func (tr *testResponse) String() string {
return fmt.Sprintf("test response %v", tr.index)
}

@ -0,0 +1,39 @@
package requestmanager
import sttypes "github.com/harmony-one/harmony/p2p/stream/types"
// RequestOption is the additional instruction for requests.
// Currently, two options are supported:
// 1. WithHighPriority
// 2. WithBlacklist
// 3. WithWhitelist
type RequestOption func(*request)
// WithHighPriority is the request option to do request with higher priority.
// High priority requests are done first.
func WithHighPriority() RequestOption {
return func(req *request) {
req.priority = reqPriorityHigh
}
}
// WithBlacklist is the request option not to assign the request to the blacklisted
// stream ID.
func WithBlacklist(blacklist []sttypes.StreamID) RequestOption {
return func(req *request) {
for _, stid := range blacklist {
req.addBlacklistedStream(stid)
}
}
}
// WithWhitelist is the request option to restrict the request to be assigned to the
// given stream IDs.
// If a request is not with this option, all streams will be allowed.
func WithWhitelist(whitelist []sttypes.StreamID) RequestOption {
return func(req *request) {
for _, stid := range whitelist {
req.addWhiteListStream(stid)
}
}
}

@ -0,0 +1,410 @@
package requestmanager
import (
"context"
"fmt"
"sync"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/ethereum/go-ethereum/event"
"github.com/harmony-one/harmony/internal/utils"
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
)
// requestManager implements RequestManager. It is responsible for matching response
// with requests.
// TODO: each peer is able to have a queue of requests instead of one request at a time.
// TODO: add QoS evaluation for each stream
type requestManager struct {
streams map[sttypes.StreamID]*stream // All streams
available map[sttypes.StreamID]struct{} // Streams that are available for request
pendings map[uint64]*request // requests that are sent but not received response
waitings requestQueue // double linked list of requests that are on the waiting list
// Stream events
newStreamC <-chan streammanager.EvtStreamAdded
rmStreamC <-chan streammanager.EvtStreamRemoved
// Request events
cancelReqC chan cancelReqData // request being canceled
deliveryC chan responseData
newRequestC chan *request
subs []event.Subscription
logger zerolog.Logger
stopC chan struct{}
lock sync.Mutex
}
// NewRequestManager creates a new request manager
func NewRequestManager(sm streammanager.Subscriber) RequestManager {
return newRequestManager(sm)
}
func newRequestManager(sm streammanager.Subscriber) *requestManager {
// subscribe at initialize to prevent misuse of upper function which might cause
// the bootstrap peers are ignored
newStreamC := make(chan streammanager.EvtStreamAdded)
rmStreamC := make(chan streammanager.EvtStreamRemoved)
sub1 := sm.SubscribeAddStreamEvent(newStreamC)
sub2 := sm.SubscribeRemoveStreamEvent(rmStreamC)
logger := utils.Logger().With().Str("module", "request manager").Logger()
return &requestManager{
streams: make(map[sttypes.StreamID]*stream),
available: make(map[sttypes.StreamID]struct{}),
pendings: make(map[uint64]*request),
waitings: newRequestQueue(),
newStreamC: newStreamC,
rmStreamC: rmStreamC,
cancelReqC: make(chan cancelReqData, 16),
deliveryC: make(chan responseData, 128),
newRequestC: make(chan *request, 128),
subs: []event.Subscription{sub1, sub2},
logger: logger,
stopC: make(chan struct{}),
}
}
func (rm *requestManager) Start() {
go rm.loop()
}
func (rm *requestManager) Close() {
rm.stopC <- struct{}{}
}
// DoRequest do the given request with a stream picked randomly. Return the response, stream id that
// is responsible for response, delivery and error.
func (rm *requestManager) DoRequest(ctx context.Context, raw sttypes.Request, options ...RequestOption) (sttypes.Response, sttypes.StreamID, error) {
resp := <-rm.doRequestAsync(ctx, raw, options...)
return resp.resp, resp.stID, resp.err
}
func (rm *requestManager) doRequestAsync(ctx context.Context, raw sttypes.Request, options ...RequestOption) <-chan responseData {
req := &request{
Request: raw,
respC: make(chan responseData),
doneC: make(chan struct{}),
}
for _, opt := range options {
opt(req)
}
rm.newRequestC <- req
go func() {
select {
case <-ctx.Done(): // canceled or timeout in upper function calls
rm.cancelReqC <- cancelReqData{
reqID: req.ReqID(),
err: ctx.Err(),
}
case <-req.doneC:
}
}()
return req.respC
}
// DeliverResponse delivers the response to the corresponding request.
// The function behaves non-block
func (rm *requestManager) DeliverResponse(stID sttypes.StreamID, resp sttypes.Response) {
sd := responseData{
resp: resp,
stID: stID,
}
go func() {
select {
case rm.deliveryC <- sd:
case <-time.After(deliverTimeout):
rm.logger.Error().Msg("WARNING: delivery timeout. Possible stuck in loop")
}
}()
}
func (rm *requestManager) loop() {
var (
throttleC = make(chan struct{}, 1) // throttle the waiting requests periodically
ticker = time.NewTicker(throttleInterval)
)
throttle := func() {
select {
case throttleC <- struct{}{}:
default:
}
}
for {
select {
case <-ticker.C:
throttle()
case <-throttleC:
loop:
for i := 0; i != throttleBatch; i++ {
req, st := rm.getNextRequest()
if req == nil {
break loop
}
rm.addPendingRequest(req, st)
b, err := req.Encode()
if err != nil {
rm.logger.Warn().Str("request", req.String()).Err(err).
Msg("request encode error")
}
go func() {
if err := st.WriteBytes(b); err != nil {
rm.logger.Warn().Str("streamID", string(st.ID())).Err(err).
Msg("write bytes")
req.doneWithResponse(responseData{
stID: st.ID(),
err: errors.Wrap(err, "write bytes"),
})
}
}()
}
case req := <-rm.newRequestC:
added := rm.handleNewRequest(req)
if added {
throttle()
}
case data := <-rm.deliveryC:
rm.handleDeliverData(data)
case data := <-rm.cancelReqC:
rm.handleCancelRequest(data)
case evt := <-rm.newStreamC:
rm.logger.Info().Str("streamID", string(evt.Stream.ID())).Msg("add new stream")
rm.addNewStream(evt.Stream)
case evt := <-rm.rmStreamC:
rm.logger.Info().Str("streamID", string(evt.ID)).Msg("remove stream")
rm.removeStream(evt.ID)
case <-rm.stopC:
rm.logger.Info().Msg("request manager stopped")
rm.close()
return
}
}
}
func (rm *requestManager) handleNewRequest(req *request) bool {
rm.lock.Lock()
defer rm.lock.Unlock()
err := rm.addRequestToWaitings(req, reqPriorityLow)
if err != nil {
rm.logger.Warn().Err(err).Msg("failed to add new request to waitings")
req.doneWithResponse(responseData{
err: errors.Wrap(err, "failed to add new request to waitings"),
})
return false
}
return true
}
func (rm *requestManager) handleDeliverData(data responseData) {
rm.lock.Lock()
defer rm.lock.Unlock()
if err := rm.validateDelivery(data); err != nil {
// if error happens in delivery, most likely it's a stale delivery. No action needed
// and return
rm.logger.Info().Err(err).Str("response", data.resp.String()).Msg("unable to validate deliver")
return
}
// req and st is ensured not to be empty in validateDelivery
req := rm.pendings[data.resp.ReqID()]
req.doneWithResponse(data)
rm.removePendingRequest(req)
}
func (rm *requestManager) validateDelivery(data responseData) error {
if data.err != nil {
return data.err
}
st := rm.streams[data.stID]
if st == nil {
return fmt.Errorf("data delivered from dead stream: %v", data.stID)
}
req := rm.pendings[data.resp.ReqID()]
if req == nil {
return fmt.Errorf("stale p2p response delivery")
}
if req.owner == nil || req.owner.ID() != data.stID {
return fmt.Errorf("unexpected delivery stream")
}
if st.req == nil || st.req.ReqID() != data.resp.ReqID() {
// Possible when request is canceled
return fmt.Errorf("unexpected deliver request")
}
return nil
}
func (rm *requestManager) handleCancelRequest(data cancelReqData) {
rm.lock.Lock()
defer rm.lock.Unlock()
req, ok := rm.pendings[data.reqID]
if !ok {
return
}
rm.removePendingRequest(req)
var stid sttypes.StreamID
if req.owner != nil {
stid = req.owner.ID()
}
req.doneWithResponse(responseData{
resp: nil,
stID: stid,
err: data.err,
})
}
func (rm *requestManager) getNextRequest() (*request, *stream) {
rm.lock.Lock()
defer rm.lock.Unlock()
var req *request
for {
req = rm.waitings.Pop()
if req == nil {
return nil, nil
}
if !req.isDone() {
break
}
}
st, err := rm.pickAvailableStream(req)
if err != nil {
rm.logger.Debug().Msg("No available streams.")
rm.addRequestToWaitings(req, reqPriorityHigh)
return nil, nil
}
return req, st
}
func (rm *requestManager) genReqID() uint64 {
for {
rid := sttypes.GenReqID()
if _, ok := rm.pendings[rid]; !ok {
return rid
}
}
}
func (rm *requestManager) addPendingRequest(req *request, st *stream) {
rm.lock.Lock()
defer rm.lock.Unlock()
reqID := rm.genReqID()
req.SetReqID(reqID)
req.owner = st
st.req = req
delete(rm.available, st.ID())
rm.pendings[req.ReqID()] = req
}
func (rm *requestManager) removePendingRequest(req *request) {
delete(rm.pendings, req.ReqID())
if st := req.owner; st != nil {
st.clearPendingRequest()
rm.available[st.ID()] = struct{}{}
}
}
func (rm *requestManager) pickAvailableStream(req *request) (*stream, error) {
for id := range rm.available {
if !req.isStreamAllowed(id) {
continue
}
st, ok := rm.streams[id]
if !ok {
return nil, errors.New("sanity error: available stream not registered")
}
if st.req != nil {
return nil, errors.New("sanity error: available stream has pending requests")
}
spec, _ := st.ProtoSpec()
if req.Request.IsSupportedByProto(spec) {
return st, nil
}
}
return nil, errors.New("no more available streams")
}
func (rm *requestManager) addNewStream(st sttypes.Stream) {
rm.lock.Lock()
defer rm.lock.Unlock()
if _, ok := rm.streams[st.ID()]; !ok {
rm.streams[st.ID()] = &stream{Stream: st}
rm.available[st.ID()] = struct{}{}
}
}
// removeStream remove the stream from request manager, clear the pending request
// of the stream. Return whether a pending request is canceled in the stream,
func (rm *requestManager) removeStream(id sttypes.StreamID) {
rm.lock.Lock()
defer rm.lock.Unlock()
st, ok := rm.streams[id]
if !ok {
return
}
delete(rm.available, id)
delete(rm.streams, id)
cleared := st.clearPendingRequest()
if cleared != nil {
cleared.doneWithResponse(responseData{
stID: id,
err: errors.New("stream removed when doing request"),
})
}
}
func (rm *requestManager) close() {
rm.lock.Lock()
defer rm.lock.Unlock()
for _, sub := range rm.subs {
sub.Unsubscribe()
}
for _, req := range rm.pendings {
req.doneWithResponse(responseData{err: ErrClosed})
}
rm.pendings = make(map[uint64]*request)
rm.available = make(map[sttypes.StreamID]struct{})
rm.streams = make(map[sttypes.StreamID]*stream)
rm.waitings = newRequestQueue()
close(rm.stopC)
}
type reqPriority int
const (
reqPriorityLow reqPriority = iota
reqPriorityHigh
)
func (rm *requestManager) addRequestToWaitings(req *request, priority reqPriority) error {
return rm.waitings.Push(req, priority)
}

@ -0,0 +1,428 @@
package requestmanager
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/pkg/errors"
)
var (
defTestSleep = 50 * time.Millisecond
)
// Request is delivered right away as expected
func TestRequestManager_Request_Normal(t *testing.T) {
delayF := makeDefaultDelayFunc(150 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 3)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second)
res := <-ts.rm.doRequestAsync(ctx, req)
if res.err != nil {
t.Errorf("unexpected error: %v", res.err)
return
}
if err := req.checkResponse(res.resp); err != nil {
t.Error(err)
}
if res.stID == "" {
t.Errorf("unexpected stid")
}
}
// The request is canceled by context
func TestRequestManager_Request_Cancel(t *testing.T) {
delayF := makeDefaultDelayFunc(500 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 3)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
resC := ts.rm.doRequestAsync(ctx, req)
time.Sleep(defTestSleep)
cancel()
res := <-resC
if res.err != context.Canceled {
t.Errorf("unexpected error: %v", res.err)
}
if res.stID == "" {
t.Errorf("unexpected canceled request should also have stid")
}
}
// error happens when adding request to waiting list
func TestRequestManager_NewStream(t *testing.T) {
delayF := makeDefaultDelayFunc(500 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 3)
ts.Start()
defer ts.Close()
ts.sm.addNewStream(ts.makeTestStream(3))
time.Sleep(defTestSleep)
ts.rm.lock.Lock()
if len(ts.rm.streams) != 4 || len(ts.rm.available) != 4 {
t.Errorf("unexpected stream size")
}
ts.rm.lock.Unlock()
}
// For request assigned to the stream being removed, the request will be rescheduled.
func TestRequestManager_RemoveStream(t *testing.T) {
delayF := makeOnceBlockDelayFunc(150 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 3)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
resC := ts.rm.doRequestAsync(ctx, req)
time.Sleep(defTestSleep)
// remove the stream which is responsible for the request
idToRemove := ts.pickOneOccupiedStream()
ts.sm.rmStream(idToRemove)
// the request is rescheduled thus there is supposed to be no errors
res := <-resC
if res.err == nil {
t.Errorf("unexpected error: %v", errors.New("stream removed when doing request"))
}
ts.rm.lock.Lock()
if len(ts.rm.streams) != 2 || len(ts.rm.available) != 2 {
t.Errorf("unexpected stream size")
}
ts.rm.lock.Unlock()
}
// stream delivers an unknown request ID
func TestRequestManager_UnknownDelivery(t *testing.T) {
delayF := makeDefaultDelayFunc(150 * time.Millisecond)
respF := func(req *testRequest) *testResponse {
var rid uint64
for rid == req.reqID {
rid++
}
return &testResponse{
reqID: rid,
index: 0,
}
}
ts := newTestSuite(delayF, respF, 3)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
resC := ts.rm.doRequestAsync(ctx, req)
time.Sleep(6 * time.Second)
cancel()
// Since the reqID is not delivered, the result is not delivered to the request
// and be canceled
res := <-resC
if res.err != context.Canceled {
t.Errorf("unexpected error: %v", res.err)
}
}
// stream delivers a response for a canceled request
func TestRequestManager_StaleDelivery(t *testing.T) {
delayF := makeDefaultDelayFunc(1 * time.Second)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 3)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
resC := ts.rm.doRequestAsync(ctx, req)
time.Sleep(2 * time.Second)
// Since the reqID is not delivered, the result is not delivered to the request
// and be canceled
res := <-resC
if res.err != context.DeadlineExceeded {
t.Errorf("unexpected error: %v", res.err)
}
}
// closing request manager will also close all
func TestRequestManager_Close(t *testing.T) {
delayF := makeDefaultDelayFunc(1 * time.Second)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 3)
ts.Start()
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
resC := ts.rm.doRequestAsync(ctx, makeTestRequest(0))
time.Sleep(100 * time.Millisecond)
ts.Close()
// Since the reqID is not delivered, the result is not delivered to the request
// and be canceled
res := <-resC
if assErr := assertError(res.err, errors.New("request manager module closed")); assErr != nil {
t.Errorf("unexpected error: %v", assErr)
}
}
func TestRequestManager_Request_Blacklist(t *testing.T) {
delayF := makeDefaultDelayFunc(150 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 4)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second)
res := <-ts.rm.doRequestAsync(ctx, req, WithBlacklist([]sttypes.StreamID{
makeStreamID(0),
makeStreamID(1),
makeStreamID(2),
}))
if res.err != nil {
t.Errorf("unexpected error: %v", res.err)
return
}
if err := req.checkResponse(res.resp); err != nil {
t.Error(err)
}
if res.stID != makeStreamID(3) {
t.Errorf("unexpected stid")
}
}
func TestRequestManager_Request_Whitelist(t *testing.T) {
delayF := makeDefaultDelayFunc(150 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 4)
ts.Start()
defer ts.Close()
req := makeTestRequest(100)
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second)
res := <-ts.rm.doRequestAsync(ctx, req, WithWhitelist([]sttypes.StreamID{
makeStreamID(3),
}))
if res.err != nil {
t.Errorf("unexpected error: %v", res.err)
return
}
if err := req.checkResponse(res.resp); err != nil {
t.Error(err)
}
if res.stID != makeStreamID(3) {
t.Errorf("unexpected stid")
}
}
// test the race condition by spinning up a lot of goroutines
func TestRequestManager_Concurrency(t *testing.T) {
var (
testDuration = 10 * time.Second
numThreads = 25
)
delayF := makeDefaultDelayFunc(100 * time.Millisecond)
respF := makeDefaultResponseFunc()
ts := newTestSuite(delayF, respF, 18)
ts.Start()
stopC := make(chan struct{})
var (
aErr atomic.Value
numReqs uint64
wg sync.WaitGroup
)
wg.Add(numThreads)
for i := 0; i != numThreads; i++ {
go func() {
defer wg.Done()
for {
resC := ts.rm.doRequestAsync(context.Background(), makeTestRequest(1000))
select {
case res := <-resC:
if res.err == nil {
atomic.AddUint64(&numReqs, 1)
continue
}
if res.err.Error() == "request manager module closed" {
return
}
aErr.Store(res.err.Error())
case <-stopC:
return
}
}
}()
}
time.Sleep(testDuration)
close(stopC)
ts.Close()
wg.Wait()
if isNilErr := aErr.Load() == nil; !isNilErr {
err := aErr.Load().(error)
t.Errorf("unexpected error: %v", err)
}
num := atomic.LoadUint64(&numReqs)
t.Logf("Mock processed requests: %v", num)
}
func TestGenReqID(t *testing.T) {
retry := 100000
rm := &requestManager{
pendings: make(map[uint64]*request),
}
for i := 0; i != retry; i++ {
rid := rm.genReqID()
if _, ok := rm.pendings[rid]; ok {
t.Errorf("rid collision")
}
rm.pendings[rid] = nil
}
}
type testSuite struct {
rm *requestManager
sm *testStreamManager
bootStreams []*testStream
delayFunc delayFunc
respFunc responseFunc
ctx context.Context
cancel func()
}
func newTestSuite(delayF delayFunc, respF responseFunc, numStreams int) *testSuite {
sm := newTestStreamManager()
rm := newRequestManager(sm)
ctx, cancel := context.WithCancel(context.Background())
ts := &testSuite{
rm: rm,
sm: sm,
bootStreams: make([]*testStream, 0, numStreams),
delayFunc: delayF,
respFunc: respF,
ctx: ctx,
cancel: cancel,
}
for i := 0; i != numStreams; i++ {
ts.bootStreams = append(ts.bootStreams, ts.makeTestStream(i))
}
return ts
}
func (ts *testSuite) Start() {
ts.rm.Start()
for _, st := range ts.bootStreams {
ts.sm.addNewStream(st)
}
}
func (ts *testSuite) Close() {
ts.rm.Close()
ts.cancel()
}
func (ts *testSuite) pickOneOccupiedStream() sttypes.StreamID {
ts.rm.lock.Lock()
defer ts.rm.lock.Unlock()
for _, req := range ts.rm.pendings {
return req.owner.ID()
}
return ""
}
type (
// responseFunc is the function to compose a response
responseFunc func(request *testRequest) *testResponse
// delayFunc is the function to determine the delay to deliver a response
delayFunc func() time.Duration
)
func makeDefaultResponseFunc() responseFunc {
return func(request *testRequest) *testResponse {
resp := &testResponse{
reqID: request.reqID,
index: request.index,
}
return resp
}
}
func checkResponseMessage(request sttypes.Request, response sttypes.Response) error {
tReq, ok := request.(*testRequest)
if !ok || tReq == nil {
return errors.New("request not testRequest")
}
tResp, ok := response.(*testResponse)
if !ok || tResp == nil {
return errors.New("response not testResponse")
}
return tReq.checkResponse(tResp)
}
func makeDefaultDelayFunc(delay time.Duration) delayFunc {
return func() time.Duration {
return delay
}
}
func makeOnceBlockDelayFunc(normalDelay time.Duration) delayFunc {
// This usage of once is nasty. Avoid using once like this in production code.
var once sync.Once
return func() time.Duration {
var block bool
once.Do(func() {
block = true
})
if block {
return time.Hour
}
return normalDelay
}
}
func (ts *testSuite) makeTestStream(index int) *testStream {
stid := makeStreamID(index)
return &testStream{
id: stid,
rm: ts.rm,
deliver: func(req *testRequest) {
delay := ts.delayFunc()
resp := ts.respFunc(req)
go func() {
select {
case <-ts.ctx.Done():
case <-time.After(delay):
ts.rm.DeliverResponse(stid, resp)
}
}()
},
}
}

@ -0,0 +1,182 @@
package requestmanager
import (
"container/list"
"sync"
"sync/atomic"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/pkg/errors"
)
var (
// ErrQueueFull is the error happens when the waiting queue is already full
ErrQueueFull = errors.New("waiting request queue already full")
// ErrClosed is request error that the module is closed during request
ErrClosed = errors.New("request manager module closed")
)
// stream is the wrapped version of sttypes.Stream.
// TODO: enable stream handle multiple pending requests at the same time
type stream struct {
sttypes.Stream
req *request // currently one stream is dealing with one request
}
// request is the wrapped request within module
type request struct {
sttypes.Request // underlying request
// result field
respC chan responseData // channel to receive response from delivered message
// concurrency control
atmDone uint32
doneC chan struct{}
// stream info
owner *stream // Current owner
// utils
lock sync.RWMutex
raw *interface{}
// options
priority reqPriority
blacklist map[sttypes.StreamID]struct{} // banned streams
whitelist map[sttypes.StreamID]struct{} // allowed streams
}
func (req *request) ReqID() uint64 {
req.lock.RLock()
defer req.lock.RUnlock()
return req.Request.ReqID()
}
func (req *request) SetReqID(val uint64) {
req.lock.Lock()
defer req.lock.Unlock()
req.Request.SetReqID(val)
}
func (req *request) doneWithResponse(resp responseData) {
notDone := atomic.CompareAndSwapUint32(&req.atmDone, 0, 1)
if notDone {
req.respC <- resp
close(req.respC)
close(req.doneC)
}
}
func (req *request) isDone() bool {
return atomic.LoadUint32(&req.atmDone) == 1
}
func (req *request) isStreamAllowed(stid sttypes.StreamID) bool {
return req.isStreamWhitelisted(stid) && !req.isStreamBlacklisted(stid)
}
func (req *request) addBlacklistedStream(stid sttypes.StreamID) {
if req.blacklist == nil {
req.blacklist = make(map[sttypes.StreamID]struct{})
}
req.blacklist[stid] = struct{}{}
}
func (req *request) isStreamBlacklisted(stid sttypes.StreamID) bool {
if req.blacklist == nil {
return false
}
_, ok := req.blacklist[stid]
return ok
}
func (req *request) addWhiteListStream(stid sttypes.StreamID) {
if req.whitelist == nil {
req.whitelist = make(map[sttypes.StreamID]struct{})
}
req.whitelist[stid] = struct{}{}
}
func (req *request) isStreamWhitelisted(stid sttypes.StreamID) bool {
if req.whitelist == nil {
return true
}
_, ok := req.whitelist[stid]
return ok
}
func (st *stream) clearPendingRequest() *request {
req := st.req
if req == nil {
return nil
}
st.req = nil
return req
}
type cancelReqData struct {
reqID uint64
err error
}
// responseData is the wrapped response for stream requests
type responseData struct {
resp sttypes.Response
stID sttypes.StreamID
err error
}
// requestQueue is a wrapper of double linked list with Request as type
type requestQueue struct {
reqsPHigh *list.List // high priority, currently defined by upper function calls
reqsPLow *list.List // low priority, applied to all normal requests
lock sync.Mutex
}
func newRequestQueue() requestQueue {
return requestQueue{
reqsPHigh: list.New(),
reqsPLow: list.New(),
}
}
// Push add a new request to requestQueue.
func (q *requestQueue) Push(req *request, priority reqPriority) error {
q.lock.Lock()
defer q.lock.Unlock()
if priority == reqPriorityHigh || req.priority == reqPriorityHigh {
return pushRequestToList(q.reqsPHigh, req)
}
if priority == reqPriorityLow {
return pushRequestToList(q.reqsPLow, req)
}
return nil
}
// Pop will first pop the request from high priority, and then pop from low priority
func (q *requestQueue) Pop() *request {
q.lock.Lock()
defer q.lock.Unlock()
if req := popRequestFromList(q.reqsPHigh); req != nil {
return req
}
return popRequestFromList(q.reqsPLow)
}
func pushRequestToList(l *list.List, req *request) error {
if l.Len() >= maxWaitingSize {
return ErrQueueFull
}
l.PushBack(req)
return nil
}
func popRequestFromList(l *list.List) *request {
elem := l.Front()
if elem == nil {
return nil
}
l.Remove(elem)
return elem.Value.(*request)
}

@ -0,0 +1,165 @@
package requestmanager
import (
"container/list"
"fmt"
"strings"
"testing"
"github.com/pkg/errors"
)
func TestRequestQueue_Push(t *testing.T) {
tests := []struct {
initSize []int
priority reqPriority
expSize []int
expErr error
}{
{
initSize: []int{10, 10},
priority: reqPriorityHigh,
expSize: []int{11, 10},
expErr: nil,
},
{
initSize: []int{10, 10},
priority: reqPriorityLow,
expSize: []int{10, 11},
expErr: nil,
},
{
initSize: []int{maxWaitingSize, maxWaitingSize},
priority: reqPriorityLow,
expErr: ErrQueueFull,
},
{
initSize: []int{maxWaitingSize, maxWaitingSize},
priority: reqPriorityHigh,
expErr: ErrQueueFull,
},
}
for i, test := range tests {
q := makeTestRequestQueue(test.initSize)
req := wrapRequestFromRaw(makeTestRequest(100))
err := q.Push(req, test.priority)
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
}
if err != nil || test.expErr != nil {
continue
}
if err := q.checkSizes(test.expSize); err != nil {
t.Errorf("Test %v: %v", i, err)
}
}
}
func TestRequestQueue_Pop(t *testing.T) {
tests := []struct {
initSizes []int
expNil bool
expIndex uint64
expSizes []int
}{
{
initSizes: []int{10, 10},
expNil: false,
expIndex: 0,
expSizes: []int{9, 10},
},
{
initSizes: []int{0, 10},
expNil: false,
expIndex: 0,
expSizes: []int{0, 9},
},
{
initSizes: []int{0, 0},
expNil: true,
expSizes: []int{0, 0},
},
}
for i, test := range tests {
q := makeTestRequestQueue(test.initSizes)
req := q.Pop()
if err := q.checkSizes(test.expSizes); err != nil {
t.Errorf("Test %v: %v", i, err)
}
if req == nil != (test.expNil) {
t.Errorf("test %v: unpected nil", i)
}
if req == nil {
continue
}
index := req.Request.(*testRequest).index
if index != test.expIndex {
t.Errorf("Test %v: unexpected index: %v / %v", i, index, test.expIndex)
}
}
}
func makeTestRequestQueue(sizes []int) requestQueue {
if len(sizes) != 2 {
panic("unexpected sizes")
}
q := newRequestQueue()
index := 0
for i := 0; i != sizes[0]; i++ {
q.reqsPHigh.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index))))
index++
}
for i := 0; i != sizes[1]; i++ {
q.reqsPLow.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index))))
index++
}
return q
}
func wrapRequestFromRaw(raw *testRequest) *request {
return &request{
Request: raw,
}
}
func getTestRequestFromElem(elem *list.Element) (*testRequest, error) {
req, ok := elem.Value.(*request)
if !ok {
return nil, errors.New("unexpected type")
}
raw, ok := req.Request.(*testRequest)
if !ok {
return nil, errors.New("unexpected raw types")
}
return raw, nil
}
func (q *requestQueue) checkSizes(sizes []int) error {
if len(sizes) != 2 {
panic("expect 2 sizes")
}
if q.reqsPHigh.Len() != sizes[0] {
return fmt.Errorf("high priority %v / %v", q.reqsPHigh.Len(), sizes[0])
}
if q.reqsPLow.Len() != sizes[1] {
return fmt.Errorf("low priority %v / %v", q.reqsPLow.Len(), sizes[2])
}
return nil
}
func assertError(got, exp error) error {
if (got == nil) != (exp == nil) {
return fmt.Errorf("unexpected error: %v / %v", got, exp)
}
if got == nil {
return nil
}
if !strings.Contains(got.Error(), exp.Error()) {
return fmt.Errorf("unexpected error: %v / %v", got, exp)
}
return nil
}
Loading…
Cancel
Save