diff --git a/p2p/stream/common/requestmanager/interface_test.go b/p2p/stream/common/requestmanager/interface_test.go index 01e8c595a..57ffe3355 100644 --- a/p2p/stream/common/requestmanager/interface_test.go +++ b/p2p/stream/common/requestmanager/interface_test.go @@ -14,19 +14,25 @@ import ( var testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") type testStreamManager struct { + streams map[sttypes.StreamID]sttypes.Stream + newStreamFeed event.Feed rmStreamFeed event.Feed } func newTestStreamManager() *testStreamManager { - return &testStreamManager{} + return &testStreamManager{ + streams: make(map[sttypes.StreamID]sttypes.Stream), + } } func (sm *testStreamManager) addNewStream(st sttypes.Stream) { + sm.streams[st.ID()] = st sm.newStreamFeed.Send(streammanager.EvtStreamAdded{Stream: st}) } func (sm *testStreamManager) rmStream(stid sttypes.StreamID) { + delete(sm.streams, stid) sm.rmStreamFeed.Send(streammanager.EvtStreamRemoved{ID: stid}) } @@ -38,6 +44,20 @@ func (sm *testStreamManager) SubscribeRemoveStreamEvent(ch chan<- streammanager. return sm.rmStreamFeed.Subscribe(ch) } +func (sm *testStreamManager) GetStreams() []sttypes.Stream { + sts := make([]sttypes.Stream, 0, len(sm.streams)) + + for _, st := range sm.streams { + sts = append(sts, st) + } + return sts +} + +func (sm *testStreamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) { + st, exist := sm.streams[id] + return st, exist +} + type testStream struct { id sttypes.StreamID rm *requestManager @@ -79,6 +99,29 @@ func (st *testStream) ResetOnClose() error { return nil } +func makeDummyTestStreams(indexes []int) []sttypes.Stream { + sts := make([]sttypes.Stream, 0, len(indexes)) + + for _, index := range indexes { + sts = append(sts, &testStream{ + id: makeStreamID(index), + }) + } + return sts +} + +func makeDummyStreamSets(indexes []int) map[sttypes.StreamID]*stream { + m := make(map[sttypes.StreamID]*stream) + + for _, index := range indexes { + st := &testStream{ + id: makeStreamID(index), + } + m[st.ID()] = &stream{Stream: st} + } + return m +} + func makeStreamID(index int) sttypes.StreamID { return sttypes.StreamID(strconv.Itoa(index)) } diff --git a/p2p/stream/common/requestmanager/requestmanager_test.go b/p2p/stream/common/requestmanager/requestmanager_test.go index a81bdf202..01638f6b9 100644 --- a/p2p/stream/common/requestmanager/requestmanager_test.go +++ b/p2p/stream/common/requestmanager/requestmanager_test.go @@ -2,6 +2,7 @@ package requestmanager import ( "context" + "fmt" "sync" "sync/atomic" "testing" @@ -303,6 +304,95 @@ func TestGenReqID(t *testing.T) { } } +func TestCheckStreamUpdates(t *testing.T) { + tests := []struct { + exists map[sttypes.StreamID]*stream + targets []sttypes.Stream + expAddedIndexes []int + expRemovedIndexes []int + }{ + { + exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}), + targets: makeDummyTestStreams([]int{2, 3, 4, 5}), + expAddedIndexes: []int{}, + expRemovedIndexes: []int{1}, + }, + { + exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}), + targets: makeDummyTestStreams([]int{1, 2, 3, 4, 5, 6}), + expAddedIndexes: []int{6}, + expRemovedIndexes: []int{}, + }, + { + exists: makeDummyStreamSets([]int{}), + targets: makeDummyTestStreams([]int{}), + expAddedIndexes: []int{}, + expRemovedIndexes: []int{}, + }, + { + exists: makeDummyStreamSets([]int{}), + targets: makeDummyTestStreams([]int{1, 2, 3, 4, 5}), + expAddedIndexes: []int{1, 2, 3, 4, 5}, + expRemovedIndexes: []int{}, + }, + { + exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}), + targets: makeDummyTestStreams([]int{}), + expAddedIndexes: []int{}, + expRemovedIndexes: []int{1, 2, 3, 4, 5}, + }, + { + exists: makeDummyStreamSets([]int{1, 2, 3, 4, 5}), + targets: makeDummyTestStreams([]int{6, 7, 8, 9, 10}), + expAddedIndexes: []int{6, 7, 8, 9, 10}, + expRemovedIndexes: []int{1, 2, 3, 4, 5}, + }, + } + + for i, test := range tests { + added, removed := checkStreamUpdates(test.exists, test.targets) + + if err := checkStreamIDsEqual(added, test.expAddedIndexes); err != nil { + t.Errorf("Test %v: check added: %v", i, err) + } + if err := checkStreamIDsEqual2(removed, test.expRemovedIndexes); err != nil { + t.Errorf("Test %v: check removed: %v", i, err) + } + } +} + +func checkStreamIDsEqual(sts []sttypes.Stream, expIndexes []int) error { + if len(sts) != len(expIndexes) { + return fmt.Errorf("size not equal") + } + expM := make(map[sttypes.StreamID]struct{}) + for _, index := range expIndexes { + expM[makeStreamID(index)] = struct{}{} + } + for _, st := range sts { + if _, ok := expM[st.ID()]; !ok { + return fmt.Errorf("stream not exist in exp: %v", st.ID()) + } + } + return nil +} + +func checkStreamIDsEqual2(sts []*stream, expIndexes []int) error { + if len(sts) != len(expIndexes) { + return fmt.Errorf("size not equal") + } + expM := make(map[sttypes.StreamID]struct{}) + for _, index := range expIndexes { + expM[makeStreamID(index)] = struct{}{} + } + for _, st := range sts { + if _, ok := expM[st.ID()]; !ok { + return fmt.Errorf("stream not exist in exp: %v", st.ID()) + } + } + return nil +} + type testSuite struct { rm *requestManager sm *testStreamManager @@ -330,7 +420,8 @@ func newTestSuite(delayF delayFunc, respF responseFunc, numStreams int) *testSui cancel: cancel, } for i := 0; i != numStreams; i++ { - ts.bootStreams = append(ts.bootStreams, ts.makeTestStream(i)) + st := ts.makeTestStream(i) + ts.bootStreams = append(ts.bootStreams, st) } return ts }