Merge branch 'main' into stream_syncprotocol_reviewfix

pull/3592/head
Rongjian Lan 4 years ago committed by GitHub
commit 07db2515a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      consensus/consensus.go
  2. 19
      consensus/consensus_v2.go
  3. 131
      consensus/downloader.go
  4. 11
      consensus/validator.go
  5. 45
      p2p/stream/common/requestmanager/interface_test.go
  6. 88
      p2p/stream/common/requestmanager/requestmanager.go
  7. 168
      p2p/stream/common/requestmanager/requestmanager_test.go
  8. 107
      p2p/stream/common/requestmanager/types.go
  9. 18
      p2p/stream/common/requestmanager/types_test.go
  10. 18
      p2p/stream/common/streammanager/interface.go

@ -130,6 +130,8 @@ type Consensus struct {
finality int64 finality int64
// finalityCounter keep tracks of the finality time // finalityCounter keep tracks of the finality time
finalityCounter int64 finalityCounter int64
dHelper *downloadHelper
} }
// SetCommitDelay sets the commit message delay. If set to non-zero, // SetCommitDelay sets the commit message delay. If set to non-zero,

@ -334,6 +334,8 @@ func (consensus *Consensus) Start(
break break
} }
} }
// TODO: Refactor this piece of code to consensus/downloader.go after DNS legacy sync is removed
case <-consensus.syncReadyChan: case <-consensus.syncReadyChan:
consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncReadyChan") consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncReadyChan")
consensus.mutex.Lock() consensus.mutex.Lock()
@ -352,6 +354,7 @@ func (consensus *Consensus) Start(
} }
consensus.mutex.Unlock() consensus.mutex.Unlock()
// TODO: Refactor this piece of code to consensus/downloader.go after DNS legacy sync is removed
case <-consensus.syncNotReadyChan: case <-consensus.syncNotReadyChan:
consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncNotReadyChan") consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncNotReadyChan")
consensus.SetBlockNum(consensus.Blockchain.CurrentHeader().Number().Uint64() + 1) consensus.SetBlockNum(consensus.Blockchain.CurrentHeader().Number().Uint64() + 1)
@ -467,13 +470,26 @@ func (consensus *Consensus) Start(
} }
consensus.getLogger().Info().Msg("[ConsensusMainLoop] Ended.") consensus.getLogger().Info().Msg("[ConsensusMainLoop] Ended.")
}() }()
if consensus.dHelper != nil {
consensus.dHelper.start()
}
} }
// Close close the consensus. If current is in normal commit phase, wait until the commit // Close close the consensus. If current is in normal commit phase, wait until the commit
// phase end. // phase end.
func (consensus *Consensus) Close() error { func (consensus *Consensus) Close() error {
if consensus.dHelper != nil {
consensus.dHelper.close()
}
consensus.waitForCommit()
return nil
}
// waitForCommit wait extra 2 seconds for commit phase to finish
func (consensus *Consensus) waitForCommit() {
if consensus.Mode() != Normal || consensus.phase != FBFTCommit { if consensus.Mode() != Normal || consensus.phase != FBFTCommit {
return nil return
} }
// We only need to wait consensus is in normal commit phase // We only need to wait consensus is in normal commit phase
utils.Logger().Warn().Str("phase", consensus.phase.String()).Msg("[shutdown] commit phase has to wait") utils.Logger().Warn().Str("phase", consensus.phase.String()).Msg("[shutdown] commit phase has to wait")
@ -483,7 +499,6 @@ func (consensus *Consensus) Close() error {
utils.Logger().Warn().Msg("[shutdown] wait for consensus finished") utils.Logger().Warn().Msg("[shutdown] wait for consensus finished")
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
} }
return nil
} }
// LastMileBlockIter is the iterator to iterate over the last mile blocks in consensus cache. // LastMileBlockIter is the iterator to iterate over the last mile blocks in consensus cache.

