[stream] added stream manager module (#3554)

* [stream] added a length bytes to the start of p2p base stream

* [stream] added stream manager module

Co-authored-by: Rongjian Lan <rongjian.lan@gmail.com>
pull/3565/head
Jacky Wang 4 years ago committed by GitHub
parent e12698cf44
commit 8adce8925a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 25
      p2p/stream/common/streammanager/config.go
  2. 28
      p2p/stream/common/streammanager/events.go
  3. 73
      p2p/stream/common/streammanager/events_test.go
  4. 50
      p2p/stream/common/streammanager/interface.go
  5. 203
      p2p/stream/common/streammanager/interface_test.go
  6. 416
      p2p/stream/common/streammanager/streammanager.go
  7. 243
      p2p/stream/common/streammanager/streammanager_test.go

@ -0,0 +1,25 @@
package streammanager
import "time"
const (
// checkInterval is the default interval for checking stream number. If the stream
// number is smaller than softLoCap, an active discover through DHT will be triggered.
checkInterval = 30 * time.Second
// discTimeout is the timeout for one batch of discovery
discTimeout = 10 * time.Second
// connectTimeout is the timeout for setting up a stream with a discovered peer
connectTimeout = 60 * time.Second
)
// Config is the config for stream manager
type Config struct {
// HardLoCap is low cap of stream number that immediately trigger discovery
HardLoCap int
// SoftLoCap is low cap of stream number that will trigger discovery during stream check
SoftLoCap int
// HiCap is the high cap of stream number
HiCap int
// DiscBatch is the size of each discovery
DiscBatch int
}

@ -0,0 +1,28 @@
package streammanager
import (
"github.com/ethereum/go-ethereum/event"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
)
// EvtStreamAdded is the event of adding a new stream
type (
EvtStreamAdded struct {
Stream sttypes.Stream
}
// EvtStreamRemoved is an event of stream removed
EvtStreamRemoved struct {
ID sttypes.StreamID
}
)
// SubscribeAddStreamEvent subscribe the add stream event
func (sm *streamManager) SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription {
return sm.addStreamFeed.Subscribe(ch)
}
// SubscribeRemoveStreamEvent subscribe the remove stream event
func (sm *streamManager) SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription {
return sm.removeStreamFeed.Subscribe(ch)
}

@ -0,0 +1,73 @@
package streammanager
import (
"sync/atomic"
"testing"
"time"
)
func TestStreamManager_SubscribeAddStreamEvent(t *testing.T) {
sm := newTestStreamManager()
addStreamEvtC := make(chan EvtStreamAdded, 1)
sub := sm.SubscribeAddStreamEvent(addStreamEvtC)
defer sub.Unsubscribe()
stopC := make(chan struct{}, 1)
var numStreamAdded uint32
go func() {
for {
select {
case <-addStreamEvtC:
atomic.AddUint32(&numStreamAdded, 1)
case <-stopC:
return
}
}
}()
sm.Start()
time.Sleep(defTestWait)
close(stopC)
sm.Close()
if atomic.LoadUint32(&numStreamAdded) != 16 {
t.Errorf("numStreamAdded unexpected")
}
}
func TestStreamManager_SubscribeRemoveStreamEvent(t *testing.T) {
sm := newTestStreamManager()
rmStreamEvtC := make(chan EvtStreamRemoved, 1)
sub := sm.SubscribeRemoveStreamEvent(rmStreamEvtC)
defer sub.Unsubscribe()
stopC := make(chan struct{}, 1)
var numStreamRemoved uint32
go func() {
for {
select {
case <-rmStreamEvtC:
atomic.AddUint32(&numStreamRemoved, 1)
case <-stopC:
return
}
}
}()
sm.Start()
time.Sleep(defTestWait)
err := sm.RemoveStream(makeStreamID(1))
if err != nil {
t.Fatal(err)
}
time.Sleep(defTestWait)
close(stopC)
sm.Close()
if atomic.LoadUint32(&numStreamRemoved) != 1 {
t.Errorf("numStreamAdded unexpected")
}
}

@ -0,0 +1,50 @@
package streammanager
import (
"context"
"github.com/ethereum/go-ethereum/event"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
p2ptypes "github.com/harmony-one/harmony/p2p/types"
"github.com/libp2p/go-libp2p-core/network"
libp2p_peer "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
)
// StreamManager is the interface for streamManager
type StreamManager interface {
p2ptypes.LifeCycle
StreamOperator
Subscriber
StreamReader
}
// StreamOperator handles new stream or remove stream
type StreamOperator interface {
NewStream(stream sttypes.Stream) error
RemoveStream(stID sttypes.StreamID) error
}
// Subscriber is the interface to support stream event subscription
type Subscriber interface {
SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription
SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription
}
// StreamReader is the interface to read stream in stream manager
type StreamReader interface {
GetStreams() []sttypes.Stream
GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool)
}
// host is the adapter interface of the libp2p host implementation.
// TODO: further adapt the host
type host interface {
ID() libp2p_peer.ID
NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error)
}
// peerFinder is the adapter interface of discovery.Discovery
type peerFinder interface {
FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error)
}

