From 72f74c08cb3df6b259a8b3125511e70b90dea6b8 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Mon, 15 Mar 2021 23:08:30 -0700 Subject: [PATCH] [stream] request can be canceled from waitings in request manager --- .../common/requestmanager/requestmanager.go | 23 ++-- .../requestmanager/requestmanager_test.go | 75 +++++++++++- p2p/stream/common/requestmanager/types.go | 107 ++++++++++++------ .../common/requestmanager/types_test.go | 18 +-- 4 files changed, 167 insertions(+), 56 deletions(-) diff --git a/p2p/stream/common/requestmanager/requestmanager.go b/p2p/stream/common/requestmanager/requestmanager.go index 75356de41..2cac5c51d 100644 --- a/p2p/stream/common/requestmanager/requestmanager.go +++ b/p2p/stream/common/requestmanager/requestmanager.go @@ -23,7 +23,7 @@ type requestManager struct { streams map[sttypes.StreamID]*stream // All streams available map[sttypes.StreamID]struct{} // Streams that are available for request pendings map[uint64]*request // requests that are sent but not received response - waitings requestQueue // double linked list of requests that are on the waiting list + waitings requestQueues // double linked list of requests that are on the waiting list // Stream events sm streammanager.Reader @@ -59,7 +59,7 @@ func newRequestManager(sm streammanager.ReaderSubscriber) *requestManager { streams: make(map[sttypes.StreamID]*stream), available: make(map[sttypes.StreamID]struct{}), pendings: make(map[uint64]*request), - waitings: newRequestQueue(), + waitings: newRequestQueues(), sm: sm, newStreamC: newStreamC, @@ -104,8 +104,8 @@ func (rm *requestManager) doRequestAsync(ctx context.Context, raw sttypes.Reques select { case <-ctx.Done(): // canceled or timeout in upper function calls rm.cancelReqC <- cancelReqData{ - reqID: req.ReqID(), - err: ctx.Err(), + req: req, + err: ctx.Err(), } case <-req.doneC: } @@ -255,21 +255,20 @@ func (rm *requestManager) handleCancelRequest(data cancelReqData) { rm.lock.Lock() defer rm.lock.Unlock() - req, ok := rm.pendings[data.reqID] - if !ok { - return - } + var ( + req = data.req + err = data.err + ) + rm.waitings.Remove(req) rm.removePendingRequest(req) - var stid sttypes.StreamID if req.owner != nil { stid = req.owner.ID() } - req.doneWithResponse(responseData{ resp: nil, stID: stid, - err: data.err, + err: err, }) } @@ -417,7 +416,7 @@ func (rm *requestManager) close() { rm.pendings = make(map[uint64]*request) rm.available = make(map[sttypes.StreamID]struct{}) rm.streams = make(map[sttypes.StreamID]*stream) - rm.waitings = newRequestQueue() + rm.waitings = newRequestQueues() close(rm.stopC) } diff --git a/p2p/stream/common/requestmanager/requestmanager_test.go b/p2p/stream/common/requestmanager/requestmanager_test.go index 01638f6b9..63804d0d5 100644 --- a/p2p/stream/common/requestmanager/requestmanager_test.go +++ b/p2p/stream/common/requestmanager/requestmanager_test.go @@ -133,7 +133,7 @@ func TestRequestManager_UnknownDelivery(t *testing.T) { req := makeTestRequest(100) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) resC := ts.rm.doRequestAsync(ctx, req) - time.Sleep(6 * time.Second) + time.Sleep(2 * time.Second) cancel() // Since the reqID is not delivered, the result is not delivered to the request @@ -165,6 +165,79 @@ func TestRequestManager_StaleDelivery(t *testing.T) { } } +// TestRequestManager_cancelWaitings test the scenario of request being canceled +// while still in waitings. In order to do this, +// 1. Set number of streams to 1 +// 2. Occupy the stream with a request, and block +// 3. Do the second request. This request will be in waitings. +// 4. Cancel the second request. Request shall be removed from waitings. +// 5. Unblock the first request +// 6. Request 1 finished, request 2 canceled +func TestRequestManager_cancelWaitings(t *testing.T) { + req1 := makeTestRequest(1) + req2 := makeTestRequest(2) + + var req1Block sync.Mutex + req1Block.Lock() + unblockReq1 := func() { req1Block.Unlock() } + + delayF := makeDefaultDelayFunc(150 * time.Millisecond) + respF := func(req *testRequest) *testResponse { + if req.index == req1.index { + req1Block.Lock() + } + return makeDefaultResponseFunc()(req) + } + ts := newTestSuite(delayF, respF, 1) + ts.Start() + defer ts.Close() + + ctx1, _ := context.WithTimeout(context.Background(), 1*time.Second) + ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second) + resC1 := ts.rm.doRequestAsync(ctx1, req1) + resC2 := ts.rm.doRequestAsync(ctx2, req2) + + cancel2() + unblockReq1() + + var ( + res1 responseData + res2 responseData + ) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + + select { + case res1 = <-resC1: + case <-time.After(1 * time.Second): + t.Errorf("req1 timed out") + } + }() + go func() { + defer wg.Done() + + select { + case res2 = <-resC2: + case <-time.After(1 * time.Second): + t.Errorf("req2 timed out") + } + }() + wg.Wait() + + if res1.err != nil { + t.Errorf("request 1 shall return nil error") + } + if res2.err != context.Canceled { + t.Errorf("request 2 shall be canceled") + } + if ts.rm.waitings.reqsPLow.len() != 0 || ts.rm.waitings.reqsPHigh.len() != 0 { + t.Errorf("waitings shall be clean") + } +} + // closing request manager will also close all func TestRequestManager_Close(t *testing.T) { delayF := makeDefaultDelayFunc(1 * time.Second) diff --git a/p2p/stream/common/requestmanager/types.go b/p2p/stream/common/requestmanager/types.go index 7cf29c821..c73488e1d 100644 --- a/p2p/stream/common/requestmanager/types.go +++ b/p2p/stream/common/requestmanager/types.go @@ -114,8 +114,8 @@ func (st *stream) clearPendingRequest() *request { } type cancelReqData struct { - reqID uint64 - err error + req *request + err error } // responseData is the wrapped response for stream requests @@ -125,58 +125,97 @@ type responseData struct { err error } -// requestQueue is a wrapper of double linked list with Request as type -type requestQueue struct { - reqsPHigh *list.List // high priority, currently defined by upper function calls - reqsPLow *list.List // low priority, applied to all normal requests - lock sync.Mutex +// requestQueues is a wrapper of double linked list with Request as type +type requestQueues struct { + reqsPHigh *requestQueue // high priority, currently defined by upper function calls + reqsPLow *requestQueue // low priority, applied to all normal requests } -func newRequestQueue() requestQueue { - return requestQueue{ - reqsPHigh: list.New(), - reqsPLow: list.New(), +func newRequestQueues() requestQueues { + return requestQueues{ + reqsPHigh: newRequestQueue(), + reqsPLow: newRequestQueue(), } } -// Push add a new request to requestQueue. -func (q *requestQueue) Push(req *request, priority reqPriority) error { - q.lock.Lock() - defer q.lock.Unlock() - +// Push add a new request to requestQueues. +func (q *requestQueues) Push(req *request, priority reqPriority) error { if priority == reqPriorityHigh || req.priority == reqPriorityHigh { - return pushRequestToList(q.reqsPHigh, req) + return q.reqsPHigh.push(req) } - if priority == reqPriorityLow { - return pushRequestToList(q.reqsPLow, req) - } - return nil + return q.reqsPLow.push(req) } // Pop will first pop the request from high priority, and then pop from low priority -func (q *requestQueue) Pop() *request { - q.lock.Lock() - defer q.lock.Unlock() - - if req := popRequestFromList(q.reqsPHigh); req != nil { +func (q *requestQueues) Pop() *request { + if req := q.reqsPHigh.pop(); req != nil { return req } - return popRequestFromList(q.reqsPLow) + return q.reqsPLow.pop() +} + +func (q *requestQueues) Remove(req *request) { + q.reqsPHigh.remove(req) + q.reqsPLow.remove(req) +} + +// requestQueue is a thread safe request double linked list +type requestQueue struct { + l *list.List + elemM map[*request]*list.Element // Yes, pointer as map key + lock sync.Mutex +} + +func newRequestQueue() *requestQueue { + return &requestQueue{ + l: list.New(), + elemM: make(map[*request]*list.Element), + } } -func pushRequestToList(l *list.List, req *request) error { - if l.Len() >= maxWaitingSize { +func (rl *requestQueue) push(req *request) error { + rl.lock.Lock() + defer rl.lock.Unlock() + + if rl.l.Len() >= maxWaitingSize { return ErrQueueFull } - l.PushBack(req) + elem := rl.l.PushBack(req) + rl.elemM[req] = elem return nil } -func popRequestFromList(l *list.List) *request { - elem := l.Front() +func (rl *requestQueue) pop() *request { + rl.lock.Lock() + defer rl.lock.Unlock() + + elem := rl.l.Front() if elem == nil { return nil } - l.Remove(elem) - return elem.Value.(*request) + rl.l.Remove(elem) + + req := elem.Value.(*request) + delete(rl.elemM, req) + return req +} + +func (rl *requestQueue) remove(req *request) { + rl.lock.Lock() + defer rl.lock.Unlock() + + elem := rl.elemM[req] + if elem == nil { + // Already removed + return + } + rl.l.Remove(elem) + delete(rl.elemM, req) +} + +func (rl *requestQueue) len() int { + rl.lock.Lock() + defer rl.lock.Unlock() + + return rl.l.Len() } diff --git a/p2p/stream/common/requestmanager/types_test.go b/p2p/stream/common/requestmanager/types_test.go index f98c93323..f19c5f0ca 100644 --- a/p2p/stream/common/requestmanager/types_test.go +++ b/p2p/stream/common/requestmanager/types_test.go @@ -102,19 +102,19 @@ func TestRequestQueue_Pop(t *testing.T) { } } -func makeTestRequestQueue(sizes []int) requestQueue { +func makeTestRequestQueue(sizes []int) requestQueues { if len(sizes) != 2 { panic("unexpected sizes") } - q := newRequestQueue() + q := newRequestQueues() index := 0 for i := 0; i != sizes[0]; i++ { - q.reqsPHigh.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) + q.reqsPHigh.push(wrapRequestFromRaw(makeTestRequest(uint64(index)))) index++ } for i := 0; i != sizes[1]; i++ { - q.reqsPLow.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) + q.reqsPLow.push(wrapRequestFromRaw(makeTestRequest(uint64(index)))) index++ } return q @@ -138,15 +138,15 @@ func getTestRequestFromElem(elem *list.Element) (*testRequest, error) { return raw, nil } -func (q *requestQueue) checkSizes(sizes []int) error { +func (q *requestQueues) checkSizes(sizes []int) error { if len(sizes) != 2 { panic("expect 2 sizes") } - if q.reqsPHigh.Len() != sizes[0] { - return fmt.Errorf("high priority %v / %v", q.reqsPHigh.Len(), sizes[0]) + if q.reqsPHigh.len() != sizes[0] { + return fmt.Errorf("high priority %v / %v", q.reqsPHigh.len(), sizes[0]) } - if q.reqsPLow.Len() != sizes[1] { - return fmt.Errorf("low priority %v / %v", q.reqsPLow.Len(), sizes[2]) + if q.reqsPLow.len() != sizes[1] { + return fmt.Errorf("low priority %v / %v", q.reqsPLow.len(), sizes[2]) } return nil }