[stream] added stream manager module (#3554)
* [stream] added a length bytes to the start of p2p base stream * [stream] added stream manager module Co-authored-by: Rongjian Lan <rongjian.lan@gmail.com>pull/3565/head
parent
e12698cf44
commit
8adce8925a
@ -0,0 +1,25 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import "time" |
||||||
|
|
||||||
|
const ( |
||||||
|
// checkInterval is the default interval for checking stream number. If the stream
|
||||||
|
// number is smaller than softLoCap, an active discover through DHT will be triggered.
|
||||||
|
checkInterval = 30 * time.Second |
||||||
|
// discTimeout is the timeout for one batch of discovery
|
||||||
|
discTimeout = 10 * time.Second |
||||||
|
// connectTimeout is the timeout for setting up a stream with a discovered peer
|
||||||
|
connectTimeout = 60 * time.Second |
||||||
|
) |
||||||
|
|
||||||
|
// Config is the config for stream manager
|
||||||
|
type Config struct { |
||||||
|
// HardLoCap is low cap of stream number that immediately trigger discovery
|
||||||
|
HardLoCap int |
||||||
|
// SoftLoCap is low cap of stream number that will trigger discovery during stream check
|
||||||
|
SoftLoCap int |
||||||
|
// HiCap is the high cap of stream number
|
||||||
|
HiCap int |
||||||
|
// DiscBatch is the size of each discovery
|
||||||
|
DiscBatch int |
||||||
|
} |
@ -0,0 +1,28 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
) |
||||||
|
|
||||||
|
// EvtStreamAdded is the event of adding a new stream
|
||||||
|
type ( |
||||||
|
EvtStreamAdded struct { |
||||||
|
Stream sttypes.Stream |
||||||
|
} |
||||||
|
|
||||||
|
// EvtStreamRemoved is an event of stream removed
|
||||||
|
EvtStreamRemoved struct { |
||||||
|
ID sttypes.StreamID |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
// SubscribeAddStreamEvent subscribe the add stream event
|
||||||
|
func (sm *streamManager) SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription { |
||||||
|
return sm.addStreamFeed.Subscribe(ch) |
||||||
|
} |
||||||
|
|
||||||
|
// SubscribeRemoveStreamEvent subscribe the remove stream event
|
||||||
|
func (sm *streamManager) SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription { |
||||||
|
return sm.removeStreamFeed.Subscribe(ch) |
||||||
|
} |
@ -0,0 +1,73 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"sync/atomic" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
func TestStreamManager_SubscribeAddStreamEvent(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
|
||||||
|
addStreamEvtC := make(chan EvtStreamAdded, 1) |
||||||
|
sub := sm.SubscribeAddStreamEvent(addStreamEvtC) |
||||||
|
defer sub.Unsubscribe() |
||||||
|
stopC := make(chan struct{}, 1) |
||||||
|
|
||||||
|
var numStreamAdded uint32 |
||||||
|
go func() { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-addStreamEvtC: |
||||||
|
atomic.AddUint32(&numStreamAdded, 1) |
||||||
|
case <-stopC: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
close(stopC) |
||||||
|
sm.Close() |
||||||
|
|
||||||
|
if atomic.LoadUint32(&numStreamAdded) != 16 { |
||||||
|
t.Errorf("numStreamAdded unexpected") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestStreamManager_SubscribeRemoveStreamEvent(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
|
||||||
|
rmStreamEvtC := make(chan EvtStreamRemoved, 1) |
||||||
|
sub := sm.SubscribeRemoveStreamEvent(rmStreamEvtC) |
||||||
|
defer sub.Unsubscribe() |
||||||
|
stopC := make(chan struct{}, 1) |
||||||
|
|
||||||
|
var numStreamRemoved uint32 |
||||||
|
go func() { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-rmStreamEvtC: |
||||||
|
atomic.AddUint32(&numStreamRemoved, 1) |
||||||
|
case <-stopC: |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
|
||||||
|
err := sm.RemoveStream(makeStreamID(1)) |
||||||
|
if err != nil { |
||||||
|
t.Fatal(err) |
||||||
|
} |
||||||
|
time.Sleep(defTestWait) |
||||||
|
close(stopC) |
||||||
|
sm.Close() |
||||||
|
|
||||||
|
if atomic.LoadUint32(&numStreamRemoved) != 1 { |
||||||
|
t.Errorf("numStreamAdded unexpected") |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,50 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
p2ptypes "github.com/harmony-one/harmony/p2p/types" |
||||||
|
"github.com/libp2p/go-libp2p-core/network" |
||||||
|
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||||
|
"github.com/libp2p/go-libp2p-core/protocol" |
||||||
|
) |
||||||
|
|
||||||
|
// StreamManager is the interface for streamManager
|
||||||
|
type StreamManager interface { |
||||||
|
p2ptypes.LifeCycle |
||||||
|
StreamOperator |
||||||
|
Subscriber |
||||||
|
StreamReader |
||||||
|
} |
||||||
|
|
||||||
|
// StreamOperator handles new stream or remove stream
|
||||||
|
type StreamOperator interface { |
||||||
|
NewStream(stream sttypes.Stream) error |
||||||
|
RemoveStream(stID sttypes.StreamID) error |
||||||
|
} |
||||||
|
|
||||||
|
// Subscriber is the interface to support stream event subscription
|
||||||
|
type Subscriber interface { |
||||||
|
SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription |
||||||
|
SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription |
||||||
|
} |
||||||
|
|
||||||
|
// StreamReader is the interface to read stream in stream manager
|
||||||
|
type StreamReader interface { |
||||||
|
GetStreams() []sttypes.Stream |
||||||
|
GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) |
||||||
|
} |
||||||
|
|
||||||
|
// host is the adapter interface of the libp2p host implementation.
|
||||||
|
// TODO: further adapt the host
|
||||||
|
type host interface { |
||||||
|
ID() libp2p_peer.ID |
||||||
|
NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) |
||||||
|
} |
||||||
|
|
||||||
|
// peerFinder is the adapter interface of discovery.Discovery
|
||||||
|
type peerFinder interface { |
||||||
|
FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) |
||||||
|
} |
@ -0,0 +1,203 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"strconv" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
|
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/libp2p/go-libp2p-core/network" |
||||||
|
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||||
|
"github.com/libp2p/go-libp2p-core/protocol" |
||||||
|
) |
||||||
|
|
||||||
|
var _ StreamManager = &streamManager{} |
||||||
|
|
||||||
|
var ( |
||||||
|
myPeerID = makePeerID(0) |
||||||
|
testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
defHardLoCap = 16 // discovery trigger immediately when size smaller than this number
|
||||||
|
defSoftLoCap = 32 // discovery trigger for routine check
|
||||||
|
defHiCap = 128 // Hard cap of the stream number
|
||||||
|
defDiscBatch = 16 // batch size for discovery
|
||||||
|
) |
||||||
|
|
||||||
|
var defConfig = Config{ |
||||||
|
HardLoCap: defHardLoCap, |
||||||
|
SoftLoCap: defSoftLoCap, |
||||||
|
HiCap: defHiCap, |
||||||
|
DiscBatch: defDiscBatch, |
||||||
|
} |
||||||
|
|
||||||
|
func newTestStreamManager() *streamManager { |
||||||
|
pid := testProtoID |
||||||
|
host := newTestHost() |
||||||
|
pf := newTestPeerFinder(makeRemotePeers(100), emptyDelayFunc) |
||||||
|
|
||||||
|
sm := newStreamManager(pid, host, pf, nil, defConfig) |
||||||
|
host.sm = sm |
||||||
|
return sm |
||||||
|
} |
||||||
|
|
||||||
|
type testStream struct { |
||||||
|
id sttypes.StreamID |
||||||
|
proto sttypes.ProtoID |
||||||
|
closed bool |
||||||
|
} |
||||||
|
|
||||||
|
func newTestStream(id sttypes.StreamID, proto sttypes.ProtoID) *testStream { |
||||||
|
return &testStream{id: id, proto: proto} |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ID() sttypes.StreamID { |
||||||
|
return st.id |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ProtoID() sttypes.ProtoID { |
||||||
|
return st.proto |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) WriteBytes([]byte) error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ReadBytes() ([]byte, error) { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) Close() error { |
||||||
|
if st.closed { |
||||||
|
return errors.New("already closed") |
||||||
|
} |
||||||
|
st.closed = true |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ResetOnClose() error { |
||||||
|
if st.closed { |
||||||
|
return errors.New("already closed") |
||||||
|
} |
||||||
|
st.closed = true |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) { |
||||||
|
return sttypes.ProtoIDToProtoSpec(st.ProtoID()) |
||||||
|
} |
||||||
|
|
||||||
|
type testHost struct { |
||||||
|
sm *streamManager |
||||||
|
streams map[sttypes.StreamID]*testStream |
||||||
|
lock sync.Mutex |
||||||
|
|
||||||
|
errHook streamErrorHook |
||||||
|
} |
||||||
|
|
||||||
|
type streamErrorHook func(id sttypes.StreamID, err error) |
||||||
|
|
||||||
|
func newTestHost() *testHost { |
||||||
|
return &testHost{ |
||||||
|
streams: make(map[sttypes.StreamID]*testStream), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (h *testHost) ID() libp2p_peer.ID { |
||||||
|
return myPeerID |
||||||
|
} |
||||||
|
|
||||||
|
// NewStream mock the upper function logic. When stream setup and running protocol, the
|
||||||
|
// upper code logic will call StreamManager to add new stream
|
||||||
|
func (h *testHost) NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) { |
||||||
|
if len(pids) == 0 { |
||||||
|
return nil, errors.New("nil protocol ids") |
||||||
|
} |
||||||
|
var err error |
||||||
|
stid := sttypes.StreamID(p) |
||||||
|
defer func() { |
||||||
|
if err != nil && h.errHook != nil { |
||||||
|
h.errHook(stid, err) |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
st := newTestStream(stid, sttypes.ProtoID(pids[0])) |
||||||
|
h.lock.Lock() |
||||||
|
h.streams[stid] = st |
||||||
|
h.lock.Unlock() |
||||||
|
|
||||||
|
err = h.sm.NewStream(st) |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
func makeStreamID(index int) sttypes.StreamID { |
||||||
|
return sttypes.StreamID(strconv.Itoa(index)) |
||||||
|
} |
||||||
|
|
||||||
|
func makePeerID(index int) libp2p_peer.ID { |
||||||
|
return libp2p_peer.ID(strconv.Itoa(index)) |
||||||
|
} |
||||||
|
|
||||||
|
func makeRemotePeers(size int) []libp2p_peer.ID { |
||||||
|
ids := make([]libp2p_peer.ID, 0, size) |
||||||
|
for i := 1; i != size+1; i++ { |
||||||
|
ids = append(ids, makePeerID(i)) |
||||||
|
} |
||||||
|
return ids |
||||||
|
} |
||||||
|
|
||||||
|
type testPeerFinder struct { |
||||||
|
peerIDs []libp2p_peer.ID |
||||||
|
curIndex int32 |
||||||
|
fpHook delayFunc |
||||||
|
} |
||||||
|
|
||||||
|
type delayFunc func(id libp2p_peer.ID) <-chan struct{} |
||||||
|
|
||||||
|
func emptyDelayFunc(id libp2p_peer.ID) <-chan struct{} { |
||||||
|
c := make(chan struct{}) |
||||||
|
go func() { |
||||||
|
c <- struct{}{} |
||||||
|
}() |
||||||
|
return c |
||||||
|
} |
||||||
|
|
||||||
|
func newTestPeerFinder(ids []libp2p_peer.ID, fpHook delayFunc) *testPeerFinder { |
||||||
|
return &testPeerFinder{ |
||||||
|
peerIDs: ids, |
||||||
|
curIndex: 0, |
||||||
|
fpHook: fpHook, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (pf *testPeerFinder) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { |
||||||
|
if peerLimit > len(pf.peerIDs) { |
||||||
|
peerLimit = len(pf.peerIDs) |
||||||
|
} |
||||||
|
resC := make(chan libp2p_peer.AddrInfo) |
||||||
|
|
||||||
|
go func() { |
||||||
|
defer close(resC) |
||||||
|
|
||||||
|
for i := 0; i != peerLimit; i++ { |
||||||
|
// hack to prevent race
|
||||||
|
curIndex := atomic.LoadInt32(&pf.curIndex) |
||||||
|
pid := pf.peerIDs[curIndex] |
||||||
|
select { |
||||||
|
case <-ctx.Done(): |
||||||
|
return |
||||||
|
case <-pf.fpHook(pid): |
||||||
|
} |
||||||
|
resC <- libp2p_peer.AddrInfo{ID: pid} |
||||||
|
atomic.AddInt32(&pf.curIndex, 1) |
||||||
|
if int(atomic.LoadInt32(&pf.curIndex)) == len(pf.peerIDs) { |
||||||
|
pf.curIndex = 0 |
||||||
|
} |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
return resC, nil |
||||||
|
} |
@ -0,0 +1,416 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/event" |
||||||
|
"github.com/harmony-one/harmony/internal/utils" |
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
"github.com/libp2p/go-libp2p-core/network" |
||||||
|
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||||
|
"github.com/libp2p/go-libp2p-core/protocol" |
||||||
|
"github.com/pkg/errors" |
||||||
|
"github.com/rs/zerolog" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// ErrStreamAlreadyRemoved is the error that a stream has already been removed
|
||||||
|
ErrStreamAlreadyRemoved = errors.New("stream already removed") |
||||||
|
) |
||||||
|
|
||||||
|
// streamManager is the implementation of StreamManager. It manages streams on
|
||||||
|
// one single protocol. It does the following job:
|
||||||
|
// 1. add a new stream.
|
||||||
|
// 2. closes a stream.
|
||||||
|
// 3. discover and connect new streams when the number of streams is below threshold.
|
||||||
|
// 4. emit stream events to inform other modules.
|
||||||
|
// 5. reset all streams on close.
|
||||||
|
type streamManager struct { |
||||||
|
// streamManager only manages streams on one protocol.
|
||||||
|
myProtoID sttypes.ProtoID |
||||||
|
myProtoSpec sttypes.ProtoSpec |
||||||
|
config Config |
||||||
|
// streams is the map of peer ID to stream
|
||||||
|
// Note that it could happen that remote node does not share exactly the same
|
||||||
|
// protocol ID (e.g. different version)
|
||||||
|
streams *streamSet |
||||||
|
// libp2p utilities
|
||||||
|
host host |
||||||
|
pf peerFinder |
||||||
|
handleStream func(stream network.Stream) |
||||||
|
// incoming task channels
|
||||||
|
addStreamCh chan addStreamTask |
||||||
|
rmStreamCh chan rmStreamTask |
||||||
|
stopCh chan stopTask |
||||||
|
discCh chan discTask |
||||||
|
curTask interface{} |
||||||
|
// utils
|
||||||
|
addStreamFeed event.Feed |
||||||
|
removeStreamFeed event.Feed |
||||||
|
logger zerolog.Logger |
||||||
|
ctx context.Context |
||||||
|
cancel func() |
||||||
|
} |
||||||
|
|
||||||
|
// NewStreamManager creates a new stream manager for the given proto ID
|
||||||
|
func NewStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) StreamManager { |
||||||
|
return newStreamManager(pid, host, pf, handleStream, c) |
||||||
|
} |
||||||
|
|
||||||
|
// newStreamManager creates a new stream manager
|
||||||
|
func newStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) *streamManager { |
||||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||||
|
|
||||||
|
logger := utils.Logger().With().Str("module", "stream manager"). |
||||||
|
Str("protocol ID", string(pid)).Logger() |
||||||
|
|
||||||
|
protoSpec, _ := sttypes.ProtoIDToProtoSpec(pid) |
||||||
|
|
||||||
|
return &streamManager{ |
||||||
|
myProtoID: pid, |
||||||
|
myProtoSpec: protoSpec, |
||||||
|
config: c, |
||||||
|
streams: newStreamSet(), |
||||||
|
host: host, |
||||||
|
pf: pf, |
||||||
|
handleStream: handleStream, |
||||||
|
addStreamCh: make(chan addStreamTask), |
||||||
|
rmStreamCh: make(chan rmStreamTask), |
||||||
|
stopCh: make(chan stopTask), |
||||||
|
discCh: make(chan discTask, 1), // discCh is a buffered channel to avoid overuse of goroutine
|
||||||
|
|
||||||
|
logger: logger, |
||||||
|
ctx: ctx, |
||||||
|
cancel: cancel, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Start starts the stream manager
|
||||||
|
func (sm *streamManager) Start() { |
||||||
|
go sm.loop() |
||||||
|
} |
||||||
|
|
||||||
|
// Close close the stream manager
|
||||||
|
func (sm *streamManager) Close() { |
||||||
|
task := stopTask{done: make(chan struct{})} |
||||||
|
sm.stopCh <- task |
||||||
|
|
||||||
|
<-task.done |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) loop() { |
||||||
|
var ( |
||||||
|
discTicker = time.NewTicker(checkInterval) |
||||||
|
discCtx context.Context |
||||||
|
discCancel func() |
||||||
|
) |
||||||
|
// bootstrap discovery
|
||||||
|
sm.discCh <- discTask{} |
||||||
|
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-discTicker.C: |
||||||
|
if !sm.softHaveEnoughStreams() { |
||||||
|
sm.discCh <- discTask{} |
||||||
|
} |
||||||
|
|
||||||
|
case <-sm.discCh: |
||||||
|
// cancel last discovery
|
||||||
|
if discCancel != nil { |
||||||
|
discCancel() |
||||||
|
} |
||||||
|
discCtx, discCancel = context.WithCancel(sm.ctx) |
||||||
|
go func() { |
||||||
|
err := sm.discoverAndSetupStream(discCtx) |
||||||
|
if err != nil { |
||||||
|
sm.logger.Err(err) |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
case addStream := <-sm.addStreamCh: |
||||||
|
err := sm.handleAddStream(addStream.st) |
||||||
|
addStream.errC <- err |
||||||
|
|
||||||
|
case rmStream := <-sm.rmStreamCh: |
||||||
|
err := sm.handleRemoveStream(rmStream.id) |
||||||
|
rmStream.errC <- err |
||||||
|
|
||||||
|
case stop := <-sm.stopCh: |
||||||
|
sm.cancel() |
||||||
|
sm.removeAllStreamOnClose() |
||||||
|
stop.done <- struct{}{} |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// NewStream handles a new stream from stream handler protocol
|
||||||
|
func (sm *streamManager) NewStream(stream sttypes.Stream) error { |
||||||
|
if err := sm.sanityCheckStream(stream); err != nil { |
||||||
|
return errors.Wrap(err, "stream sanity check failed") |
||||||
|
} |
||||||
|
task := addStreamTask{ |
||||||
|
st: stream, |
||||||
|
errC: make(chan error), |
||||||
|
} |
||||||
|
sm.addStreamCh <- task |
||||||
|
return <-task.errC |
||||||
|
} |
||||||
|
|
||||||
|
// RemoveStream close and remove a stream from stream manager
|
||||||
|
func (sm *streamManager) RemoveStream(stID sttypes.StreamID) error { |
||||||
|
task := rmStreamTask{ |
||||||
|
id: stID, |
||||||
|
errC: make(chan error), |
||||||
|
} |
||||||
|
sm.rmStreamCh <- task |
||||||
|
return <-task.errC |
||||||
|
} |
||||||
|
|
||||||
|
// GetStreams return the streams.
|
||||||
|
func (sm *streamManager) GetStreams() []sttypes.Stream { |
||||||
|
return sm.streams.getStreams() |
||||||
|
} |
||||||
|
|
||||||
|
// GetStreamByID return the stream with the given id.
|
||||||
|
func (sm *streamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) { |
||||||
|
return sm.streams.get(id) |
||||||
|
} |
||||||
|
|
||||||
|
type ( |
||||||
|
addStreamTask struct { |
||||||
|
st sttypes.Stream |
||||||
|
errC chan error |
||||||
|
} |
||||||
|
|
||||||
|
rmStreamTask struct { |
||||||
|
id sttypes.StreamID |
||||||
|
errC chan error |
||||||
|
} |
||||||
|
|
||||||
|
discTask struct{} |
||||||
|
|
||||||
|
stopTask struct { |
||||||
|
done chan struct{} |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
// sanity checks the service, network and shard ID
|
||||||
|
func (sm *streamManager) sanityCheckStream(st sttypes.Stream) error { |
||||||
|
mySpec := sm.myProtoSpec |
||||||
|
rmSpec, err := st.ProtoSpec() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if mySpec.Service != rmSpec.Service { |
||||||
|
return fmt.Errorf("unexpected service: %v/%v", rmSpec.Service, mySpec.Service) |
||||||
|
} |
||||||
|
if mySpec.NetworkType != rmSpec.NetworkType { |
||||||
|
return fmt.Errorf("unexpected network: %v/%v", rmSpec.NetworkType, mySpec.NetworkType) |
||||||
|
} |
||||||
|
if mySpec.ShardID != rmSpec.ShardID { |
||||||
|
return fmt.Errorf("unexpected shard ID: %v/%v", rmSpec.ShardID, mySpec.ShardID) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) handleAddStream(st sttypes.Stream) error { |
||||||
|
id := st.ID() |
||||||
|
if sm.streams.size() >= sm.config.HiCap { |
||||||
|
return errors.New("too many streams") |
||||||
|
} |
||||||
|
if _, ok := sm.streams.get(id); ok { |
||||||
|
return errors.New("stream already exist") |
||||||
|
} |
||||||
|
|
||||||
|
sm.streams.addStream(st) |
||||||
|
|
||||||
|
sm.addStreamFeed.Send(EvtStreamAdded{st}) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) handleRemoveStream(id sttypes.StreamID) error { |
||||||
|
st, ok := sm.streams.get(id) |
||||||
|
if !ok { |
||||||
|
return ErrStreamAlreadyRemoved |
||||||
|
} |
||||||
|
|
||||||
|
sm.streams.deleteStream(st) |
||||||
|
// if stream number is smaller than HardLoCap, spin up the discover
|
||||||
|
if !sm.hardHaveEnoughStream() { |
||||||
|
select { |
||||||
|
case sm.discCh <- discTask{}: |
||||||
|
default: |
||||||
|
} |
||||||
|
} |
||||||
|
sm.removeStreamFeed.Send(EvtStreamRemoved{id}) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) removeAllStreamOnClose() { |
||||||
|
var wg sync.WaitGroup |
||||||
|
|
||||||
|
for _, st := range sm.streams.slice() { |
||||||
|
wg.Add(1) |
||||||
|
go func(st sttypes.Stream) { |
||||||
|
defer wg.Done() |
||||||
|
err := st.ResetOnClose() |
||||||
|
if err != nil { |
||||||
|
sm.logger.Warn().Err(err).Str("stream ID", string(st.ID())). |
||||||
|
Msg("failed to close stream") |
||||||
|
} |
||||||
|
}(st) |
||||||
|
} |
||||||
|
wg.Wait() |
||||||
|
|
||||||
|
// Be nice. after close, the field is still accessible to prevent potential panics
|
||||||
|
sm.streams = newStreamSet() |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) discoverAndSetupStream(discCtx context.Context) error { |
||||||
|
peers, err := sm.discover(discCtx) |
||||||
|
if err != nil { |
||||||
|
return errors.Wrap(err, "failed to discover") |
||||||
|
} |
||||||
|
for peer := range peers { |
||||||
|
if peer.ID == sm.host.ID() { |
||||||
|
continue |
||||||
|
} |
||||||
|
go func(pid libp2p_peer.ID) { |
||||||
|
// The ctx here is using the module context instead of discover context
|
||||||
|
err := sm.setupStreamWithPeer(sm.ctx, pid) |
||||||
|
if err != nil { |
||||||
|
sm.logger.Warn().Err(err).Str("peerID", string(pid)).Msg("failed to setup stream with peer") |
||||||
|
return |
||||||
|
} |
||||||
|
}(peer.ID) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) discover(ctx context.Context) (<-chan libp2p_peer.AddrInfo, error) { |
||||||
|
protoID := string(sm.myProtoID) |
||||||
|
discBatch := sm.config.DiscBatch |
||||||
|
if sm.config.HiCap-sm.streams.size() < sm.config.DiscBatch { |
||||||
|
discBatch = sm.config.HiCap - sm.streams.size() |
||||||
|
} |
||||||
|
if discBatch < 0 { |
||||||
|
return nil, nil |
||||||
|
} |
||||||
|
|
||||||
|
ctx, _ = context.WithTimeout(ctx, discTimeout) |
||||||
|
return sm.pf.FindPeers(ctx, protoID, discBatch) |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) setupStreamWithPeer(ctx context.Context, pid libp2p_peer.ID) error { |
||||||
|
nCtx, cancel := context.WithTimeout(ctx, connectTimeout) |
||||||
|
defer cancel() |
||||||
|
|
||||||
|
st, err := sm.host.NewStream(nCtx, pid, protocol.ID(sm.myProtoID)) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if sm.handleStream != nil { |
||||||
|
go sm.handleStream(st) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) softHaveEnoughStreams() bool { |
||||||
|
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec) |
||||||
|
return availStreams >= sm.config.SoftLoCap |
||||||
|
} |
||||||
|
|
||||||
|
func (sm *streamManager) hardHaveEnoughStream() bool { |
||||||
|
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec) |
||||||
|
return availStreams >= sm.config.HardLoCap |
||||||
|
} |
||||||
|
|
||||||
|
// streamSet is the concurrency safe stream set.
|
||||||
|
type streamSet struct { |
||||||
|
streams map[sttypes.StreamID]sttypes.Stream |
||||||
|
numByProto map[sttypes.ProtoSpec]int |
||||||
|
lock sync.RWMutex |
||||||
|
} |
||||||
|
|
||||||
|
func newStreamSet() *streamSet { |
||||||
|
return &streamSet{ |
||||||
|
streams: make(map[sttypes.StreamID]sttypes.Stream), |
||||||
|
numByProto: make(map[sttypes.ProtoSpec]int), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) size() int { |
||||||
|
ss.lock.RLock() |
||||||
|
defer ss.lock.RUnlock() |
||||||
|
|
||||||
|
return len(ss.streams) |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) get(id sttypes.StreamID) (sttypes.Stream, bool) { |
||||||
|
ss.lock.RLock() |
||||||
|
defer ss.lock.RUnlock() |
||||||
|
|
||||||
|
st, ok := ss.streams[id] |
||||||
|
return st, ok |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) addStream(st sttypes.Stream) { |
||||||
|
ss.lock.Lock() |
||||||
|
defer ss.lock.Unlock() |
||||||
|
|
||||||
|
ss.streams[st.ID()] = st |
||||||
|
spec, _ := st.ProtoSpec() |
||||||
|
ss.numByProto[spec]++ |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) deleteStream(st sttypes.Stream) { |
||||||
|
ss.lock.Lock() |
||||||
|
defer ss.lock.Unlock() |
||||||
|
|
||||||
|
delete(ss.streams, st.ID()) |
||||||
|
|
||||||
|
spec, _ := st.ProtoSpec() |
||||||
|
ss.numByProto[spec]-- |
||||||
|
if ss.numByProto[spec] == 0 { |
||||||
|
delete(ss.numByProto, spec) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) slice() []sttypes.Stream { |
||||||
|
ss.lock.RLock() |
||||||
|
defer ss.lock.RUnlock() |
||||||
|
|
||||||
|
sts := make([]sttypes.Stream, 0, len(ss.streams)) |
||||||
|
for _, st := range ss.streams { |
||||||
|
sts = append(sts, st) |
||||||
|
} |
||||||
|
return sts |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) getStreams() []sttypes.Stream { |
||||||
|
ss.lock.RLock() |
||||||
|
defer ss.lock.RUnlock() |
||||||
|
|
||||||
|
res := make([]sttypes.Stream, 0, len(ss.streams)) |
||||||
|
for _, st := range ss.streams { |
||||||
|
res = append(res, st) |
||||||
|
} |
||||||
|
return res |
||||||
|
} |
||||||
|
|
||||||
|
func (ss *streamSet) numStreamsWithMinProtoSpec(minSpec sttypes.ProtoSpec) int { |
||||||
|
ss.lock.RLock() |
||||||
|
defer ss.lock.RUnlock() |
||||||
|
|
||||||
|
var res int |
||||||
|
for spec, num := range ss.numByProto { |
||||||
|
if !spec.Version.LessThan(minSpec.Version) { |
||||||
|
res += num |
||||||
|
} |
||||||
|
} |
||||||
|
return res |
||||||
|
} |
@ -0,0 +1,243 @@ |
|||||||
|
package streammanager |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types" |
||||||
|
libp2p_peer "github.com/libp2p/go-libp2p-core/peer" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
defTestWait = 100 * time.Millisecond |
||||||
|
) |
||||||
|
|
||||||
|
// When started, discover will be run at bootstrap
|
||||||
|
func TestStreamManager_BootstrapDisc(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
sm.host.(*testHost).errHook = func(id sttypes.StreamID, err error) { |
||||||
|
t.Errorf("%s stream error: %v", id, err) |
||||||
|
} |
||||||
|
|
||||||
|
// After bootstrap, stream manager shall discover streams and setup connection
|
||||||
|
// Note host will mock the upper code logic to call sm.NewStream in this case
|
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
if gotSize := sm.streams.size(); gotSize != sm.config.DiscBatch { |
||||||
|
t.Errorf("unexpected stream size: %v != %v", gotSize, sm.config.DiscBatch) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// After close, all stream shall be closed and removed
|
||||||
|
func TestStreamManager_Close(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
// Bootstrap
|
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
// Close stream manager, all stream shall be closed and removed
|
||||||
|
closeDone := make(chan struct{}) |
||||||
|
go func() { |
||||||
|
sm.Close() |
||||||
|
closeDone <- struct{}{} |
||||||
|
}() |
||||||
|
select { |
||||||
|
case <-time.After(defTestWait): |
||||||
|
t.Errorf("still not closed") |
||||||
|
case <-closeDone: |
||||||
|
} |
||||||
|
// Check stream been removed from stream manager and all streams to be closed
|
||||||
|
if sm.streams.size() != 0 { |
||||||
|
t.Errorf("after close, stream not removed from stream manager") |
||||||
|
} |
||||||
|
host := sm.host.(*testHost) |
||||||
|
for _, st := range host.streams { |
||||||
|
if !st.closed { |
||||||
|
t.Errorf("after close, stream still not closed") |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Close shall terminate the current discover at once
|
||||||
|
func TestStreamManager_CloseDisc(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
// discover will be blocked forever
|
||||||
|
sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} { |
||||||
|
select {} |
||||||
|
} |
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
// Close stream manager, all stream shall be closed and removed
|
||||||
|
closeDone := make(chan struct{}) |
||||||
|
go func() { |
||||||
|
sm.Close() |
||||||
|
closeDone <- struct{}{} |
||||||
|
}() |
||||||
|
select { |
||||||
|
case <-time.After(defTestWait): |
||||||
|
t.Errorf("close shall unblock the current discovery") |
||||||
|
case <-closeDone: |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Each time discTicker ticks, it will cancel last discovery, and start a new one
|
||||||
|
func TestStreamManager_refreshDisc(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
// discover will be blocked for the first time but good for second time
|
||||||
|
var once sync.Once |
||||||
|
sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} { |
||||||
|
var sendSig = true |
||||||
|
once.Do(func() { |
||||||
|
sendSig = false |
||||||
|
}) |
||||||
|
c := make(chan struct{}, 1) |
||||||
|
if sendSig { |
||||||
|
c <- struct{}{} |
||||||
|
} |
||||||
|
return c |
||||||
|
} |
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
|
||||||
|
sm.discCh <- struct{}{} |
||||||
|
time.Sleep(defTestWait) |
||||||
|
|
||||||
|
// We shall now have non-zero streams setup
|
||||||
|
if sm.streams.size() == 0 { |
||||||
|
t.Errorf("stream size still zero after refresh") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestStreamManager_HandleNewStream(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
stream sttypes.Stream |
||||||
|
expSize int |
||||||
|
expErr error |
||||||
|
}{ |
||||||
|
{ |
||||||
|
stream: newTestStream(makeStreamID(100), testProtoID), |
||||||
|
expSize: defDiscBatch + 1, |
||||||
|
expErr: nil, |
||||||
|
}, |
||||||
|
{ |
||||||
|
stream: newTestStream(makeStreamID(1), testProtoID), |
||||||
|
expSize: defDiscBatch, |
||||||
|
expErr: errors.New("stream already exist"), |
||||||
|
}, |
||||||
|
} |
||||||
|
for i, test := range tests { |
||||||
|
sm := newTestStreamManager() |
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
|
||||||
|
err := sm.NewStream(test.stream) |
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
} |
||||||
|
|
||||||
|
if sm.streams.size() != test.expSize { |
||||||
|
t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(), |
||||||
|
test.expSize) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestStreamManager_HandleRemoveStream(t *testing.T) { |
||||||
|
tests := []struct { |
||||||
|
id sttypes.StreamID |
||||||
|
expSize int |
||||||
|
expErr error |
||||||
|
}{ |
||||||
|
{ |
||||||
|
id: makeStreamID(1), |
||||||
|
expSize: defDiscBatch - 1, |
||||||
|
expErr: nil, |
||||||
|
}, |
||||||
|
{ |
||||||
|
id: makeStreamID(100), |
||||||
|
expSize: defDiscBatch, |
||||||
|
expErr: errors.New("stream already removed"), |
||||||
|
}, |
||||||
|
} |
||||||
|
for i, test := range tests { |
||||||
|
sm := newTestStreamManager() |
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
|
||||||
|
err := sm.RemoveStream(test.id) |
||||||
|
if assErr := assertError(err, test.expErr); assErr != nil { |
||||||
|
t.Errorf("Test %v: %v", i, assErr) |
||||||
|
} |
||||||
|
|
||||||
|
if sm.streams.size() != test.expSize { |
||||||
|
t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(), |
||||||
|
test.expSize) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// When number of streams is smaller than hard low limit, discover will be triggered
|
||||||
|
func TestStreamManager_HandleRemoveStream_Disc(t *testing.T) { |
||||||
|
sm := newTestStreamManager() |
||||||
|
sm.Start() |
||||||
|
time.Sleep(defTestWait) |
||||||
|
|
||||||
|
// Remove DiscBatch - HardLoCap + 1 streams
|
||||||
|
num := 0 |
||||||
|
for _, st := range sm.streams.slice() { |
||||||
|
if err := sm.RemoveStream(st.ID()); err != nil { |
||||||
|
t.Error(err) |
||||||
|
} |
||||||
|
num++ |
||||||
|
if num == sm.config.DiscBatch-sm.config.HardLoCap+1 { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Last remove stream will also trigger discover
|
||||||
|
time.Sleep(defTestWait) |
||||||
|
if sm.streams.size() != sm.config.HardLoCap+sm.config.DiscBatch-1 { |
||||||
|
t.Errorf("unexpected stream number %v / %v", sm.streams.size(), sm.config.HardLoCap+sm.config.DiscBatch-1) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestStreamSet_numStreamsWithMinProtoID(t *testing.T) { |
||||||
|
var ( |
||||||
|
pid1 = testProtoID |
||||||
|
numPid1 = 5 |
||||||
|
|
||||||
|
pid2 = sttypes.ProtoID("harmony/sync/unitest/0/1.0.1") |
||||||
|
numPid2 = 10 |
||||||
|
) |
||||||
|
|
||||||
|
ss := newStreamSet() |
||||||
|
|
||||||
|
for i := 0; i != numPid1; i++ { |
||||||
|
ss.addStream(newTestStream(makeStreamID(i), pid1)) |
||||||
|
} |
||||||
|
for i := 0; i != numPid2; i++ { |
||||||
|
ss.addStream(newTestStream(makeStreamID(i), pid2)) |
||||||
|
} |
||||||
|
|
||||||
|
minSpec, _ := sttypes.ProtoIDToProtoSpec(pid2) |
||||||
|
num := ss.numStreamsWithMinProtoSpec(minSpec) |
||||||
|
if num != numPid2 { |
||||||
|
t.Errorf("unexpected result: %v/%v", num, numPid2) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func assertError(got, exp error) error { |
||||||
|
if (got == nil) != (exp == nil) { |
||||||
|
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||||
|
} |
||||||
|
if got == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
if !strings.Contains(got.Error(), exp.Error()) { |
||||||
|
return fmt.Errorf("unexpected error: %v / %v", got, exp) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
Loading…
Reference in new issue