@ -0,0 +1,131 @@
package consensus
import (
"github.com/ethereum/go-ethereum/event"
"github.com/harmony-one/harmony/core/types"
"github.com/pkg/errors"
)
// downloader is the adapter interface for downloader.Downloader, which is used for
// 1. Subscribe download finished event to help syncing to the latest block.
// 2. Trigger the downloader to start working
type downloader interface {
SubscribeDownloadFinished(ch chan struct{}) event.Subscription
SubscribeDownloadStarted(ch chan struct{}) event.Subscription
DownloadAsync()
}
// Set downloader set the downloader of the shard to consensus
// TODO: It will be better to move this to consensus.New and register consensus as a service
func (consensus *Consensus) SetDownloader(d downloader) {
consensus.dHelper = newDownloadHelper(consensus, d)
}
type downloadHelper struct {
d downloader
c *Consensus
startedCh chan struct{}
finishedCh chan struct{}
startedSub event.Subscription
finishedSub event.Subscription
}
func newDownloadHelper(c *Consensus, d downloader) *downloadHelper {
startedCh := make(chan struct{}, 1)
startedSub := d.SubscribeDownloadStarted(startedCh)
finishedCh := make(chan struct{}, 1)
finishedSub := d.SubscribeDownloadFinished(finishedCh)
return &downloadHelper{
c: c,
d: d,
startedCh: startedCh,
finishedCh: finishedCh,
startedSub: startedSub,
finishedSub: finishedSub,
}
}
func (dh *downloadHelper) start() {
go dh.downloadStartedLoop()
go dh.downloadFinishedLoop()
}
func (dh *downloadHelper) close() {
dh.startedSub.Unsubscribe()
dh.finishedSub.Unsubscribe()
}
func (dh *downloadHelper) downloadStartedLoop() {
for {
select {
case <-dh.startedCh:
dh.c.BlocksNotSynchronized()
case err := <-dh.startedSub.Err():
dh.c.getLogger().Info().Err(err).Msg("consensus download finished loop closed")
return
}
}
}
func (dh *downloadHelper) downloadFinishedLoop() {
for {
select {
case <-dh.finishedCh:
err := dh.c.addConsensusLastMile()
if err != nil {
dh.c.getLogger().Error().Err(err).Msg("add last mile failed")
}
dh.c.BlocksSynchronized()
case err := <-dh.finishedSub.Err():
dh.c.getLogger().Info().Err(err).Msg("consensus download finished loop closed")
return
}
}
}
func (consensus *Consensus) addConsensusLastMile() error {
curBN := consensus.Blockchain.CurrentBlock().NumberU64()
blockIter, err := consensus.GetLastMileBlockIter(curBN + 1)
if err != nil {
return err
}
for {
block := blockIter.Next()
if block == nil {
break
}
if _, err := consensus.Blockchain.InsertChain(types.Blocks{block}, true); err != nil {
return errors.Wrap(err, "failed to InsertChain")
}
}
return nil
}
func (consensus *Consensus) spinUpStateSync() {
if consensus.dHelper != nil {
consensus.dHelper.d.DownloadAsync()
consensus.current.SetMode(Syncing)
for _, v := range consensus.consensusTimeout {
v.Stop()
}
} else {
consensus.spinLegacyStateSync()
}
}
func (consensus *Consensus) spinLegacyStateSync() {
select {
case consensus.BlockNumLowChan <- struct{}{}:
consensus.current.SetMode(Syncing)
for _, v := range consensus.consensusTimeout {
v.Stop()
}
default:
}
}

@ -402,14 +402,3 @@ func (consensus *Consensus) broadcastConsensusP2pMessages(p2pMsgs []*NetworkMess
} }
return nil return nil
} }
func (consensus *Consensus) spinUpStateSync() {
select {
case consensus.BlockNumLowChan <- struct{}{}:
consensus.current.SetMode(Syncing)
for _, v := range consensus.consensusTimeout {
v.Stop()
}
default:
}
}

