You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
417 lines
10 KiB
417 lines
10 KiB
4 years ago
|
package streammanager
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/ethereum/go-ethereum/event"
|
||
|
"github.com/harmony-one/harmony/internal/utils"
|
||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
|
||
|
"github.com/libp2p/go-libp2p-core/network"
|
||
|
libp2p_peer "github.com/libp2p/go-libp2p-core/peer"
|
||
|
"github.com/libp2p/go-libp2p-core/protocol"
|
||
|
"github.com/pkg/errors"
|
||
|
"github.com/rs/zerolog"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// ErrStreamAlreadyRemoved is the error that a stream has already been removed
|
||
|
ErrStreamAlreadyRemoved = errors.New("stream already removed")
|
||
|
)
|
||
|
|
||
|
// streamManager is the implementation of StreamManager. It manages streams on
|
||
|
// one single protocol. It does the following job:
|
||
|
// 1. add a new stream.
|
||
|
// 2. closes a stream.
|
||
|
// 3. discover and connect new streams when the number of streams is below threshold.
|
||
|
// 4. emit stream events to inform other modules.
|
||
|
// 5. reset all streams on close.
|
||
|
type streamManager struct {
|
||
|
// streamManager only manages streams on one protocol.
|
||
|
myProtoID sttypes.ProtoID
|
||
|
myProtoSpec sttypes.ProtoSpec
|
||
|
config Config
|
||
|
// streams is the map of peer ID to stream
|
||
|
// Note that it could happen that remote node does not share exactly the same
|
||
|
// protocol ID (e.g. different version)
|
||
|
streams *streamSet
|
||
|
// libp2p utilities
|
||
|
host host
|
||
|
pf peerFinder
|
||
|
handleStream func(stream network.Stream)
|
||
|
// incoming task channels
|
||
|
addStreamCh chan addStreamTask
|
||
|
rmStreamCh chan rmStreamTask
|
||
|
stopCh chan stopTask
|
||
|
discCh chan discTask
|
||
|
curTask interface{}
|
||
|
// utils
|
||
|
addStreamFeed event.Feed
|
||
|
removeStreamFeed event.Feed
|
||
|
logger zerolog.Logger
|
||
|
ctx context.Context
|
||
|
cancel func()
|
||
|
}
|
||
|
|
||
|
// NewStreamManager creates a new stream manager for the given proto ID
|
||
|
func NewStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) StreamManager {
|
||
|
return newStreamManager(pid, host, pf, handleStream, c)
|
||
|
}
|
||
|
|
||
|
// newStreamManager creates a new stream manager
|
||
|
func newStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) *streamManager {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
||
|
logger := utils.Logger().With().Str("module", "stream manager").
|
||
|
Str("protocol ID", string(pid)).Logger()
|
||
|
|
||
|
protoSpec, _ := sttypes.ProtoIDToProtoSpec(pid)
|
||
|
|
||
|
return &streamManager{
|
||
|
myProtoID: pid,
|
||
|
myProtoSpec: protoSpec,
|
||
|
config: c,
|
||
|
streams: newStreamSet(),
|
||
|
host: host,
|
||
|
pf: pf,
|
||
|
handleStream: handleStream,
|
||
|
addStreamCh: make(chan addStreamTask),
|
||
|
rmStreamCh: make(chan rmStreamTask),
|
||
|
stopCh: make(chan stopTask),
|
||
|
discCh: make(chan discTask, 1), // discCh is a buffered channel to avoid overuse of goroutine
|
||
|
|
||
|
logger: logger,
|
||
|
ctx: ctx,
|
||
|
cancel: cancel,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Start starts the stream manager
|
||
|
func (sm *streamManager) Start() {
|
||
|
go sm.loop()
|
||
|
}
|
||
|
|
||
|
// Close close the stream manager
|
||
|
func (sm *streamManager) Close() {
|
||
|
task := stopTask{done: make(chan struct{})}
|
||
|
sm.stopCh <- task
|
||
|
|
||
|
<-task.done
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) loop() {
|
||
|
var (
|
||
|
discTicker = time.NewTicker(checkInterval)
|
||
|
discCtx context.Context
|
||
|
discCancel func()
|
||
|
)
|
||
|
// bootstrap discovery
|
||
|
sm.discCh <- discTask{}
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-discTicker.C:
|
||
|
if !sm.softHaveEnoughStreams() {
|
||
|
sm.discCh <- discTask{}
|
||
|
}
|
||
|
|
||
|
case <-sm.discCh:
|
||
|
// cancel last discovery
|
||
|
if discCancel != nil {
|
||
|
discCancel()
|
||
|
}
|
||
|
discCtx, discCancel = context.WithCancel(sm.ctx)
|
||
|
go func() {
|
||
|
err := sm.discoverAndSetupStream(discCtx)
|
||
|
if err != nil {
|
||
|
sm.logger.Err(err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
case addStream := <-sm.addStreamCh:
|
||
|
err := sm.handleAddStream(addStream.st)
|
||
|
addStream.errC <- err
|
||
|
|
||
|
case rmStream := <-sm.rmStreamCh:
|
||
|
err := sm.handleRemoveStream(rmStream.id)
|
||
|
rmStream.errC <- err
|
||
|
|
||
|
case stop := <-sm.stopCh:
|
||
|
sm.cancel()
|
||
|
sm.removeAllStreamOnClose()
|
||
|
stop.done <- struct{}{}
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NewStream handles a new stream from stream handler protocol
|
||
|
func (sm *streamManager) NewStream(stream sttypes.Stream) error {
|
||
|
if err := sm.sanityCheckStream(stream); err != nil {
|
||
|
return errors.Wrap(err, "stream sanity check failed")
|
||
|
}
|
||
|
task := addStreamTask{
|
||
|
st: stream,
|
||
|
errC: make(chan error),
|
||
|
}
|
||
|
sm.addStreamCh <- task
|
||
|
return <-task.errC
|
||
|
}
|
||
|
|
||
|
// RemoveStream close and remove a stream from stream manager
|
||
|
func (sm *streamManager) RemoveStream(stID sttypes.StreamID) error {
|
||
|
task := rmStreamTask{
|
||
|
id: stID,
|
||
|
errC: make(chan error),
|
||
|
}
|
||
|
sm.rmStreamCh <- task
|
||
|
return <-task.errC
|
||
|
}
|
||
|
|
||
|
// GetStreams return the streams.
|
||
|
func (sm *streamManager) GetStreams() []sttypes.Stream {
|
||
|
return sm.streams.getStreams()
|
||
|
}
|
||
|
|
||
|
// GetStreamByID return the stream with the given id.
|
||
|
func (sm *streamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) {
|
||
|
return sm.streams.get(id)
|
||
|
}
|
||
|
|
||
|
type (
|
||
|
addStreamTask struct {
|
||
|
st sttypes.Stream
|
||
|
errC chan error
|
||
|
}
|
||
|
|
||
|
rmStreamTask struct {
|
||
|
id sttypes.StreamID
|
||
|
errC chan error
|
||
|
}
|
||
|
|
||
|
discTask struct{}
|
||
|
|
||
|
stopTask struct {
|
||
|
done chan struct{}
|
||
|
}
|
||
|
)
|
||
|
|
||
|
// sanity checks the service, network and shard ID
|
||
|
func (sm *streamManager) sanityCheckStream(st sttypes.Stream) error {
|
||
|
mySpec := sm.myProtoSpec
|
||
|
rmSpec, err := st.ProtoSpec()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if mySpec.Service != rmSpec.Service {
|
||
|
return fmt.Errorf("unexpected service: %v/%v", rmSpec.Service, mySpec.Service)
|
||
|
}
|
||
|
if mySpec.NetworkType != rmSpec.NetworkType {
|
||
|
return fmt.Errorf("unexpected network: %v/%v", rmSpec.NetworkType, mySpec.NetworkType)
|
||
|
}
|
||
|
if mySpec.ShardID != rmSpec.ShardID {
|
||
|
return fmt.Errorf("unexpected shard ID: %v/%v", rmSpec.ShardID, mySpec.ShardID)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) handleAddStream(st sttypes.Stream) error {
|
||
|
id := st.ID()
|
||
|
if sm.streams.size() >= sm.config.HiCap {
|
||
|
return errors.New("too many streams")
|
||
|
}
|
||
|
if _, ok := sm.streams.get(id); ok {
|
||
|
return errors.New("stream already exist")
|
||
|
}
|
||
|
|
||
|
sm.streams.addStream(st)
|
||
|
|
||
|
sm.addStreamFeed.Send(EvtStreamAdded{st})
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) handleRemoveStream(id sttypes.StreamID) error {
|
||
|
st, ok := sm.streams.get(id)
|
||
|
if !ok {
|
||
|
return ErrStreamAlreadyRemoved
|
||
|
}
|
||
|
|
||
|
sm.streams.deleteStream(st)
|
||
|
// if stream number is smaller than HardLoCap, spin up the discover
|
||
|
if !sm.hardHaveEnoughStream() {
|
||
|
select {
|
||
|
case sm.discCh <- discTask{}:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
sm.removeStreamFeed.Send(EvtStreamRemoved{id})
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) removeAllStreamOnClose() {
|
||
|
var wg sync.WaitGroup
|
||
|
|
||
|
for _, st := range sm.streams.slice() {
|
||
|
wg.Add(1)
|
||
|
go func(st sttypes.Stream) {
|
||
|
defer wg.Done()
|
||
|
err := st.ResetOnClose()
|
||
|
if err != nil {
|
||
|
sm.logger.Warn().Err(err).Str("stream ID", string(st.ID())).
|
||
|
Msg("failed to close stream")
|
||
|
}
|
||
|
}(st)
|
||
|
}
|
||
|
wg.Wait()
|
||
|
|
||
|
// Be nice. after close, the field is still accessible to prevent potential panics
|
||
|
sm.streams = newStreamSet()
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) discoverAndSetupStream(discCtx context.Context) error {
|
||
|
peers, err := sm.discover(discCtx)
|
||
|
if err != nil {
|
||
|
return errors.Wrap(err, "failed to discover")
|
||
|
}
|
||
|
for peer := range peers {
|
||
|
if peer.ID == sm.host.ID() {
|
||
|
continue
|
||
|
}
|
||
|
go func(pid libp2p_peer.ID) {
|
||
|
// The ctx here is using the module context instead of discover context
|
||
|
err := sm.setupStreamWithPeer(sm.ctx, pid)
|
||
|
if err != nil {
|
||
|
sm.logger.Warn().Err(err).Str("peerID", string(pid)).Msg("failed to setup stream with peer")
|
||
|
return
|
||
|
}
|
||
|
}(peer.ID)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) discover(ctx context.Context) (<-chan libp2p_peer.AddrInfo, error) {
|
||
|
protoID := string(sm.myProtoID)
|
||
|
discBatch := sm.config.DiscBatch
|
||
|
if sm.config.HiCap-sm.streams.size() < sm.config.DiscBatch {
|
||
|
discBatch = sm.config.HiCap - sm.streams.size()
|
||
|
}
|
||
|
if discBatch < 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
ctx, _ = context.WithTimeout(ctx, discTimeout)
|
||
|
return sm.pf.FindPeers(ctx, protoID, discBatch)
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) setupStreamWithPeer(ctx context.Context, pid libp2p_peer.ID) error {
|
||
|
nCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||
|
defer cancel()
|
||
|
|
||
|
st, err := sm.host.NewStream(nCtx, pid, protocol.ID(sm.myProtoID))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if sm.handleStream != nil {
|
||
|
go sm.handleStream(st)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) softHaveEnoughStreams() bool {
|
||
|
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec)
|
||
|
return availStreams >= sm.config.SoftLoCap
|
||
|
}
|
||
|
|
||
|
func (sm *streamManager) hardHaveEnoughStream() bool {
|
||
|
availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec)
|
||
|
return availStreams >= sm.config.HardLoCap
|
||
|
}
|
||
|
|
||
|
// streamSet is the concurrency safe stream set.
|
||
|
type streamSet struct {
|
||
|
streams map[sttypes.StreamID]sttypes.Stream
|
||
|
numByProto map[sttypes.ProtoSpec]int
|
||
|
lock sync.RWMutex
|
||
|
}
|
||
|
|
||
|
func newStreamSet() *streamSet {
|
||
|
return &streamSet{
|
||
|
streams: make(map[sttypes.StreamID]sttypes.Stream),
|
||
|
numByProto: make(map[sttypes.ProtoSpec]int),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) size() int {
|
||
|
ss.lock.RLock()
|
||
|
defer ss.lock.RUnlock()
|
||
|
|
||
|
return len(ss.streams)
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) get(id sttypes.StreamID) (sttypes.Stream, bool) {
|
||
|
ss.lock.RLock()
|
||
|
defer ss.lock.RUnlock()
|
||
|
|
||
|
st, ok := ss.streams[id]
|
||
|
return st, ok
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) addStream(st sttypes.Stream) {
|
||
|
ss.lock.Lock()
|
||
|
defer ss.lock.Unlock()
|
||
|
|
||
|
ss.streams[st.ID()] = st
|
||
|
spec, _ := st.ProtoSpec()
|
||
|
ss.numByProto[spec]++
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) deleteStream(st sttypes.Stream) {
|
||
|
ss.lock.Lock()
|
||
|
defer ss.lock.Unlock()
|
||
|
|
||
|
delete(ss.streams, st.ID())
|
||
|
|
||
|
spec, _ := st.ProtoSpec()
|
||
|
ss.numByProto[spec]--
|
||
|
if ss.numByProto[spec] == 0 {
|
||
|
delete(ss.numByProto, spec)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) slice() []sttypes.Stream {
|
||
|
ss.lock.RLock()
|
||
|
defer ss.lock.RUnlock()
|
||
|
|
||
|
sts := make([]sttypes.Stream, 0, len(ss.streams))
|
||
|
for _, st := range ss.streams {
|
||
|
sts = append(sts, st)
|
||
|
}
|
||
|
return sts
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) getStreams() []sttypes.Stream {
|
||
|
ss.lock.RLock()
|
||
|
defer ss.lock.RUnlock()
|
||
|
|
||
|
res := make([]sttypes.Stream, 0, len(ss.streams))
|
||
|
for _, st := range ss.streams {
|
||
|
res = append(res, st)
|
||
|
}
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func (ss *streamSet) numStreamsWithMinProtoSpec(minSpec sttypes.ProtoSpec) int {
|
||
|
ss.lock.RLock()
|
||
|
defer ss.lock.RUnlock()
|
||
|
|
||
|
var res int
|
||
|
for spec, num := range ss.numByProto {
|
||
|
if !spec.Version.LessThan(minSpec.Version) {
|
||
|
res += num
|
||
|
}
|
||
|
}
|
||
|
return res
|
||
|
}
|