diff --git a/p2p/stream/common/requestmanager/requestmanager.go b/p2p/stream/common/requestmanager/requestmanager.go index 4e0d33c14..75356de41 100644 --- a/p2p/stream/common/requestmanager/requestmanager.go +++ b/p2p/stream/common/requestmanager/requestmanager.go @@ -26,6 +26,7 @@ type requestManager struct { waitings requestQueue // double linked list of requests that are on the waiting list // Stream events + sm streammanager.Reader newStreamC <-chan streammanager.EvtStreamAdded rmStreamC <-chan streammanager.EvtStreamRemoved // Request events @@ -40,11 +41,11 @@ type requestManager struct { } // NewRequestManager creates a new request manager -func NewRequestManager(sm streammanager.Subscriber) RequestManager { +func NewRequestManager(sm streammanager.ReaderSubscriber) RequestManager { return newRequestManager(sm) } -func newRequestManager(sm streammanager.Subscriber) *requestManager { +func newRequestManager(sm streammanager.ReaderSubscriber) *requestManager { // subscribe at initialize to prevent misuse of upper function which might cause // the bootstrap peers are ignored newStreamC := make(chan streammanager.EvtStreamAdded) @@ -60,6 +61,7 @@ func newRequestManager(sm streammanager.Subscriber) *requestManager { pendings: make(map[uint64]*request), waitings: newRequestQueue(), + sm: sm, newStreamC: newStreamC, rmStreamC: rmStreamC, cancelReqC: make(chan cancelReqData, 16), @@ -182,13 +184,11 @@ func (rm *requestManager) loop() { case data := <-rm.cancelReqC: rm.handleCancelRequest(data) - case evt := <-rm.newStreamC: - rm.logger.Info().Str("streamID", string(evt.Stream.ID())).Msg("add new stream") - rm.addNewStream(evt.Stream) + case <-rm.newStreamC: + rm.refreshStreams() - case evt := <-rm.rmStreamC: - rm.logger.Info().Str("streamID", string(evt.ID)).Msg("remove stream") - rm.removeStream(evt.ID) + case <-rm.rmStreamC: + rm.refreshStreams() case <-rm.stopC: rm.logger.Info().Msg("request manager stopped") @@ -349,26 +349,49 @@ func (rm *requestManager) pickAvailableStream(req *request) (*stream, error) { return nil, errors.New("no more available streams") } -func (rm *requestManager) addNewStream(st sttypes.Stream) { +func (rm *requestManager) refreshStreams() { rm.lock.Lock() defer rm.lock.Unlock() - if _, ok := rm.streams[st.ID()]; !ok { - rm.streams[st.ID()] = &stream{Stream: st} - rm.available[st.ID()] = struct{}{} + added, removed := checkStreamUpdates(rm.streams, rm.sm.GetStreams()) + + for _, st := range added { + rm.logger.Info().Str("streamID", string(st.ID())).Msg("add new stream") + rm.addNewStream(st) + } + for _, st := range removed { + rm.logger.Info().Str("streamID", string(st.ID())).Msg("remove stream") + rm.removeStream(st) } } -// removeStream remove the stream from request manager, clear the pending request -// of the stream. Return whether a pending request is canceled in the stream, -func (rm *requestManager) removeStream(id sttypes.StreamID) { - rm.lock.Lock() - defer rm.lock.Unlock() +func checkStreamUpdates(exists map[sttypes.StreamID]*stream, targets []sttypes.Stream) (added []sttypes.Stream, removed []*stream) { + targetM := make(map[sttypes.StreamID]sttypes.Stream) - st, ok := rm.streams[id] - if !ok { - return + for _, target := range targets { + id := target.ID() + targetM[id] = target + if _, ok := exists[id]; !ok { + added = append(added, target) + } + } + for id, exist := range exists { + if _, ok := targetM[id]; !ok { + removed = append(removed, exist) + } } + return +} + +func (rm *requestManager) addNewStream(st sttypes.Stream) { + rm.streams[st.ID()] = &stream{Stream: st} + rm.available[st.ID()] = struct{}{} +} + +// removeStream remove the stream from request manager, clear the pending request +// of the stream. +func (rm *requestManager) removeStream(st *stream) { + id := st.ID() delete(rm.available, id) delete(rm.streams, id) diff --git a/p2p/stream/common/streammanager/interface.go b/p2p/stream/common/streammanager/interface.go index 5c34488d8..e6659cf65 100644 --- a/p2p/stream/common/streammanager/interface.go +++ b/p2p/stream/common/streammanager/interface.go @@ -14,13 +14,19 @@ import ( // StreamManager is the interface for streamManager type StreamManager interface { p2ptypes.LifeCycle - StreamOperator + Operator Subscriber - StreamReader + Reader } -// StreamOperator handles new stream or remove stream -type StreamOperator interface { +// ReaderSubscriber reads stream and subscribe stream events +type ReaderSubscriber interface { + Reader + Subscriber +} + +// Operator handles new stream or remove stream +type Operator interface { NewStream(stream sttypes.Stream) error RemoveStream(stID sttypes.StreamID) error } @@ -31,8 +37,8 @@ type Subscriber interface { SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription } -// StreamReader is the interface to read stream in stream manager -type StreamReader interface { +// Reader is the interface to read stream in stream manager +type Reader interface { GetStreams() []sttypes.Stream GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) }