@ -14,19 +14,25 @@ import (
var testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") var testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0")
type testStreamManager struct { type testStreamManager struct {
streams map[sttypes.StreamID]sttypes.Stream
newStreamFeed event.Feed newStreamFeed event.Feed
rmStreamFeed event.Feed rmStreamFeed event.Feed
} }
func newTestStreamManager() *testStreamManager { func newTestStreamManager() *testStreamManager {
return &testStreamManager{} return &testStreamManager{
streams: make(map[sttypes.StreamID]sttypes.Stream),
}
} }
func (sm *testStreamManager) addNewStream(st sttypes.Stream) { func (sm *testStreamManager) addNewStream(st sttypes.Stream) {
sm.streams[st.ID()] = st
sm.newStreamFeed.Send(streammanager.EvtStreamAdded{Stream: st}) sm.newStreamFeed.Send(streammanager.EvtStreamAdded{Stream: st})
} }
func (sm *testStreamManager) rmStream(stid sttypes.StreamID) { func (sm *testStreamManager) rmStream(stid sttypes.StreamID) {
delete(sm.streams, stid)
sm.rmStreamFeed.Send(streammanager.EvtStreamRemoved{ID: stid}) sm.rmStreamFeed.Send(streammanager.EvtStreamRemoved{ID: stid})
} }
@ -38,6 +44,20 @@ func (sm *testStreamManager) SubscribeRemoveStreamEvent(ch chan<- streammanager.
return sm.rmStreamFeed.Subscribe(ch) return sm.rmStreamFeed.Subscribe(ch)
} }
func (sm *testStreamManager) GetStreams() []sttypes.Stream {
sts := make([]sttypes.Stream, 0, len(sm.streams))
for _, st := range sm.streams {
sts = append(sts, st)
}
return sts
}
func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) {
st, exist := sm.streams[id]
return st, exist
}
type testStream struct { type testStream struct {
id sttypes.StreamID id sttypes.StreamID
rm *requestManager rm *requestManager
@ -79,6 +99,29 @@ func (st *testStream) ResetOnClose() error {
return nil return nil
} }
func makeDummyTestStreams(indexes []int) []sttypes.Stream {
sts := make([]sttypes.Stream, 0, len(indexes))
for _, index := range indexes {
sts = append(sts, &testStream{
id: makeStreamID(index),
})
}
return sts
}
func makeDummyStreamSets(indexes []int) map[sttypes.StreamID]*stream {
m := make(map[sttypes.StreamID]*stream)
for _, index := range indexes {
st := &testStream{
id: makeStreamID(index),
}
m[st.ID()] = &stream{Stream: st}
}
return m
}
func makeStreamID(index int) sttypes.StreamID { func makeStreamID(index int) sttypes.StreamID {
return sttypes.StreamID(strconv.Itoa(index)) return sttypes.StreamID(strconv.Itoa(index))
} }

@ -23,9 +23,10 @@ type requestManager struct {
streams map[sttypes.StreamID]*stream // All streams streams map[sttypes.StreamID]*stream // All streams
available map[sttypes.StreamID]struct{} // Streams that are available for request available map[sttypes.StreamID]struct{} // Streams that are available for request
pendings map[uint64]*request // requests that are sent but not received response pendings map[uint64]*request // requests that are sent but not received response
waitings requestQueue // double linked list of requests that are on the waiting list waitings requestQueues // double linked list of requests that are on the waiting list
// Stream events // Stream events
sm streammanager.Reader
newStreamC <-chan streammanager.EvtStreamAdded newStreamC <-chan streammanager.EvtStreamAdded
rmStreamC <-chan streammanager.EvtStreamRemoved rmStreamC <-chan streammanager.EvtStreamRemoved
// Request events // Request events
@ -40,11 +41,11 @@ type requestManager struct {
} }
// NewRequestManager creates a new request manager // NewRequestManager creates a new request manager
func NewRequestManager(sm streammanager.Subscriber) RequestManager { func NewRequestManager(sm streammanager.ReaderSubscriber) RequestManager {
return newRequestManager(sm) return newRequestManager(sm)
} }
func newRequestManager(sm streammanager.Subscriber) *requestManager { func newRequestManager(sm streammanager.ReaderSubscriber) *requestManager {
// subscribe at initialize to prevent misuse of upper function which might cause // subscribe at initialize to prevent misuse of upper function which might cause
// the bootstrap peers are ignored // the bootstrap peers are ignored
newStreamC := make(chan streammanager.EvtStreamAdded) newStreamC := make(chan streammanager.EvtStreamAdded)
@ -58,8 +59,9 @@ func newRequestManager(sm streammanager.Subscriber) *requestManager {
streams: make(map[sttypes.StreamID]*stream), streams: make(map[sttypes.StreamID]*stream),
available: make(map[sttypes.StreamID]struct{}), available: make(map[sttypes.StreamID]struct{}),
pendings: make(map[uint64]*request), pendings: make(map[uint64]*request),
waitings: newRequestQueue(), waitings: newRequestQueues(),
sm: sm,
newStreamC: newStreamC, newStreamC: newStreamC,
rmStreamC: rmStreamC, rmStreamC: rmStreamC,
cancelReqC: make(chan cancelReqData, 16), cancelReqC: make(chan cancelReqData, 16),
@ -102,8 +104,8 @@ func (rm *requestManager) doRequestAsync(ctx context.Context, raw sttypes.Reques
select { select {
case <-ctx.Done(): // canceled or timeout in upper function calls case <-ctx.Done(): // canceled or timeout in upper function calls
rm.cancelReqC <- cancelReqData{ rm.cancelReqC <- cancelReqData{
reqID: req.ReqID(), req: req,
err: ctx.Err(), err: ctx.Err(),
} }
case <-req.doneC: case <-req.doneC:
} }
@ -182,13 +184,11 @@ func (rm *requestManager) loop() {
case data := <-rm.cancelReqC: case data := <-rm.cancelReqC:
rm.handleCancelRequest(data) rm.handleCancelRequest(data)
case evt := <-rm.newStreamC: case <-rm.newStreamC:
rm.logger.Info().Str("streamID", string(evt.Stream.ID())).Msg("add new stream") rm.refreshStreams()
rm.addNewStream(evt.Stream)
case evt := <-rm.rmStreamC: case <-rm.rmStreamC:
rm.logger.Info().Str("streamID", string(evt.ID)).Msg("remove stream") rm.refreshStreams()
rm.removeStream(evt.ID)
case <-rm.stopC: case <-rm.stopC:
rm.logger.Info().Msg("request manager stopped") rm.logger.Info().Msg("request manager stopped")
@ -255,21 +255,20 @@ func (rm *requestManager) handleCancelRequest(data cancelReqData) {
rm.lock.Lock() rm.lock.Lock()
defer rm.lock.Unlock() defer rm.lock.Unlock()
req, ok := rm.pendings[data.reqID] var (
if !ok { req = data.req
return err = data.err
} )
rm.waitings.Remove(req)
rm.removePendingRequest(req) rm.removePendingRequest(req)
var stid sttypes.StreamID var stid sttypes.StreamID
if req.owner != nil { if req.owner != nil {
stid = req.owner.ID() stid = req.owner.ID()
} }
req.doneWithResponse(responseData{ req.doneWithResponse(responseData{
resp: nil, resp: nil,
stID: stid, stID: stid,
err: data.err, err: err,
}) })
} }
@ -290,7 +289,7 @@ func (rm *requestManager) getNextRequest() (*request, *stream) {
st, err := rm.pickAvailableStream(req) st, err := rm.pickAvailableStream(req)
if err != nil { if err != nil {
rm.logger.Debug().Msg("No available streams.") rm.logger.Debug().Err(err).Str("request", req.String()).Msg("Pick available streams.")
rm.addRequestToWaitings(req, reqPriorityHigh) rm.addRequestToWaitings(req, reqPriorityHigh)
return nil, nil return nil, nil
} }
@ -349,26 +348,49 @@ func (rm *requestManager) pickAvailableStream(req *request) (*stream, error) {
return nil, errors.New("no more available streams") return nil, errors.New("no more available streams")
} }
func (rm *requestManager) addNewStream(st sttypes.Stream) { func (rm *requestManager) refreshStreams() {
rm.lock.Lock() rm.lock.Lock()
defer rm.lock.Unlock() defer rm.lock.Unlock()
if _, ok := rm.streams[st.ID()]; !ok { added, removed := checkStreamUpdates(rm.streams, rm.sm.GetStreams())
rm.streams[st.ID()] = &stream{Stream: st}
rm.available[st.ID()] = struct{}{} for _, st := range added {
rm.logger.Info().Str("streamID", string(st.ID())).Msg("add new stream")
rm.addNewStream(st)
}
for _, st := range removed {
rm.logger.Info().Str("streamID", string(st.ID())).Msg("remove stream")
rm.removeStream(st)
} }
} }
// removeStream remove the stream from request manager, clear the pending request func checkStreamUpdates(exists map[sttypes.StreamID]*stream, targets []sttypes.Stream) (added []sttypes.Stream, removed []*stream) {
// of the stream. Return whether a pending request is canceled in the stream, targetM := make(map[sttypes.StreamID]sttypes.Stream)
func (rm *requestManager) removeStream(id sttypes.StreamID) {
rm.lock.Lock()
defer rm.lock.Unlock()
st, ok := rm.streams[id] for _, target := range targets {
if !ok { id := target.ID()
return targetM[id] = target
if _, ok := exists[id]; !ok {
added = append(added, target)
}
} }
for id, exist := range exists {
if _, ok := targetM[id]; !ok {
removed = append(removed, exist)
}
}
return
}
func (rm *requestManager) addNewStream(st sttypes.Stream) {
rm.streams[st.ID()] = &stream{Stream: st}
rm.available[st.ID()] = struct{}{}
}
// removeStream remove the stream from request manager, clear the pending request
// of the stream.
func (rm *requestManager) removeStream(st *stream) {
id := st.ID()
delete(rm.available, id) delete(rm.available, id)
delete(rm.streams, id) delete(rm.streams, id)
@ -394,7 +416,7 @@ func (rm *requestManager) close() {
rm.pendings = make(map[uint64]*request) rm.pendings = make(map[uint64]*request)
rm.available = make(map[sttypes.StreamID]struct{}) rm.available = make(map[sttypes.StreamID]struct{})
rm.streams = make(map[sttypes.StreamID]*stream) rm.streams = make(map[sttypes.StreamID]*stream)
rm.waitings = newRequestQueue() rm.waitings = newRequestQueues()
close(rm.stopC) close(rm.stopC)
} }

@ -2,6 +2,7 @@ package requestmanager
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -132,7 +133,7 @@ func TestRequestManager_UnknownDelivery(t *testing.T) {
req := makeTestRequest(100) req := makeTestRequest(100)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
resC := ts.rm.doRequestAsync(ctx, req) resC := ts.rm.doRequestAsync(ctx, req)
time.Sleep(6 * time.Second) time.Sleep(2 * time.Second)
cancel() cancel()
// Since the reqID is not delivered, the result is not delivered to the request // Since the reqID is not delivered, the result is not delivered to the request
@ -164,6 +165,79 @@ func TestRequestManager_StaleDelivery(t *testing.T) {
} }
} }
// TestRequestManager_cancelWaitings test the scenario of request being canceled
// while still in waitings. In order to do this,
// 1. Set number of streams to 1
// 2. Occupy the stream with a request, and block
// 3. Do the second request. This request will be in waitings.
// 4. Cancel the second request. Request shall be removed from waitings.
// 5. Unblock the first request
// 6. Request 1 finished, request 2 canceled
func TestRequestManager_cancelWaitings(t *testing.T) {
req1 := makeTestRequest(1)
req2 := makeTestRequest(2)
var req1Block sync.Mutex
req1Block.Lock()
unblockReq1 := func() { req1Block.Unlock() }
delayF := makeDefaultDelayFunc(150 * time.Millisecond)
respF := func(req *testRequest) *testResponse {
if req.index == req1.index {
req1Block.Lock()
}
return makeDefaultResponseFunc()(req)
}
ts := newTestSuite(delayF, respF, 1)
ts.Start()
defer ts.Close()
ctx1, _ := context.WithTimeout(context.Background(), 1*time.Second)
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
resC1 := ts.rm.doRequestAsync(ctx1, req1)
resC2 := ts.rm.doRequestAsync(ctx2, req2)
cancel2()
unblockReq1()
var (
res1 responseData
res2 responseData
)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
select {
case res1 = <-resC1:
case <-time.After(1 * time.Second):
t.Errorf("req1 timed out")
}
}()
go func() {
defer wg.Done()
select {
case res2 = <-resC2:
case <-time.After(1 * time.Second):
t.Errorf("req2 timed out")
}
}()
wg.Wait()
if res1.err != nil {
t.Errorf("request 1 shall return nil error")
}
if res2.err != context.Canceled {
t.Errorf("request 2 shall be canceled")
}
if ts.rm.waitings.reqsPLow.len() != 0 || ts.rm.waitings.reqsPHigh.len() != 0 {
t.Errorf("waitings shall be clean")
}
}
// closing request manager will also close all // closing request manager will also close all
func TestRequestManager_Close(t *testing.T) { func TestRequestManager_Close(t *testing.T) {
delayF := makeDefaultDelayFunc(1 * time.Second) delayF := makeDefaultDelayFunc(1 * time.Second)
@ -303,6 +377,95 @@ func TestGenReqID(t *testing.T) {
} }
} }
func TestCheckStreamUpdates(t *testing.T) {
tests := []struct {
exists map[sttypes.StreamID]*stream
targets []sttypes.Stream
expAddedIndexes []int
expRemovedIndexes []int
}{
{
exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}),
targets: makeDummyTestStreams([]int{2, 3, 4, 5}),
expAddedIndexes: []int{},
expRemovedIndexes: []int{1},
},
{
exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}),
targets: makeDummyTestStreams([]int{1, 2, 3, 4, 5, 6}),
expAddedIndexes: []int{6},
expRemovedIndexes: []int{},
},
{
exists: makeDummyStreamSets([]int{}),
targets: makeDummyTestStreams([]int{}),
expAddedIndexes: []int{},
expRemovedIndexes: []int{},
},
{
exists: makeDummyStreamSets([]int{}),
targets: makeDummyTestStreams([]int{1, 2, 3, 4, 5}),
expAddedIndexes: []int{1, 2, 3, 4, 5},
expRemovedIndexes: []int{},
},
{
exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}),
targets: makeDummyTestStreams([]int{}),
expAddedIndexes: []int{},
expRemovedIndexes: []int{1, 2, 3, 4, 5},
},
{
exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}),
targets: makeDummyTestStreams([]int{6, 7, 8, 9, 10}),
expAddedIndexes: []int{6, 7, 8, 9, 10},
expRemovedIndexes: []int{1, 2, 3, 4, 5},
},
}
for i, test := range tests {
added, removed := checkStreamUpdates(test.exists, test.targets)
if err := checkStreamIDsEqual(added, test.expAddedIndexes); err != nil {
t.Errorf("Test %v: check added: %v", i, err)
}
if err := checkStreamIDsEqual2(removed, test.expRemovedIndexes); err != nil {
t.Errorf("Test %v: check removed: %v", i, err)
}
}
}
func checkStreamIDsEqual(sts []sttypes.Stream, expIndexes []int) error {
if len(sts) != len(expIndexes) {
return fmt.Errorf("size not equal")
}
expM := make(map[sttypes.StreamID]struct{})
for _, index := range expIndexes {
expM[makeStreamID(index)] = struct{}{}
}
for _, st := range sts {
if _, ok := expM[st.ID()]; !ok {
return fmt.Errorf("stream not exist in exp: %v", st.ID())
}
}
return nil
}
func checkStreamIDsEqual2(sts []*stream, expIndexes []int) error {
if len(sts) != len(expIndexes) {
return fmt.Errorf("size not equal")
}
expM := make(map[sttypes.StreamID]struct{})
for _, index := range expIndexes {
expM[makeStreamID(index)] = struct{}{}
}
for _, st := range sts {
if _, ok := expM[st.ID()]; !ok {
return fmt.Errorf("stream not exist in exp: %v", st.ID())
}
}
return nil
}
type testSuite struct { type testSuite struct {
rm *requestManager rm *requestManager
sm *testStreamManager sm *testStreamManager
@ -330,7 +493,8 @@ func newTestSuite(delayF delayFunc, respF responseFunc, numStreams int) *testSui
cancel: cancel, cancel: cancel,
} }
for i := 0; i != numStreams; i++ { for i := 0; i != numStreams; i++ {
ts.bootStreams = append(ts.bootStreams, ts.makeTestStream(i)) st := ts.makeTestStream(i)
ts.bootStreams = append(ts.bootStreams, st)
} }
return ts return ts
} }

