[stream] added stream manager module (#3554)
* [stream] added a length bytes to the start of p2p base stream * [stream] added stream manager module Co-authored-by: Rongjian Lan <rongjian.lan@gmail.com>pull/3565/head
parent
e12698cf44
commit
8adce8925a
@ -0,0 +1,25 @@ |
||||
package streammanager |
||||
|
||||
import "time" |
||||
|
||||
const ( |
||||
// checkInterval is the default interval for checking stream number. If the stream
|
||||
// number is smaller than softLoCap, an active discover through DHT will be triggered.
|
||||
checkInterval = 30 * time.Second |
||||
// discTimeout is the timeout for one batch of discovery
|
||||
discTimeout = 10 * time.Second |
||||
// connectTimeout is the timeout for setting up a stream with a discovered peer
|
||||
connectTimeout = 60 * time.Second |
||||
) |
||||
|
||||
// Config is the config for stream manager
|
||||
type Config struct { |
||||
// HardLoCap is low cap of stream number that immediately trigger discovery
|
||||
HardLoCap int |
||||
// SoftLoCap is low cap of stream number that will trigger discovery during stream check
|
||||
SoftLoCap int |
||||
// HiCap is the high cap of stream number
|
||||
HiCap int |
||||
// DiscBatch is the size of each discovery
|
||||
DiscBatch int |
||||
} |
@ -0,0 +1,28 @@ |
||||
package streammanager |
||||
|
||||
import ( |
||||
"github.com/ethereum/go-ethereum/event" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
) |
||||
|
||||
// EvtStreamAdded is the event of adding a new stream
|
||||
type ( |
||||
EvtStreamAdded struct { |
||||
Stream sttypes.Stream |
||||
} |
||||
|
||||
// EvtStreamRemoved is an event of stream removed
|
||||
EvtStreamRemoved struct { |
||||
ID sttypes.StreamID |
||||
} |
||||
) |
||||
|
||||
// SubscribeAddStreamEvent subscribe the add stream event
|
||||
func (sm *streamManager) SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription { |
||||
return sm.addStreamFeed.Subscribe(ch) |
||||
} |
||||
|
||||
// SubscribeRemoveStreamEvent subscribe the remove stream event
|
||||
func (sm *streamManager) SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription { |
||||
return sm.removeStreamFeed.Subscribe(ch) |
||||
} |
@ -0,0 +1,73 @@ |
||||
package streammanager |
||||
|
||||
import ( |
||||
"sync/atomic" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestStreamManager_SubscribeAddStreamEvent(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
|
||||
addStreamEvtC := make(chan EvtStreamAdded, 1) |
||||
sub := sm.SubscribeAddStreamEvent(addStreamEvtC) |
||||
defer sub.Unsubscribe() |
||||
stopC := make(chan struct{}, 1) |
||||
|
||||
var numStreamAdded uint32 |
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-addStreamEvtC: |
||||
atomic.AddUint32(&numStreamAdded, 1) |
||||
case <-stopC: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
close(stopC) |
||||
sm.Close() |
||||
|
||||
if atomic.LoadUint32(&numStreamAdded) != 16 { |
||||
t.Errorf("numStreamAdded unexpected") |
||||
} |
||||
} |
||||
|
||||
func TestStreamManager_SubscribeRemoveStreamEvent(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
|
||||
rmStreamEvtC := make(chan EvtStreamRemoved, 1) |
||||
sub := sm.SubscribeRemoveStreamEvent(rmStreamEvtC) |
||||
defer sub.Unsubscribe() |
||||
stopC := make(chan struct{}, 1) |
||||
|
||||
var numStreamRemoved uint32 |
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-rmStreamEvtC: |
||||
atomic.AddUint32(&numStreamRemoved, 1) |
||||
case <-stopC: |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
|
||||
err := sm.RemoveStream(makeStreamID(1)) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
time.Sleep(defTestWait) |
||||
close(stopC) |
||||
sm.Close() |
||||
|
||||
if atomic.LoadUint32(&numStreamRemoved) != 1 { |
||||
t.Errorf("numStreamAdded unexpected") |
||||
} |
||||
} |
@ -0,0 +1,50 @@ |
||||
package streammanager |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/ethereum/go-ethereum/event" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
p2ptypes "github.com/harmony-one/harmony/p2p/types" |
||||
"github.com/libp2p/go-libp2p-core/network" |
||||
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||
"github.com/libp2p/go-libp2p-core/protocol" |
||||
) |
||||
|
||||
// StreamManager is the interface for streamManager
|
||||
type StreamManager interface { |
||||
p2ptypes.LifeCycle |
||||
StreamOperator |
||||
Subscriber |
||||
StreamReader |
||||
} |
||||
|
||||
// StreamOperator handles new stream or remove stream
|
||||
type StreamOperator interface { |
||||
NewStream(stream sttypes.Stream) error |
||||
RemoveStream(stID sttypes.StreamID) error |
||||
} |
||||
|
||||
// Subscriber is the interface to support stream event subscription
|
||||
type Subscriber interface { |
||||
SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription |
||||
SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription |
||||
} |
||||
|
||||
// StreamReader is the interface to read stream in stream manager
|
||||
type StreamReader interface { |
||||
GetStreams() []sttypes.Stream |
||||
GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) |
||||
} |
||||
|
||||
// host is the adapter interface of the libp2p host implementation.
|
||||
// TODO: further adapt the host
|
||||
type host interface { |
||||
ID() libp2p_peer.ID |
||||
NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) |
||||
} |
||||
|
||||
// peerFinder is the adapter interface of discovery.Discovery
|
||||
type peerFinder interface { |
||||
FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) |
||||
} |
@ -0,0 +1,203 @@ |
||||
package streammanager |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"strconv" |
||||
"sync" |
||||
"sync/atomic" |
||||
|
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
"github.com/libp2p/go-libp2p-core/network" |
||||
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||
"github.com/libp2p/go-libp2p-core/protocol" |
||||
) |
||||
|
||||
var _ StreamManager = &streamManager{} |
||||
|
||||
var ( |
||||
myPeerID = makePeerID(0) |
||||
testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") |
||||
) |
||||
|
||||
const ( |
||||
defHardLoCap = 16 // discovery trigger immediately when size smaller than this number
|
||||
defSoftLoCap = 32 // discovery trigger for routine check
|
||||
defHiCap = 128 // Hard cap of the stream number
|
||||
defDiscBatch = 16 // batch size for discovery
|
||||
) |
||||
|
||||
var defConfig = Config{ |
||||
HardLoCap: defHardLoCap, |
||||
SoftLoCap: defSoftLoCap, |
||||
HiCap: defHiCap, |
||||
DiscBatch: defDiscBatch, |
||||
} |
||||
|
||||
func newTestStreamManager() *streamManager { |
||||
pid := testProtoID |
||||
host := newTestHost() |
||||
pf := newTestPeerFinder(makeRemotePeers(100), emptyDelayFunc) |
||||
|
||||
sm := newStreamManager(pid, host, pf, nil, defConfig) |
||||
host.sm = sm |
||||
return sm |
||||
} |
||||
|
||||
type testStream struct { |
||||
id sttypes.StreamID |
||||
proto sttypes.ProtoID |
||||
closed bool |
||||
} |
||||
|
||||
func newTestStream(id sttypes.StreamID, proto sttypes.ProtoID) *testStream { |
||||
return &testStream{id: id, proto: proto} |
||||
} |
||||
|
||||
func (st *testStream) ID() sttypes.StreamID { |
||||
return st.id |
||||
} |
||||
|
||||
func (st *testStream) ProtoID() sttypes.ProtoID { |
||||
return st.proto |
||||
} |
||||
|
||||
func (st *testStream) WriteBytes([]byte) error { |
||||
return nil |
||||
} |
||||
|
||||
func (st *testStream) ReadBytes() ([]byte, error) { |
||||
return nil, nil |
||||
} |
||||
|
||||
func (st *testStream) Close() error { |
||||
if st.closed { |
||||
return errors.New("already closed") |
||||
} |
||||
st.closed = true |
||||
return nil |
||||
} |
||||
|
||||
func (st *testStream) ResetOnClose() error { |
||||
if st.closed { |
||||
return errors.New("already closed") |
||||
} |
||||
st.closed = true |
||||
return nil |
||||
} |
||||
|
||||
func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) { |
||||
return sttypes.ProtoIDToProtoSpec(st.ProtoID()) |
||||
} |
||||
|
||||
type testHost struct { |
||||
sm *streamManager |
||||
streams map[sttypes.StreamID]*testStream |
||||
lock sync.Mutex |
||||
|
||||
errHook streamErrorHook |
||||
} |
||||
|
||||
type streamErrorHook func(id sttypes.StreamID, err error) |
||||
|
||||
func newTestHost() *testHost { |
||||
return &testHost{ |
||||
streams: make(map[sttypes.StreamID]*testStream), |
||||
} |
||||
} |
||||
|
||||
func (h *testHost) ID() libp2p_peer.ID { |
||||
return myPeerID |
||||
} |
||||
|
||||
// NewStream mock the upper function logic. When stream setup and running protocol, the
|
||||
// upper code logic will call StreamManager to add new stream
|
||||
func (h *testHost) NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) { |
||||
if len(pids) == 0 { |
||||
return nil, errors.New("nil protocol ids") |
||||
} |
||||
var err error |
||||
stid := sttypes.StreamID(p) |
||||
defer func() { |
||||
if err != nil && h.errHook != nil { |
||||
h.errHook(stid, err) |
||||
} |
||||
}() |
||||
|
||||
st := newTestStream(stid, sttypes.ProtoID(pids[0])) |
||||
h.lock.Lock() |
||||
h.streams[stid] = st |
||||
h.lock.Unlock() |
||||
|
||||
err = h.sm.NewStream(st) |
||||
return nil, err |
||||
} |
||||
|
||||
func makeStreamID(index int) sttypes.StreamID { |
||||
return sttypes.StreamID(strconv.Itoa(index)) |
||||
} |
||||
|
||||
func makePeerID(index int) libp2p_peer.ID { |
||||
return libp2p_peer.ID(strconv.Itoa(index)) |
||||
} |
||||
|
||||
func makeRemotePeers(size int) []libp2p_peer.ID { |
||||
ids := make([]libp2p_peer.ID, 0, size) |
||||
for i := 1; i != size+1; i++ { |
||||
ids = append(ids, makePeerID(i)) |
||||
} |
||||
return ids |
||||
} |
||||
|
||||
type testPeerFinder struct { |
||||
peerIDs []libp2p_peer.ID |
||||
curIndex int32 |
||||
fpHook delayFunc |
||||
} |
||||
|
||||
type delayFunc func(id libp2p_peer.ID) <-chan struct{} |
||||
|
||||
func emptyDelayFunc(id libp2p_peer.ID) <-chan struct{} { |
||||
c := make(chan struct{}) |
||||
go func() { |
||||
c <- struct{}{} |
||||
}() |
||||
return c |
||||
} |
||||
|
||||
func newTestPeerFinder(ids []libp2p_peer.ID, fpHook delayFunc) *testPeerFinder { |
||||
return &testPeerFinder{ |
||||
peerIDs: ids, |
||||
curIndex: 0, |
||||
fpHook: fpHook, |
||||
} |
||||
} |
||||
|
||||
func (pf *testPeerFinder) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { |
||||
if peerLimit > len(pf.peerIDs) { |
||||
peerLimit = len(pf.peerIDs) |
||||
} |
||||
resC := make(chan libp2p_peer.AddrInfo) |
||||
|
||||
go func() { |
||||
defer close(resC) |
||||
|
||||
for i := 0; i != peerLimit; i++ { |
||||
// hack to prevent race
|
||||
curIndex := atomic.LoadInt32(&pf.curIndex) |
||||
pid := pf.peerIDs[curIndex] |
||||
select { |
||||
case <-ctx.Done(): |
||||
return |
||||
case <-pf.fpHook(pid): |
||||
} |
||||
resC <- libp2p_peer.AddrInfo{ID: pid} |
||||
atomic.AddInt32(&pf.curIndex, 1) |
||||
if int(atomic.LoadInt32(&pf.curIndex)) == len(pf.peerIDs) { |
||||
pf.curIndex = 0 |
||||
} |
||||
} |
||||
}() |
||||
|
||||
return resC, nil |
||||
} |
@ -0,0 +1,416 @@ |
||||
package streammanager |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/ethereum/go-ethereum/event" |
||||
"github.com/harmony-one/harmony/internal/utils" |
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
"github.com/libp2p/go-libp2p-core/network" |
||||
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||
"github.com/libp2p/go-libp2p-core/protocol" |
||||
"github.com/pkg/errors" |
||||
"github.com/rs/zerolog" |
||||
) |
||||
|
||||
var ( |
||||
// ErrStreamAlreadyRemoved is the error that a stream has already been removed
|
||||
ErrStreamAlreadyRemoved = errors.New("stream already removed") |
||||
) |
||||
|
||||
// streamManager is the implementation of StreamManager. It manages streams on
|
||||
// one single protocol. It does the following job:
|
||||
// 1. add a new stream.
|
||||
// 2. closes a stream.
|
||||
// 3. discover and connect new streams when the number of streams is below threshold.
|
||||
// 4. emit stream events to inform other modules.
|
||||
// 5. reset all streams on close.
|
||||
type streamManager struct { |
||||
// streamManager only manages streams on one protocol.
|
||||
myProtoID sttypes.ProtoID |
||||
myProtoSpec sttypes.ProtoSpec |
||||
config Config |
||||
// streams is the map of peer ID to stream
|
||||
// Note that it could happen that remote node does not share exactly the same
|
||||
// protocol ID (e.g. different version)
|
||||
streams *streamSet |
||||
// libp2p utilities
|
||||
host host |
||||
pf peerFinder |
||||
handleStream func(stream network.Stream) |
||||
// incoming task channels
|
||||
addStreamCh chan addStreamTask |
||||
rmStreamCh chan rmStreamTask |
||||
stopCh chan stopTask |
||||
discCh chan discTask |
||||
curTask interface{} |
||||
// utils
|
||||
addStreamFeed event.Feed |
||||
removeStreamFeed event.Feed |
||||
logger zerolog.Logger |
||||
ctx context.Context |
||||
cancel func() |
||||
} |
||||
|
||||
// NewStreamManager creates a new stream manager for the given proto ID
|
||||
func NewStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) StreamManager { |
||||
return newStreamManager(pid, host, pf, handleStream, c) |
||||
} |
||||
|
||||
// newStreamManager creates a new stream manager
|
||||
func newStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) *streamManager { |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
||||
logger := utils.Logger().With().Str("module", "stream manager"). |
||||
Str("protocol ID", string(pid)).Logger() |
||||
|
||||
protoSpec, _ := sttypes.ProtoIDToProtoSpec(pid) |
||||
|
||||
return &streamManager{ |
||||
myProtoID: pid, |
||||
myProtoSpec: protoSpec, |
||||
config: c, |
||||
streams: newStreamSet(), |
||||
host: host, |
||||
pf: pf, |
||||
handleStream: handleStream, |
||||
addStreamCh: make(chan addStreamTask), |
||||
rmStreamCh: make(chan rmStreamTask), |
||||
stopCh: make(chan stopTask), |
||||
discCh: make(chan discTask, 1), // discCh is a buffered channel to avoid overuse of goroutine
|
||||
|
||||
logger: logger, |
||||
ctx: ctx, |
||||
cancel: cancel, |
||||
} |
||||
} |
||||
|
||||
// Start starts the stream manager
|
||||
func (sm *streamManager) Start() { |
||||
go sm.loop() |
||||
} |
||||
|
||||
// Close close the stream manager
|
||||
func (sm *streamManager) Close() { |
||||
task := stopTask{done: make(chan struct{})} |
||||
sm.stopCh <- task |
||||
|
||||
<-task.done |
||||
} |
||||
|
||||
func (sm *streamManager) loop() { |
||||
var ( |
||||
discTicker = time.NewTicker(checkInterval) |
||||
discCtx context.Context |
||||
discCancel func() |
||||
) |
||||
// bootstrap discovery
|
||||
sm.discCh <- discTask{} |
||||
|
||||
for { |
||||
select { |
||||
case <-discTicker.C: |
||||
if !sm.softHaveEnoughStreams() { |
||||
sm.discCh <- discTask{} |
||||
} |
||||
|
||||
case <-sm.discCh: |
||||
// cancel last discovery
|
||||
if discCancel != nil { |
||||
discCancel() |
||||
} |
||||
discCtx, discCancel = context.WithCancel(sm.ctx) |
||||
go func() { |
||||
err := sm.discoverAndSetupStream(discCtx) |
||||
if err != nil { |
||||
sm.logger.Err(err) |
||||
} |
||||
}() |
||||
|
||||
case addStream := <-sm.addStreamCh: |
||||
err := sm.handleAddStream(addStream.st) |
||||
addStream.errC <- err |
||||
|
||||
case rmStream := <-sm.rmStreamCh: |
||||
err := sm.handleRemoveStream(rmStream.id) |
||||
rmStream.errC <- err |
||||
|
||||
case stop := <-sm.stopCh: |
||||
sm.cancel() |
||||
sm.removeAllStreamOnClose() |
||||
stop.done <- struct{}{} |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// NewStream handles a new stream from stream handler protocol
|
||||
func (sm *streamManager) NewStream(stream sttypes.Stream) error { |
||||
if err := sm.sanityCheckStream(stream); err != nil { |
||||
return errors.Wrap(err, "stream sanity check failed") |
||||
} |
||||
task := addStreamTask{ |
||||
st: stream, |
||||
errC: make(chan error), |
||||
} |
||||
sm.addStreamCh <- task |
||||
return <-task.errC |
||||
} |
||||
|
||||
// RemoveStream close and remove a stream from stream manager
|
||||
func (sm *streamManager) RemoveStream(stID sttypes.StreamID) error { |
||||
task := rmStreamTask{ |
||||
id: stID, |
||||
errC: make(chan error), |
||||
} |
||||
sm.rmStreamCh <- task |
||||
return <-task.errC |
||||
} |
||||
|
||||
// GetStreams return the streams.
|
||||
func (sm *streamManager) GetStreams() []sttypes.Stream { |
||||
return sm.streams.getStreams() |
||||
} |
||||
|
||||
// GetStreamByID return the stream with the given id.
|
||||
func (sm *streamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) { |
||||
return sm.streams.get(id) |
||||
} |
||||
|
||||
type ( |
||||
addStreamTask struct { |
||||
st sttypes.Stream |
||||
errC chan error |
||||
} |
||||
|
||||
rmStreamTask struct { |
||||
id sttypes.StreamID |
||||
errC chan error |
||||
} |
||||
|
||||
discTask struct{} |
||||
|
||||
stopTask struct { |
||||
done chan struct{} |
||||
} |
||||
) |
||||
|
||||
// sanity checks the service, network and shard ID
|
||||
func (sm *streamManager) sanityCheckStream(st sttypes.Stream) error { |
||||
mySpec := sm.myProtoSpec |
||||
rmSpec, err := st.ProtoSpec() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if mySpec.Service != rmSpec.Service { |
||||
return fmt.Errorf("unexpected service: %v/%v", rmSpec.Service, mySpec.Service) |
||||
} |
||||
if mySpec.NetworkType != rmSpec.NetworkType { |
||||
return fmt.Errorf("unexpected network: %v/%v", rmSpec.NetworkType, mySpec.NetworkType) |
||||
} |
||||
if mySpec.ShardID != rmSpec.ShardID { |
||||
return fmt.Errorf("unexpected shard ID: %v/%v", rmSpec.ShardID, mySpec.ShardID) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (sm *streamManager) handleAddStream(st sttypes.Stream) error { |
||||
id := st.ID() |
||||
if sm.streams.size() >= sm.config.HiCap { |
||||
return errors.New("too many streams") |
||||
} |
||||
if _, ok := sm.streams.get(id); ok { |
||||
return errors.New("stream already exist") |
||||
} |
||||
|
||||
sm.streams.addStream(st) |
||||
|
||||
sm.addStreamFeed.Send(EvtStreamAdded{st}) |
||||
return nil |
||||
} |
||||
|
||||
func (sm *streamManager) handleRemoveStream(id sttypes.StreamID) error { |
||||
st, ok := sm.streams.get(id) |
||||
if !ok { |
||||
return ErrStreamAlreadyRemoved |
||||
} |
||||
|
||||
sm.streams.deleteStream(st) |
||||
// if stream number is smaller than HardLoCap, spin up the discover
|
||||
if !sm.hardHaveEnoughStream() { |
||||
select { |
||||
case sm.discCh <- discTask{}: |
||||
default: |
||||
} |
||||
} |
||||
sm.removeStreamFeed.Send(EvtStreamRemoved{id}) |
||||
return nil |
||||
} |
||||
|
||||
func (sm *streamManager) removeAllStreamOnClose() { |
||||
var wg sync.WaitGroup |
||||
|
||||
for _, st := range sm.streams.slice() { |
||||
wg.Add(1) |
||||
go func(st sttypes.Stream) { |
||||
defer wg.Done() |
||||
err := st.ResetOnClose() |
||||
if err != nil { |
||||
sm.logger.Warn().Err(err).Str("stream ID", string(st.ID())). |
||||
Msg("failed to close stream") |
||||
} |
||||
}(st) |
||||
} |
||||
wg.Wait() |
||||
|
||||
// Be nice. after close, the field is still accessible to prevent potential panics
|
||||
sm.streams = newStreamSet() |
||||
} |
||||
|
||||
func (sm *streamManager) discoverAndSetupStream(discCtx context.Context) error { |
||||
peers, err := sm.discover(discCtx) |
||||
if err != nil { |
||||
return errors.Wrap(err, "failed to discover") |
||||
} |
||||
for peer := range peers { |
||||
if peer.ID == sm.host.ID() { |
||||
continue |
||||
} |
||||
go func(pid libp2p_peer.ID) { |
||||
// The ctx here is using the module context instead of discover context
|
||||
err := sm.setupStreamWithPeer(sm.ctx, pid) |
||||
if err != nil { |
||||
sm.logger.Warn().Err(err).Str("peerID", string(pid)).Msg("failed to setup stream with peer") |
||||
return |
||||
} |
||||
}(peer.ID) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (sm *streamManager) discover(ctx context.Context) (<-chan libp2p_peer.AddrInfo, error) { |
||||
protoID := string(sm.myProtoID) |
||||
discBatch := sm.config.DiscBatch |
||||
if sm.config.HiCap-sm.streams.size() < sm.config.DiscBatch { |
||||
discBatch = sm.config.HiCap - sm.streams.size() |
||||
} |
||||
if discBatch < 0 { |
||||
return nil, nil |
||||
} |
||||
|
||||
ctx, _ = context.WithTimeout(ctx, discTimeout) |
||||
return sm.pf.FindPeers(ctx, protoID, discBatch) |
||||
} |
||||
|
||||
func (sm *streamManager) setupStreamWithPeer(ctx context.Context, pid libp2p_peer.ID) error { |
||||
nCtx, cancel := context.WithTimeout(ctx, connectTimeout) |
||||
defer cancel() |
||||
|
||||
st, err := sm.host.NewStream(nCtx, pid, protocol.ID(sm.myProtoID)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if sm.handleStream != nil { |
||||
go sm.handleStream(st) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (sm *streamManager) softHaveEnoughStreams() bool { |
||||
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec) |
||||
return availStreams >= sm.config.SoftLoCap |
||||
} |
||||
|
||||
func (sm *streamManager) hardHaveEnoughStream() bool { |
||||
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec) |
||||
return availStreams >= sm.config.HardLoCap |
||||
} |
||||
|
||||
// streamSet is the concurrency safe stream set.
|
||||
type streamSet struct { |
||||
streams map[sttypes.StreamID]sttypes.Stream |
||||
numByProto map[sttypes.ProtoSpec]int |
||||
lock sync.RWMutex |
||||
} |
||||
|
||||
func newStreamSet() *streamSet { |
||||
return &streamSet{ |
||||
streams: make(map[sttypes.StreamID]sttypes.Stream), |
||||
numByProto: make(map[sttypes.ProtoSpec]int), |
||||
} |
||||
} |
||||
|
||||
func (ss *streamSet) size() int { |
||||
ss.lock.RLock() |
||||
defer ss.lock.RUnlock() |
||||
|
||||
return len(ss.streams) |
||||
} |
||||
|
||||
func (ss *streamSet) get(id sttypes.StreamID) (sttypes.Stream, bool) { |
||||
ss.lock.RLock() |
||||
defer ss.lock.RUnlock() |
||||
|
||||
st, ok := ss.streams[id] |
||||
return st, ok |
||||
} |
||||
|
||||
func (ss *streamSet) addStream(st sttypes.Stream) { |
||||
ss.lock.Lock() |
||||
defer ss.lock.Unlock() |
||||
|
||||
ss.streams[st.ID()] = st |
||||
spec, _ := st.ProtoSpec() |
||||
ss.numByProto[spec]++ |
||||
} |
||||
|
||||
func (ss *streamSet) deleteStream(st sttypes.Stream) { |
||||
ss.lock.Lock() |
||||
defer ss.lock.Unlock() |
||||
|
||||
delete(ss.streams, st.ID()) |
||||
|
||||
spec, _ := st.ProtoSpec() |
||||
ss.numByProto[spec]-- |
||||
if ss.numByProto[spec] == 0 { |
||||
delete(ss.numByProto, spec) |
||||
} |
||||
} |
||||
|
||||
func (ss *streamSet) slice() []sttypes.Stream { |
||||
ss.lock.RLock() |
||||
defer ss.lock.RUnlock() |
||||
|
||||
sts := make([]sttypes.Stream, 0, len(ss.streams)) |
||||
for _, st := range ss.streams { |
||||
sts = append(sts, st) |
||||
} |
||||
return sts |
||||
} |
||||
|
||||
func (ss *streamSet) getStreams() []sttypes.Stream { |
||||
ss.lock.RLock() |
||||
defer ss.lock.RUnlock() |
||||
|
||||
res := make([]sttypes.Stream, 0, len(ss.streams)) |
||||
for _, st := range ss.streams { |
||||
res = append(res, st) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (ss *streamSet) numStreamsWithMinProtoSpec(minSpec sttypes.ProtoSpec) int { |
||||
ss.lock.RLock() |
||||
defer ss.lock.RUnlock() |
||||
|
||||
var res int |
||||
for spec, num := range ss.numByProto { |
||||
if !spec.Version.LessThan(minSpec.Version) { |
||||
res += num |
||||
} |
||||
} |
||||
return res |
||||
} |
@ -0,0 +1,243 @@ |
||||
package streammanager |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||
) |
||||
|
||||
const ( |
||||
defTestWait = 100 * time.Millisecond |
||||
) |
||||
|
||||
// When started, discover will be run at bootstrap
|
||||
func TestStreamManager_BootstrapDisc(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
sm.host.(*testHost).errHook = func(id sttypes.StreamID, err error) { |
||||
t.Errorf("%s stream error: %v", id, err) |
||||
} |
||||
|
||||
// After bootstrap, stream manager shall discover streams and setup connection
|
||||
// Note host will mock the upper code logic to call sm.NewStream in this case
|
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
if gotSize := sm.streams.size(); gotSize != sm.config.DiscBatch { |
||||
t.Errorf("unexpected stream size: %v != %v", gotSize, sm.config.DiscBatch) |
||||
} |
||||
} |
||||
|
||||
// After close, all stream shall be closed and removed
|
||||
func TestStreamManager_Close(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
// Bootstrap
|
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
// Close stream manager, all stream shall be closed and removed
|
||||
closeDone := make(chan struct{}) |
||||
go func() { |
||||
sm.Close() |
||||
closeDone <- struct{}{} |
||||
}() |
||||
select { |
||||
case <-time.After(defTestWait): |
||||
t.Errorf("still not closed") |
||||
case <-closeDone: |
||||
} |
||||
// Check stream been removed from stream manager and all streams to be closed
|
||||
if sm.streams.size() != 0 { |
||||
t.Errorf("after close, stream not removed from stream manager") |
||||
} |
||||
host := sm.host.(*testHost) |
||||
for _, st := range host.streams { |
||||
if !st.closed { |
||||
t.Errorf("after close, stream still not closed") |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Close shall terminate the current discover at once
|
||||
func TestStreamManager_CloseDisc(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
// discover will be blocked forever
|
||||
sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} { |
||||
select {} |
||||
} |
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
// Close stream manager, all stream shall be closed and removed
|
||||
closeDone := make(chan struct{}) |
||||
go func() { |
||||
sm.Close() |
||||
closeDone <- struct{}{} |
||||
}() |
||||
select { |
||||
case <-time.After(defTestWait): |
||||
t.Errorf("close shall unblock the current discovery") |
||||
case <-closeDone: |
||||
} |
||||
} |
||||
|
||||
// Each time discTicker ticks, it will cancel last discovery, and start a new one
|
||||
func TestStreamManager_refreshDisc(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
// discover will be blocked for the first time but good for second time
|
||||
var once sync.Once |
||||
sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} { |
||||
var sendSig = true |
||||
once.Do(func() { |
||||
sendSig = false |
||||
}) |
||||
c := make(chan struct{}, 1) |
||||
if sendSig { |
||||
c <- struct{}{} |
||||
} |
||||
return c |
||||
} |
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
|
||||
sm.discCh <- struct{}{} |
||||
time.Sleep(defTestWait) |
||||
|
||||
// We shall now have non-zero streams setup
|
||||
if sm.streams.size() == 0 { |
||||
t.Errorf("stream size still zero after refresh") |
||||
} |
||||
} |
||||
|
||||
func TestStreamManager_HandleNewStream(t *testing.T) { |
||||
tests := []struct { |
||||
stream sttypes.Stream |
||||
expSize int |
||||
expErr error |
||||
}{ |
||||
{ |
||||
stream: newTestStream(makeStreamID(100), testProtoID), |
||||
expSize: defDiscBatch + 1, |
||||
expErr: nil, |
||||
}, |
||||
{ |
||||
stream: newTestStream(makeStreamID(1), testProtoID), |
||||
expSize: defDiscBatch, |
||||
expErr: errors.New("stream already exist"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
sm := newTestStreamManager() |
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
|
||||
err := sm.NewStream(test.stream) |
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
|
||||
if sm.streams.size() != test.expSize { |
||||
t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(), |
||||
test.expSize) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestStreamManager_HandleRemoveStream(t *testing.T) { |
||||
tests := []struct { |
||||
id sttypes.StreamID |
||||
expSize int |
||||
expErr error |
||||
}{ |
||||
{ |
||||
id: makeStreamID(1), |
||||
expSize: defDiscBatch - 1, |
||||
expErr: nil, |
||||
}, |
||||
{ |
||||
id: makeStreamID(100), |
||||
expSize: defDiscBatch, |
||||
expErr: errors.New("stream already removed"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
sm := newTestStreamManager() |
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
|
||||
err := sm.RemoveStream(test.id) |
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
|
||||
if sm.streams.size() != test.expSize { |
||||
t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(), |
||||
test.expSize) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// When number of streams is smaller than hard low limit, discover will be triggered
|
||||
func TestStreamManager_HandleRemoveStream_Disc(t *testing.T) { |
||||
sm := newTestStreamManager() |
||||
sm.Start() |
||||
time.Sleep(defTestWait) |
||||
|
||||
// Remove DiscBatch - HardLoCap + 1 streams
|
||||
num := 0 |
||||
for _, st := range sm.streams.slice() { |
||||
if err := sm.RemoveStream(st.ID()); err != nil { |
||||
t.Error(err) |
||||
} |
||||
num++ |
||||
if num == sm.config.DiscBatch-sm.config.HardLoCap+1 { |
||||
break |
||||
} |
||||
} |
||||
|
||||
// Last remove stream will also trigger discover
|
||||
time.Sleep(defTestWait) |
||||
if sm.streams.size() != sm.config.HardLoCap+sm.config.DiscBatch-1 { |
||||
t.Errorf("unexpected stream number %v / %v", sm.streams.size(), sm.config.HardLoCap+sm.config.DiscBatch-1) |
||||
} |
||||
} |
||||
|
||||
func TestStreamSet_numStreamsWithMinProtoID(t *testing.T) { |
||||
var ( |
||||
pid1 = testProtoID |
||||
numPid1 = 5 |
||||
|
||||
pid2 = sttypes.ProtoID("harmony/sync/unitest/0/1.0.1") |
||||
numPid2 = 10 |
||||
) |
||||
|
||||
ss := newStreamSet() |
||||
|
||||
for i := 0; i != numPid1; i++ { |
||||
ss.addStream(newTestStream(makeStreamID(i), pid1)) |
||||
} |
||||
for i := 0; i != numPid2; i++ { |
||||
ss.addStream(newTestStream(makeStreamID(i), pid2)) |
||||
} |
||||
|
||||
minSpec, _ := sttypes.ProtoIDToProtoSpec(pid2) |
||||
num := ss.numStreamsWithMinProtoSpec(minSpec) |
||||
if num != numPid2 { |
||||
t.Errorf("unexpected result: %v/%v", num, numPid2) |
||||
} |
||||
} |
||||
|
||||
func assertError(got, exp error) error { |
||||
if (got == nil) != (exp == nil) { |
||||
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||
} |
||||
if got == nil { |
||||
return nil |
||||
} |
||||
if !strings.Contains(got.Error(), exp.Error()) { |
||||
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||
} |
||||
return nil |
||||
} |
Loading…
Reference in new issue