@ -0,0 +1,203 @@
package streammanager
import (
"context"
"errors"
"strconv"
"sync"
"sync/atomic"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/libp2p/go-libp2p-core/network"
libp2p_peer "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
)
var _ StreamManager = &streamManager{}
var (
myPeerID = makePeerID(0)
testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0")
)
const (
defHardLoCap = 16 // discovery trigger immediately when size smaller than this number
defSoftLoCap = 32 // discovery trigger for routine check
defHiCap = 128 // Hard cap of the stream number
defDiscBatch = 16 // batch size for discovery
)
var defConfig = Config{
HardLoCap: defHardLoCap,
SoftLoCap: defSoftLoCap,
HiCap: defHiCap,
DiscBatch: defDiscBatch,
}
func newTestStreamManager() *streamManager {
pid := testProtoID
host := newTestHost()
pf := newTestPeerFinder(makeRemotePeers(100), emptyDelayFunc)
sm := newStreamManager(pid, host, pf, nil, defConfig)
host.sm = sm
return sm
}
type testStream struct {
id sttypes.StreamID
proto sttypes.ProtoID
closed bool
}
func newTestStream(id sttypes.StreamID, proto sttypes.ProtoID) *testStream {
return &testStream{id: id, proto: proto}
}
func (st *testStream) ID() sttypes.StreamID {
return st.id
}
func (st *testStream) ProtoID() sttypes.ProtoID {
return st.proto
}
func (st *testStream) WriteBytes([]byte) error {
return nil
}
func (st *testStream) ReadBytes() ([]byte, error) {
return nil, nil
}
func (st *testStream) Close() error {
if st.closed {
return errors.New("already closed")
}
st.closed = true
return nil
}
func (st *testStream) ResetOnClose() error {
if st.closed {
return errors.New("already closed")
}
st.closed = true
return nil
}
func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) {
return sttypes.ProtoIDToProtoSpec(st.ProtoID())
}
type testHost struct {
sm *streamManager
streams map[sttypes.StreamID]*testStream
lock sync.Mutex
errHook streamErrorHook
}
type streamErrorHook func(id sttypes.StreamID, err error)
func newTestHost() *testHost {
return &testHost{
streams: make(map[sttypes.StreamID]*testStream),
}
}
func (h *testHost) ID() libp2p_peer.ID {
return myPeerID
}
// NewStream mock the upper function logic. When stream setup and running protocol, the
// upper code logic will call StreamManager to add new stream
func (h *testHost) NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) {
if len(pids) == 0 {
return nil, errors.New("nil protocol ids")
}
var err error
stid := sttypes.StreamID(p)
defer func() {
if err != nil && h.errHook != nil {
h.errHook(stid, err)
}
}()
st := newTestStream(stid, sttypes.ProtoID(pids[0]))
h.lock.Lock()
h.streams[stid] = st
h.lock.Unlock()
err = h.sm.NewStream(st)
return nil, err
}
func makeStreamID(index int) sttypes.StreamID {
return sttypes.StreamID(strconv.Itoa(index))
}
func makePeerID(index int) libp2p_peer.ID {
return libp2p_peer.ID(strconv.Itoa(index))
}
func makeRemotePeers(size int) []libp2p_peer.ID {
ids := make([]libp2p_peer.ID, 0, size)
for i := 1; i != size+1; i++ {
ids = append(ids, makePeerID(i))
}
return ids
}
type testPeerFinder struct {
peerIDs []libp2p_peer.ID
curIndex int32
fpHook delayFunc
}
type delayFunc func(id libp2p_peer.ID) <-chan struct{}
func emptyDelayFunc(id libp2p_peer.ID) <-chan struct{} {
c := make(chan struct{})
go func() {
c <- struct{}{}
}()
return c
}
func newTestPeerFinder(ids []libp2p_peer.ID, fpHook delayFunc) *testPeerFinder {
return &testPeerFinder{
peerIDs: ids,
curIndex: 0,
fpHook: fpHook,
}
}
func (pf *testPeerFinder) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) {
if peerLimit > len(pf.peerIDs) {
peerLimit = len(pf.peerIDs)
}
resC := make(chan libp2p_peer.AddrInfo)
go func() {
defer close(resC)
for i := 0; i != peerLimit; i++ {
// hack to prevent race
curIndex := atomic.LoadInt32(&pf.curIndex)
pid := pf.peerIDs[curIndex]
select {
case <-ctx.Done():
return
case <-pf.fpHook(pid):
}
resC <- libp2p_peer.AddrInfo{ID: pid}
atomic.AddInt32(&pf.curIndex, 1)
if int(atomic.LoadInt32(&pf.curIndex)) == len(pf.peerIDs) {
pf.curIndex = 0
}
}
}()
return resC, nil
}