@ -114,8 +114,8 @@ func (st *stream) clearPendingRequest() *request {
} }
type cancelReqData struct { type cancelReqData struct {
reqID uint64 req *request
err error err error
} }
// responseData is the wrapped response for stream requests // responseData is the wrapped response for stream requests
@ -125,58 +125,97 @@ type responseData struct {
err error err error
} }
// requestQueue is a wrapper of double linked list with Request as type // requestQueues is a wrapper of double linked list with Request as type
type requestQueue struct { type requestQueues struct {
reqsPHigh *list.List // high priority, currently defined by upper function calls reqsPHigh *requestQueue // high priority, currently defined by upper function calls
reqsPLow *list.List // low priority, applied to all normal requests reqsPLow *requestQueue // low priority, applied to all normal requests
lock sync.Mutex
} }
func newRequestQueue() requestQueue { func newRequestQueues() requestQueues {
return requestQueue{ return requestQueues{
reqsPHigh: list.New(), reqsPHigh: newRequestQueue(),
reqsPLow: list.New(), reqsPLow: newRequestQueue(),
} }
} }
// Push add a new request to requestQueue. // Push add a new request to requestQueues.
func (q *requestQueue) Push(req *request, priority reqPriority) error { func (q *requestQueues) Push(req *request, priority reqPriority) error {
q.lock.Lock()
defer q.lock.Unlock()
if priority == reqPriorityHigh || req.priority == reqPriorityHigh { if priority == reqPriorityHigh || req.priority == reqPriorityHigh {
return pushRequestToList(q.reqsPHigh, req) return q.reqsPHigh.push(req)
} }
if priority == reqPriorityLow { return q.reqsPLow.push(req)
return pushRequestToList(q.reqsPLow, req)
}
return nil
} }
// Pop will first pop the request from high priority, and then pop from low priority // Pop will first pop the request from high priority, and then pop from low priority
func (q *requestQueue) Pop() *request { func (q *requestQueues) Pop() *request {
q.lock.Lock() if req := q.reqsPHigh.pop(); req != nil {
defer q.lock.Unlock()
if req := popRequestFromList(q.reqsPHigh); req != nil {
return req return req
} }
return popRequestFromList(q.reqsPLow) return q.reqsPLow.pop()
}
func (q *requestQueues) Remove(req *request) {
q.reqsPHigh.remove(req)
q.reqsPLow.remove(req)
}
// requestQueue is a thread safe request double linked list
type requestQueue struct {
l *list.List
elemM map[*request]*list.Element // Yes, pointer as map key
lock sync.Mutex
}
func newRequestQueue() *requestQueue {
return &requestQueue{
l: list.New(),
elemM: make(map[*request]*list.Element),
}
} }
func pushRequestToList(l *list.List, req *request) error { func (rl *requestQueue) push(req *request) error {
if l.Len() >= maxWaitingSize { rl.lock.Lock()
defer rl.lock.Unlock()
if rl.l.Len() >= maxWaitingSize {
return ErrQueueFull return ErrQueueFull
} }
l.PushBack(req) elem := rl.l.PushBack(req)
rl.elemM[req] = elem
return nil return nil
} }
func popRequestFromList(l *list.List) *request { func (rl *requestQueue) pop() *request {
elem := l.Front() rl.lock.Lock()
defer rl.lock.Unlock()
elem := rl.l.Front()
if elem == nil { if elem == nil {
return nil return nil
} }
l.Remove(elem) rl.l.Remove(elem)
return elem.Value.(*request)
req := elem.Value.(*request)
delete(rl.elemM, req)
return req
}
func (rl *requestQueue) remove(req *request) {
rl.lock.Lock()
defer rl.lock.Unlock()
elem := rl.elemM[req]
if elem == nil {
// Already removed
return
}
rl.l.Remove(elem)
delete(rl.elemM, req)
}
func (rl *requestQueue) len() int {
rl.lock.Lock()
defer rl.lock.Unlock()
return rl.l.Len()
} }