@ -0,0 +1,416 @@
package streammanager
import (
"context"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/event"
"github.com/harmony-one/harmony/internal/utils"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/libp2p/go-libp2p-core/network"
libp2p_peer "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
var (
// ErrStreamAlreadyRemoved is the error that a stream has already been removed
ErrStreamAlreadyRemoved = errors.New("stream already removed")
)
// streamManager is the implementation of StreamManager. It manages streams on
// one single protocol. It does the following job:
// 1. add a new stream.
// 2. closes a stream.
// 3. discover and connect new streams when the number of streams is below threshold.
// 4. emit stream events to inform other modules.
// 5. reset all streams on close.
type streamManager struct {
// streamManager only manages streams on one protocol.
myProtoID sttypes.ProtoID
myProtoSpec sttypes.ProtoSpec
config Config
// streams is the map of peer ID to stream
// Note that it could happen that remote node does not share exactly the same
// protocol ID (e.g. different version)
streams *streamSet
// libp2p utilities
host host
pf peerFinder
handleStream func(stream network.Stream)
// incoming task channels
addStreamCh chan addStreamTask
rmStreamCh chan rmStreamTask
stopCh chan stopTask
discCh chan discTask
curTask interface{}
// utils
addStreamFeed event.Feed
removeStreamFeed event.Feed
logger zerolog.Logger
ctx context.Context
cancel func()
}
// NewStreamManager creates a new stream manager for the given proto ID
func NewStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) StreamManager {
return newStreamManager(pid, host, pf, handleStream, c)
}
// newStreamManager creates a new stream manager
func newStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) *streamManager {
ctx, cancel := context.WithCancel(context.Background())
logger := utils.Logger().With().Str("module", "stream manager").
Str("protocol ID", string(pid)).Logger()
protoSpec, _ := sttypes.ProtoIDToProtoSpec(pid)
return &streamManager{
myProtoID: pid,
myProtoSpec: protoSpec,
config: c,
streams: newStreamSet(),
host: host,
pf: pf,
handleStream: handleStream,
addStreamCh: make(chan addStreamTask),
rmStreamCh: make(chan rmStreamTask),
stopCh: make(chan stopTask),
discCh: make(chan discTask, 1), // discCh is a buffered channel to avoid overuse of goroutine
logger: logger,
ctx: ctx,
cancel: cancel,
}
}
// Start starts the stream manager
func (sm *streamManager) Start() {
go sm.loop()
}
// Close close the stream manager
func (sm *streamManager) Close() {
task := stopTask{done: make(chan struct{})}
sm.stopCh <- task
<-task.done
}
func (sm *streamManager) loop() {
var (
discTicker = time.NewTicker(checkInterval)
discCtx context.Context
discCancel func()
)
// bootstrap discovery
sm.discCh <- discTask{}
for {
select {
case <-discTicker.C:
if !sm.softHaveEnoughStreams() {
sm.discCh <- discTask{}
}
case <-sm.discCh:
// cancel last discovery
if discCancel != nil {
discCancel()
}
discCtx, discCancel = context.WithCancel(sm.ctx)
go func() {
err := sm.discoverAndSetupStream(discCtx)
if err != nil {
sm.logger.Err(err)
}
}()
case addStream := <-sm.addStreamCh:
err := sm.handleAddStream(addStream.st)
addStream.errC <- err
case rmStream := <-sm.rmStreamCh:
err := sm.handleRemoveStream(rmStream.id)
rmStream.errC <- err
case stop := <-sm.stopCh:
sm.cancel()
sm.removeAllStreamOnClose()
stop.done <- struct{}{}
return
}
}
}
// NewStream handles a new stream from stream handler protocol
func (sm *streamManager) NewStream(stream sttypes.Stream) error {
if err := sm.sanityCheckStream(stream); err != nil {
return errors.Wrap(err, "stream sanity check failed")
}
task := addStreamTask{
st: stream,
errC: make(chan error),
}
sm.addStreamCh <- task
return <-task.errC
}
// RemoveStream close and remove a stream from stream manager
func (sm *streamManager) RemoveStream(stID sttypes.StreamID) error {
task := rmStreamTask{
id: stID,
errC: make(chan error),
}
sm.rmStreamCh <- task
return <-task.errC
}
// GetStreams return the streams.
func (sm *streamManager) GetStreams() []sttypes.Stream {
return sm.streams.getStreams()
}
// GetStreamByID return the stream with the given id.
func (sm *streamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) {
return sm.streams.get(id)
}
type (
addStreamTask struct {
st sttypes.Stream
errC chan error
}
rmStreamTask struct {
id sttypes.StreamID
errC chan error
}
discTask struct{}
stopTask struct {
done chan struct{}
}
)
// sanity checks the service, network and shard ID
func (sm *streamManager) sanityCheckStream(st sttypes.Stream) error {
mySpec := sm.myProtoSpec
rmSpec, err := st.ProtoSpec()
if err != nil {
return err
}
if mySpec.Service != rmSpec.Service {
return fmt.Errorf("unexpected service: %v/%v", rmSpec.Service, mySpec.Service)
}
if mySpec.NetworkType != rmSpec.NetworkType {
return fmt.Errorf("unexpected network: %v/%v", rmSpec.NetworkType, mySpec.NetworkType)
}
if mySpec.ShardID != rmSpec.ShardID {
return fmt.Errorf("unexpected shard ID: %v/%v", rmSpec.ShardID, mySpec.ShardID)
}
return nil
}
func (sm *streamManager) handleAddStream(st sttypes.Stream) error {
id := st.ID()
if sm.streams.size() >= sm.config.HiCap {
return errors.New("too many streams")
}
if _, ok := sm.streams.get(id); ok {
return errors.New("stream already exist")
}
sm.streams.addStream(st)
sm.addStreamFeed.Send(EvtStreamAdded{st})
return nil
}
func (sm *streamManager) handleRemoveStream(id sttypes.StreamID) error {
st, ok := sm.streams.get(id)
if !ok {
return ErrStreamAlreadyRemoved
}
sm.streams.deleteStream(st)
// if stream number is smaller than HardLoCap, spin up the discover
if !sm.hardHaveEnoughStream() {
select {
case sm.discCh <- discTask{}:
default:
}
}
sm.removeStreamFeed.Send(EvtStreamRemoved{id})
return nil
}
func (sm *streamManager) removeAllStreamOnClose() {
var wg sync.WaitGroup
for _, st := range sm.streams.slice() {
wg.Add(1)
go func(st sttypes.Stream) {
defer wg.Done()
err := st.ResetOnClose()
if err != nil {
sm.logger.Warn().Err(err).Str("stream ID", string(st.ID())).
Msg("failed to close stream")
}
}(st)
}
wg.Wait()
// Be nice. after close, the field is still accessible to prevent potential panics
sm.streams = newStreamSet()
}
func (sm *streamManager) discoverAndSetupStream(discCtx context.Context) error {
peers, err := sm.discover(discCtx)
if err != nil {
return errors.Wrap(err, "failed to discover")
}
for peer := range peers {
if peer.ID == sm.host.ID() {
continue
}
go func(pid libp2p_peer.ID) {
// The ctx here is using the module context instead of discover context
err := sm.setupStreamWithPeer(sm.ctx, pid)
if err != nil {
sm.logger.Warn().Err(err).Str("peerID", string(pid)).Msg("failed to setup stream with peer")
return
}
}(peer.ID)
}
return nil
}
func (sm *streamManager) discover(ctx context.Context) (<-chan libp2p_peer.AddrInfo, error) {
protoID := string(sm.myProtoID)
discBatch := sm.config.DiscBatch
if sm.config.HiCap-sm.streams.size() < sm.config.DiscBatch {
discBatch = sm.config.HiCap - sm.streams.size()
}
if discBatch < 0 {
return nil, nil
}
ctx, _ = context.WithTimeout(ctx, discTimeout)
return sm.pf.FindPeers(ctx, protoID, discBatch)
}
func (sm *streamManager) setupStreamWithPeer(ctx context.Context, pid libp2p_peer.ID) error {
nCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
st, err := sm.host.NewStream(nCtx, pid, protocol.ID(sm.myProtoID))
if err != nil {
return err
}
if sm.handleStream != nil {
go sm.handleStream(st)
}
return nil
}
func (sm *streamManager) softHaveEnoughStreams() bool {
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec)
return availStreams >= sm.config.SoftLoCap
}
func (sm *streamManager) hardHaveEnoughStream() bool {
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec)
return availStreams >= sm.config.HardLoCap
}
// streamSet is the concurrency safe stream set.
type streamSet struct {
streams map[sttypes.StreamID]sttypes.Stream
numByProto map[sttypes.ProtoSpec]int
lock sync.RWMutex
}
func newStreamSet() *streamSet {
return &streamSet{
streams: make(map[sttypes.StreamID]sttypes.Stream),
numByProto: make(map[sttypes.ProtoSpec]int),
}
}
func (ss *streamSet) size() int {
ss.lock.RLock()
defer ss.lock.RUnlock()
return len(ss.streams)
}
func (ss *streamSet) get(id sttypes.StreamID) (sttypes.Stream, bool) {
ss.lock.RLock()
defer ss.lock.RUnlock()
st, ok := ss.streams[id]
return st, ok
}
func (ss *streamSet) addStream(st sttypes.Stream) {
ss.lock.Lock()
defer ss.lock.Unlock()
ss.streams[st.ID()] = st
spec, _ := st.ProtoSpec()
ss.numByProto[spec]++
}
func (ss *streamSet) deleteStream(st sttypes.Stream) {
ss.lock.Lock()
defer ss.lock.Unlock()
delete(ss.streams, st.ID())
spec, _ := st.ProtoSpec()
ss.numByProto[spec]--
if ss.numByProto[spec] == 0 {
delete(ss.numByProto, spec)
}
}
func (ss *streamSet) slice() []sttypes.Stream {
ss.lock.RLock()
defer ss.lock.RUnlock()
sts := make([]sttypes.Stream, 0, len(ss.streams))
for _, st := range ss.streams {
sts = append(sts, st)
}
return sts
}
func (ss *streamSet) getStreams() []sttypes.Stream {
ss.lock.RLock()
defer ss.lock.RUnlock()
res := make([]sttypes.Stream, 0, len(ss.streams))
for _, st := range ss.streams {
res = append(res, st)
}
return res
}
func (ss *streamSet) numStreamsWithMinProtoSpec(minSpec sttypes.ProtoSpec) int {
ss.lock.RLock()
defer ss.lock.RUnlock()
var res int
for spec, num := range ss.numByProto {
if !spec.Version.LessThan(minSpec.Version) {
res += num
}
}
return res
}