@ -102,19 +102,19 @@ func TestRequestQueue_Pop(t *testing.T) {
} }
} }
func makeTestRequestQueue(sizes []int) requestQueue { func makeTestRequestQueue(sizes []int) requestQueues {
if len(sizes) != 2 { if len(sizes) != 2 {
panic("unexpected sizes") panic("unexpected sizes")
} }
q := newRequestQueue() q := newRequestQueues()
index := 0 index := 0
for i := 0; i != sizes[0]; i++ { for i := 0; i != sizes[0]; i++ {
q.reqsPHigh.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) q.reqsPHigh.push(wrapRequestFromRaw(makeTestRequest(uint64(index))))
index++ index++
} }
for i := 0; i != sizes[1]; i++ { for i := 0; i != sizes[1]; i++ {
q.reqsPLow.PushBack(wrapRequestFromRaw(makeTestRequest(uint64(index)))) q.reqsPLow.push(wrapRequestFromRaw(makeTestRequest(uint64(index))))
index++ index++
} }
return q return q
@ -138,15 +138,15 @@ func getTestRequestFromElem(elem *list.Element) (*testRequest, error) {
return raw, nil return raw, nil
} }
func (q *requestQueue) checkSizes(sizes []int) error { func (q *requestQueues) checkSizes(sizes []int) error {
if len(sizes) != 2 { if len(sizes) != 2 {
panic("expect 2 sizes") panic("expect 2 sizes")
} }
if q.reqsPHigh.Len() != sizes[0] { if q.reqsPHigh.len() != sizes[0] {
return fmt.Errorf("high priority %v / %v", q.reqsPHigh.Len(), sizes[0]) return fmt.Errorf("high priority %v / %v", q.reqsPHigh.len(), sizes[0])
} }
if q.reqsPLow.Len() != sizes[1] { if q.reqsPLow.len() != sizes[1] {
return fmt.Errorf("low priority %v / %v", q.reqsPLow.Len(), sizes[2]) return fmt.Errorf("low priority %v / %v", q.reqsPLow.len(), sizes[2])
} }
return nil return nil
} }

@ -14,13 +14,19 @@ import (
// StreamManager is the interface for streamManager // StreamManager is the interface for streamManager
type StreamManager interface { type StreamManager interface {
p2ptypes.LifeCycle p2ptypes.LifeCycle
StreamOperator Operator
Subscriber Subscriber
StreamReader Reader
} }
// StreamOperator handles new stream or remove stream // ReaderSubscriber reads stream and subscribe stream events
type StreamOperator interface { type ReaderSubscriber interface {
Reader
Subscriber
}
// Operator handles new stream or remove stream
type Operator interface {
NewStream(stream sttypes.Stream) error NewStream(stream sttypes.Stream) error
RemoveStream(stID sttypes.StreamID) error RemoveStream(stID sttypes.StreamID) error
} }
@ -31,8 +37,8 @@ type Subscriber interface {
SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription
} }
// StreamReader is the interface to read stream in stream manager // Reader is the interface to read stream in stream manager
type StreamReader interface { type Reader interface {
GetStreams() []sttypes.Stream GetStreams() []sttypes.Stream
GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool)
} }

Loading…
Cancel
Save