@ -0,0 +1,243 @@
package streammanager
import (
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
libp2p_peer "github.com/libp2p/go-libp2p-core/peer"
)
const (
defTestWait = 100 * time.Millisecond
)
// When started, discover will be run at bootstrap
func TestStreamManager_BootstrapDisc(t *testing.T) {
sm := newTestStreamManager()
sm.host.(*testHost).errHook = func(id sttypes.StreamID, err error) {
t.Errorf("%s stream error: %v", id, err)
}
// After bootstrap, stream manager shall discover streams and setup connection
// Note host will mock the upper code logic to call sm.NewStream in this case
sm.Start()
time.Sleep(defTestWait)
if gotSize := sm.streams.size(); gotSize != sm.config.DiscBatch {
t.Errorf("unexpected stream size: %v != %v", gotSize, sm.config.DiscBatch)
}
}
// After close, all stream shall be closed and removed
func TestStreamManager_Close(t *testing.T) {
sm := newTestStreamManager()
// Bootstrap
sm.Start()
time.Sleep(defTestWait)
// Close stream manager, all stream shall be closed and removed
closeDone := make(chan struct{})
go func() {
sm.Close()
closeDone <- struct{}{}
}()
select {
case <-time.After(defTestWait):
t.Errorf("still not closed")
case <-closeDone:
}
// Check stream been removed from stream manager and all streams to be closed
if sm.streams.size() != 0 {
t.Errorf("after close, stream not removed from stream manager")
}
host := sm.host.(*testHost)
for _, st := range host.streams {
if !st.closed {
t.Errorf("after close, stream still not closed")
}
}
}
// Close shall terminate the current discover at once
func TestStreamManager_CloseDisc(t *testing.T) {
sm := newTestStreamManager()
// discover will be blocked forever
sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} {
select {}
}
sm.Start()
time.Sleep(defTestWait)
// Close stream manager, all stream shall be closed and removed
closeDone := make(chan struct{})
go func() {
sm.Close()
closeDone <- struct{}{}
}()
select {
case <-time.After(defTestWait):
t.Errorf("close shall unblock the current discovery")
case <-closeDone:
}
}
// Each time discTicker ticks, it will cancel last discovery, and start a new one
func TestStreamManager_refreshDisc(t *testing.T) {
sm := newTestStreamManager()
// discover will be blocked for the first time but good for second time
var once sync.Once
sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} {
var sendSig = true
once.Do(func() {
sendSig = false
})
c := make(chan struct{}, 1)
if sendSig {
c <- struct{}{}
}
return c
}
sm.Start()
time.Sleep(defTestWait)
sm.discCh <- struct{}{}
time.Sleep(defTestWait)
// We shall now have non-zero streams setup
if sm.streams.size() == 0 {
t.Errorf("stream size still zero after refresh")
}
}
func TestStreamManager_HandleNewStream(t *testing.T) {
tests := []struct {
stream sttypes.Stream
expSize int
expErr error
}{
{
stream: newTestStream(makeStreamID(100), testProtoID),
expSize: defDiscBatch + 1,
expErr: nil,
},
{
stream: newTestStream(makeStreamID(1), testProtoID),
expSize: defDiscBatch,
expErr: errors.New("stream already exist"),
},
}
for i, test := range tests {
sm := newTestStreamManager()
sm.Start()
time.Sleep(defTestWait)
err := sm.NewStream(test.stream)
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
}
if sm.streams.size() != test.expSize {
t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(),
test.expSize)
}
}
}
func TestStreamManager_HandleRemoveStream(t *testing.T) {
tests := []struct {
id sttypes.StreamID
expSize int
expErr error
}{
{
id: makeStreamID(1),
expSize: defDiscBatch - 1,
expErr: nil,
},
{
id: makeStreamID(100),
expSize: defDiscBatch,
expErr: errors.New("stream already removed"),
},
}
for i, test := range tests {
sm := newTestStreamManager()
sm.Start()
time.Sleep(defTestWait)
err := sm.RemoveStream(test.id)
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
}
if sm.streams.size() != test.expSize {
t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(),
test.expSize)
}
}
}
// When number of streams is smaller than hard low limit, discover will be triggered
func TestStreamManager_HandleRemoveStream_Disc(t *testing.T) {
sm := newTestStreamManager()
sm.Start()
time.Sleep(defTestWait)
// Remove DiscBatch - HardLoCap + 1 streams
num := 0
for _, st := range sm.streams.slice() {
if err := sm.RemoveStream(st.ID()); err != nil {
t.Error(err)
}
num++
if num == sm.config.DiscBatch-sm.config.HardLoCap+1 {
break
}
}
// Last remove stream will also trigger discover
time.Sleep(defTestWait)
if sm.streams.size() != sm.config.HardLoCap+sm.config.DiscBatch-1 {
t.Errorf("unexpected stream number %v / %v", sm.streams.size(), sm.config.HardLoCap+sm.config.DiscBatch-1)
}
}
func TestStreamSet_numStreamsWithMinProtoID(t *testing.T) {
var (
pid1 = testProtoID
numPid1 = 5
pid2 = sttypes.ProtoID("harmony/sync/unitest/0/1.0.1")
numPid2 = 10
)
ss := newStreamSet()
for i := 0; i != numPid1; i++ {
ss.addStream(newTestStream(makeStreamID(i), pid1))
}
for i := 0; i != numPid2; i++ {
ss.addStream(newTestStream(makeStreamID(i), pid2))
}
minSpec, _ := sttypes.ProtoIDToProtoSpec(pid2)
num := ss.numStreamsWithMinProtoSpec(minSpec)
if num != numPid2 {
t.Errorf("unexpected result: %v/%v", num, numPid2)
}
}
func assertError(got, exp error) error {
if (got == nil) != (exp == nil) {
return fmt.Errorf("unexpected error: %v / %v", got, exp)
}
if got == nil {
return nil
}
if !strings.Contains(got.Error(), exp.Error()) {
return fmt.Errorf("unexpected error: %v / %v", got, exp)
}
return nil
}
Loading…
Cancel
Save