Passed peerID to handlers.

pull/4455/head
frozen 1 year ago
parent 525b20ae20
commit 0e4568253a
No known key found for this signature in database
GPG Key ID: 5391C63E79B03EDE
  1. 2
      api/service/stagedstreamsync/stage_heads.go
  2. 17
      consensus/consensus.go
  3. 5
      consensus/consensus_test.go
  4. 26
      consensus/consensus_v2.go
  5. 8
      consensus/fbft_log.go
  6. 4
      consensus/fbft_log_test.go
  7. 13
      consensus/validator.go
  8. 4
      consensus/view_change_construct.go
  9. 2
      consensus/view_change_msg.go
  10. 33
      internal/utils/blockedpeers/manager.go
  11. 26
      internal/utils/blockedpeers/manager_test.go
  12. 6
      internal/utils/lrucache/lrucache.go
  13. 4
      internal/utils/timer.go
  14. 10
      node/node.go
  15. 20
      p2p/host.go
  16. 29
      p2p/security/security.go
  17. 5
      p2p/security/security_test.go

@ -53,7 +53,7 @@ func (heads *StageHeads) Exec(ctx context.Context, firstCycle bool, invalidBlock
maxHeight := s.state.status.targetBN maxHeight := s.state.status.targetBN
maxBlocksPerSyncCycle := uint64(1024) // TODO: should be in config -> s.state.MaxBlocksPerSyncCycle maxBlocksPerSyncCycle := uint64(1024) // TODO: should be in config -> s.state.MaxBlocksPerSyncCycle
currentHeight := heads.configs.bc.CurrentBlock().NumberU64() currentHeight := heads.configs.bc.CurrentHeader().NumberU64()
s.state.currentCycle.TargetHeight = maxHeight s.state.currentCycle.TargetHeight = maxHeight
targetHeight := uint64(0) targetHeight := uint64(0)
if errV := CreateView(ctx, heads.configs.db, tx, func(etx kv.Tx) (err error) { if errV := CreateView(ctx, heads.configs.db, tx, func(etx kv.Tx) (err error) {

@ -94,8 +94,6 @@ type Consensus struct {
// The post-consensus job func passed from Node object // The post-consensus job func passed from Node object
// Called when consensus on a new block is done // Called when consensus on a new block is done
PostConsensusJob func(*types.Block) error PostConsensusJob func(*types.Block) error
// The verifier func passed from Node object
BlockVerifier VerifyBlockFunc
// verified block to state sync broadcast // verified block to state sync broadcast
VerifiedNewBlock chan *types.Block VerifiedNewBlock chan *types.Block
// will trigger state syncing when blockNum is low // will trigger state syncing when blockNum is low
@ -171,12 +169,12 @@ func (consensus *Consensus) Beaconchain() core.BlockChain {
} }
// VerifyBlock is a function used to verify the block and keep trace of verified blocks. // VerifyBlock is a function used to verify the block and keep trace of verified blocks.
func (consensus *Consensus) verifyBlock(block *types.Block) error { func (FBFTLog *FBFTLog) verifyBlock(block *types.Block) error {
if !consensus.fBFTLog.IsBlockVerified(block.Hash()) { if !FBFTLog.IsBlockVerified(block.Hash()) {
if err := consensus.BlockVerifier(block); err != nil { if err := FBFTLog.BlockVerify(block); err != nil {
return errors.Errorf("Block verification failed: %s", err) return errors.Errorf("Block verification failed: %s", err)
} }
consensus.fBFTLog.MarkBlockVerified(block) FBFTLog.MarkBlockVerified(block)
} }
return nil return nil
} }
@ -304,12 +302,7 @@ func New(
consensus.RndChannel = make(chan [vdfAndSeedSize]byte) consensus.RndChannel = make(chan [vdfAndSeedSize]byte)
consensus.IgnoreViewIDCheck = abool.NewBool(false) consensus.IgnoreViewIDCheck = abool.NewBool(false)
// Make Sure Verifier is not null // Make Sure Verifier is not null
consensus.vc = newViewChange() consensus.vc = newViewChange(consensus.FBFTLog.BlockVerify)
// TODO: reference to blockchain/beaconchain should be removed.
verifier := VerifyNewBlock(registry.GetWebHooks(), consensus.Blockchain(), consensus.Beaconchain())
consensus.BlockVerifier = verifier
consensus.vc.verifyBlock = consensus.verifyBlock
// init prometheus metrics // init prometheus metrics
initMetrics() initMetrics()
consensus.AddPubkeyMetrics() consensus.AddPubkeyMetrics()

@ -22,6 +22,7 @@ func TestConsensusInitialization(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
messageSender := &MessageSender{host: host, retryTimes: int(phaseDuration.Seconds()) / RetryIntervalInSec} messageSender := &MessageSender{host: host, retryTimes: int(phaseDuration.Seconds()) / RetryIntervalInSec}
fbtLog := NewFBFTLog(consensus.FBFTLog.verifyNewBlock)
state := State{mode: Normal} state := State{mode: Normal}
timeouts := createTimeout() timeouts := createTimeout()
@ -36,6 +37,10 @@ func TestConsensusInitialization(t *testing.T) {
assert.IsType(t, make(chan struct{}), consensus.BlockNumLowChan) assert.IsType(t, make(chan struct{}), consensus.BlockNumLowChan)
// FBFTLog // FBFTLog
assert.Equal(t, fbtLog.blocks, consensus.FBFTLog.blocks)
assert.Equal(t, fbtLog.messages, consensus.FBFTLog.messages)
assert.Equal(t, len(fbtLog.verifiedBlocks), 0)
assert.Equal(t, fbtLog.verifiedBlocks, consensus.FBFTLog.verifiedBlocks)
assert.NotNil(t, consensus.FBFTLog()) assert.NotNil(t, consensus.FBFTLog())
assert.Equal(t, FBFTAnnounce, consensus.phase) assert.Equal(t, FBFTAnnounce, consensus.phase)

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/hex" "encoding/hex"
libp2p_peer "github.com/libp2p/go-libp2p/core/peer"
"math/big" "math/big"
"sync/atomic" "sync/atomic"
"time" "time"
@ -55,7 +56,7 @@ func (consensus *Consensus) isViewChangingMode() bool {
} }
// HandleMessageUpdate will update the consensus state according to received message // HandleMessageUpdate will update the consensus state according to received message
func (consensus *Consensus) HandleMessageUpdate(ctx context.Context, msg *msg_pb.Message, senderKey *bls.SerializedPublicKey) error { func (consensus *Consensus) HandleMessageUpdate(ctx context.Context, peer libp2p_peer.ID, msg *msg_pb.Message, senderKey *bls.SerializedPublicKey) error {
consensus.mutex.Lock() consensus.mutex.Lock()
defer consensus.mutex.Unlock() defer consensus.mutex.Unlock()
// when node is in ViewChanging mode, it still accepts normal messages into FBFTLog // when node is in ViewChanging mode, it still accepts normal messages into FBFTLog
@ -393,11 +394,12 @@ func (consensus *Consensus) tick() {
// the bootstrap timer will be stopped once consensus is reached or view change // the bootstrap timer will be stopped once consensus is reached or view change
// is succeeded // is succeeded
if k != timeoutBootstrap { if k != timeoutBootstrap {
consensus.getLogger().Debug(). if v.Stop() { // prevent useless logs
Str("k", k.String()). consensus.getLogger().Debug().
Str("Mode", consensus.current.Mode().String()). Str("k", k.String()).
Msg("[ConsensusMainLoop] consensusTimeout stopped!!!") Str("Mode", consensus.current.Mode().String()).
v.Stop() Msg("[ConsensusMainLoop] consensusTimeout stopped!!!")
}
continue continue
} }
} }
@ -453,7 +455,6 @@ func (consensus *Consensus) BlockChannel(newBlock *types.Block) {
type LastMileBlockIter struct { type LastMileBlockIter struct {
blockCandidates []*types.Block blockCandidates []*types.Block
fbftLog *FBFTLog fbftLog *FBFTLog
verify func(*types.Block) error
curIndex int curIndex int
logger *zerolog.Logger logger *zerolog.Logger
} }
@ -468,9 +469,6 @@ func (consensus *Consensus) GetLastMileBlockIter(bnStart uint64, cb func(iter *L
// GetLastMileBlockIter get the iterator of the last mile blocks starting from number bnStart // GetLastMileBlockIter get the iterator of the last mile blocks starting from number bnStart
func (consensus *Consensus) getLastMileBlockIter(bnStart uint64, cb func(iter *LastMileBlockIter) error) error { func (consensus *Consensus) getLastMileBlockIter(bnStart uint64, cb func(iter *LastMileBlockIter) error) error {
if consensus.BlockVerifier == nil {
return errors.New("consensus haven't initialized yet")
}
blocks, _, err := consensus.getLastMileBlocksAndMsg(bnStart) blocks, _, err := consensus.getLastMileBlocksAndMsg(bnStart)
if err != nil { if err != nil {
return err return err
@ -478,7 +476,6 @@ func (consensus *Consensus) getLastMileBlockIter(bnStart uint64, cb func(iter *L
return cb(&LastMileBlockIter{ return cb(&LastMileBlockIter{
blockCandidates: blocks, blockCandidates: blocks,
fbftLog: consensus.fBFTLog, fbftLog: consensus.fBFTLog,
verify: consensus.BlockVerifier,
curIndex: 0, curIndex: 0,
logger: consensus.getLogger(), logger: consensus.getLogger(),
}) })
@ -493,7 +490,7 @@ func (iter *LastMileBlockIter) Next() *types.Block {
iter.curIndex++ iter.curIndex++
if !iter.fbftLog.IsBlockVerified(block.Hash()) { if !iter.fbftLog.IsBlockVerified(block.Hash()) {
if err := iter.verify(block); err != nil { if err := iter.fbftLog.BlockVerify(block); err != nil {
iter.logger.Debug().Err(err).Msg("block verification failed in consensus last mile block") iter.logger.Debug().Err(err).Msg("block verification failed in consensus last mile block")
return nil return nil
} }
@ -620,9 +617,6 @@ func (consensus *Consensus) verifyLastCommitSig(lastCommitSig []byte, blk *types
// tryCatchup add the last mile block in PBFT log memory cache to blockchain. // tryCatchup add the last mile block in PBFT log memory cache to blockchain.
func (consensus *Consensus) tryCatchup() error { func (consensus *Consensus) tryCatchup() error {
// TODO: change this to a more systematic symbol // TODO: change this to a more systematic symbol
if consensus.BlockVerifier == nil {
return errors.New("consensus haven't finished initialization")
}
initBN := consensus.getBlockNum() initBN := consensus.getBlockNum()
defer consensus.postCatchup(initBN) defer consensus.postCatchup(initBN)
@ -637,7 +631,7 @@ func (consensus *Consensus) tryCatchup() error {
} }
blk.SetCurrentCommitSig(msg.Payload) blk.SetCurrentCommitSig(msg.Payload)
if err := consensus.verifyBlock(blk); err != nil { if err := consensus.FBFTLog.verifyBlock(blk); err != nil {
consensus.getLogger().Err(err).Msg("[TryCatchup] failed block verifier") consensus.getLogger().Err(err).Msg("[TryCatchup] failed block verifier")
return err return err
} }

@ -113,14 +113,16 @@ type FBFTLog struct {
blocks map[common.Hash]*types.Block // store blocks received in FBFT blocks map[common.Hash]*types.Block // store blocks received in FBFT
verifiedBlocks map[common.Hash]struct{} // store block hashes for blocks that has already been verified verifiedBlocks map[common.Hash]struct{} // store block hashes for blocks that has already been verified
messages map[fbftMsgID]*FBFTMessage // store messages received in FBFT messages map[fbftMsgID]*FBFTMessage // store messages received in FBFT
verifyNewBlock func(*types.Block) error // block verification function
} }
// NewFBFTLog returns new instance of FBFTLog // NewFBFTLog returns new instance of FBFTLog
func NewFBFTLog() *FBFTLog { func NewFBFTLog(verifyNewBlock func(*types.Block) error) *FBFTLog {
pbftLog := FBFTLog{ pbftLog := FBFTLog{
blocks: make(map[common.Hash]*types.Block), blocks: make(map[common.Hash]*types.Block),
messages: make(map[fbftMsgID]*FBFTMessage), messages: make(map[fbftMsgID]*FBFTMessage),
verifiedBlocks: make(map[common.Hash]struct{}), verifiedBlocks: make(map[common.Hash]struct{}),
verifyNewBlock: verifyNewBlock,
} }
return &pbftLog return &pbftLog
} }
@ -130,6 +132,10 @@ func (log *FBFTLog) AddBlock(block *types.Block) {
log.blocks[block.Hash()] = block log.blocks[block.Hash()] = block
} }
func (log *FBFTLog) BlockVerify(block *types.Block) error {
return log.verifyNewBlock(block)
}
// MarkBlockVerified marks the block as verified // MarkBlockVerified marks the block as verified
func (log *FBFTLog) MarkBlockVerified(block *types.Block) { func (log *FBFTLog) MarkBlockVerified(block *types.Block) {
log.verifiedBlocks[block.Hash()] = struct{}{} log.verifiedBlocks[block.Hash()] = struct{}{}

@ -65,7 +65,7 @@ func TestGetMessagesByTypeSeqViewHash(t *testing.T) {
ViewID: 3, ViewID: 3,
BlockHash: [32]byte{01, 02}, BlockHash: [32]byte{01, 02},
} }
log := NewFBFTLog() log := NewFBFTLog(nil)
log.AddVerifiedMessage(&pbftMsg) log.AddVerifiedMessage(&pbftMsg)
found := log.GetMessagesByTypeSeqViewHash( found := log.GetMessagesByTypeSeqViewHash(
@ -90,7 +90,7 @@ func TestHasMatchingAnnounce(t *testing.T) {
ViewID: 3, ViewID: 3,
BlockHash: [32]byte{01, 02}, BlockHash: [32]byte{01, 02},
} }
log := NewFBFTLog() log := NewFBFTLog(nil)
log.AddVerifiedMessage(&pbftMsg) log.AddVerifiedMessage(&pbftMsg)
found := log.HasMatchingViewAnnounce(2, 3, [32]byte{01, 02}) found := log.HasMatchingViewAnnounce(2, 3, [32]byte{01, 02})
if !found { if !found {

@ -63,6 +63,11 @@ func (consensus *Consensus) onAnnounce(msg *msg_pb.Message) {
go func() { go func() {
// Best effort check, no need to error out. // Best effort check, no need to error out.
_, err := consensus.ValidateNewBlock(recvMsg) _, err := consensus.ValidateNewBlock(recvMsg)
if err != nil {
// maybe ban sender
consensus.getLogger().Error().
Err(err).Msgf("[Announce] Failed to validate block")
}
if err == nil { if err == nil {
consensus.GetLogger().Info(). consensus.GetLogger().Info().
Msg("[Announce] Block verified") Msg("[Announce] Block verified")
@ -76,6 +81,7 @@ func (consensus *Consensus) ValidateNewBlock(recvMsg *FBFTMessage) (*types.Block
defer consensus.mutex.Unlock() defer consensus.mutex.Unlock()
return consensus.validateNewBlock(recvMsg) return consensus.validateNewBlock(recvMsg)
} }
func (consensus *Consensus) validateNewBlock(recvMsg *FBFTMessage) (*types.Block, error) { func (consensus *Consensus) validateNewBlock(recvMsg *FBFTMessage) (*types.Block, error) {
if consensus.fBFTLog.IsBlockVerified(recvMsg.BlockHash) { if consensus.fBFTLog.IsBlockVerified(recvMsg.BlockHash) {
var blockObj *types.Block var blockObj *types.Block
@ -125,12 +131,7 @@ func (consensus *Consensus) validateNewBlock(recvMsg *FBFTMessage) (*types.Block
Hex("blockHash", recvMsg.BlockHash[:]). Hex("blockHash", recvMsg.BlockHash[:]).
Msg("[validateNewBlock] Prepared message and block added") Msg("[validateNewBlock] Prepared message and block added")
if consensus.BlockVerifier == nil { if err := consensus.FBFTLog.verifyBlock(&blockObj); err != nil {
consensus.getLogger().Debug().Msg("[validateNewBlock] consensus received message before init. Ignoring")
return nil, errors.New("nil block verifier")
}
if err := consensus.verifyBlock(&blockObj); err != nil {
consensus.getLogger().Error().Err(err).Msg("[validateNewBlock] Block verification failed") consensus.getLogger().Error().Err(err).Msg("[validateNewBlock] Block verification failed")
return nil, errors.Errorf("Block verification failed: %s", err.Error()) return nil, errors.Errorf("Block verification failed: %s", err.Error())
} }

@ -51,9 +51,11 @@ type viewChange struct {
} }
// newViewChange returns a new viewChange object // newViewChange returns a new viewChange object
func newViewChange() *viewChange { func newViewChange(verifyBlock VerifyBlockFunc) *viewChange {
vc := viewChange{} vc := viewChange{}
vc.Reset() vc.Reset()
vc.verifyBlock = verifyBlock
return &vc return &vc
} }

@ -45,7 +45,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra
Interface("preparedMsg", preparedMsg). Interface("preparedMsg", preparedMsg).
Msg("[constructViewChangeMessage] found prepared msg") Msg("[constructViewChangeMessage] found prepared msg")
if block != nil { if block != nil {
if err := consensus.verifyBlock(block); err == nil { if err := consensus.FBFTLog.verifyBlock(block); err == nil {
tmpEncoded, err := rlp.EncodeToBytes(block) tmpEncoded, err := rlp.EncodeToBytes(block)
if err != nil { if err != nil {
consensus.getLogger().Err(err).Msg("[constructViewChangeMessage] Failed encoding block") consensus.getLogger().Err(err).Msg("[constructViewChangeMessage] Failed encoding block")

@ -0,0 +1,33 @@
package blockedpeers
import (
"github.com/harmony-one/harmony/internal/utils/lrucache"
libp2p_peer "github.com/libp2p/go-libp2p/core/peer"
"time"
)
type Manager struct {
internal *lrucache.Cache[libp2p_peer.ID, time.Time]
}
func NewManager(size int) *Manager {
return &Manager{
internal: lrucache.NewCache[libp2p_peer.ID, time.Time](size),
}
}
func (m *Manager) IsBanned(key libp2p_peer.ID, now time.Time) bool {
future, ok := m.internal.Get(key)
if ok {
return future.After(now) // future > now
}
return ok
}
func (m *Manager) Ban(key libp2p_peer.ID, future time.Time) {
m.internal.Set(key, future)
}
func (m *Manager) Contains(key libp2p_peer.ID) bool {
return m.internal.Contains(key)
}

@ -0,0 +1,26 @@
package blockedpeers
import (
libp2p_peer "github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestNewManager(t *testing.T) {
var (
peer1 libp2p_peer.ID = "peer1"
now = time.Now()
m = NewManager(4)
)
t.Run("check_empty", func(t *testing.T) {
require.False(t, m.IsBanned(peer1, now), "peer1 should not be banned")
})
t.Run("ban_peer1", func(t *testing.T) {
m.Ban(peer1, now.Add(2*time.Second))
require.True(t, m.IsBanned(peer1, now), "peer1 should be banned")
require.False(t, m.IsBanned(peer1, now.Add(3*time.Second)), "peer1 should not be banned after 3 seconds")
})
}

@ -25,3 +25,9 @@ func (c *Cache[K, V]) Get(key K) (V, bool) {
func (c *Cache[K, V]) Set(key K, value V) { func (c *Cache[K, V]) Set(key K, value V) {
c.cache.Add(key, value) c.cache.Add(key, value)
} }
// Contains checks if a key is in the cache, without updating the
// recent-ness or deleting it for being stale.
func (c *Cache[K, V]) Contains(key K) bool {
return c.cache.Contains(key)
}

@ -34,9 +34,11 @@ func (timeout *Timeout) Start() {
} }
// Stop stops the timeout clock // Stop stops the timeout clock
func (timeout *Timeout) Stop() { func (timeout *Timeout) Stop() (stopped bool) {
stopped = timeout.state != Inactive
timeout.state = Inactive timeout.state = Inactive
timeout.start = time.Now() timeout.start = time.Now()
return stopped
} }
// Expired checks whether the timeout is reached/expired // Expired checks whether the timeout is reached/expired

@ -559,7 +559,7 @@ func (node *Node) validateNodeMessage(ctx context.Context, payload []byte) (
// validate shardID // validate shardID
// validate public key size // validate public key size
// verify message signature // verify message signature
func validateShardBoundMessage(consensus *consensus.Consensus, nodeConfig *nodeconfig.ConfigType, payload []byte, func validateShardBoundMessage(consensus *consensus.Consensus, peer libp2p_peer.ID, nodeConfig *nodeconfig.ConfigType, payload []byte,
) (*msg_pb.Message, *bls.SerializedPublicKey, bool, error) { ) (*msg_pb.Message, *bls.SerializedPublicKey, bool, error) {
var ( var (
m msg_pb.Message m msg_pb.Message
@ -740,6 +740,7 @@ func (node *Node) StartPubSub() error {
// p2p consensus message handler function // p2p consensus message handler function
type p2pHandlerConsensus func( type p2pHandlerConsensus func(
ctx context.Context, ctx context.Context,
peer libp2p_peer.ID,
msg *msg_pb.Message, msg *msg_pb.Message,
key *bls.SerializedPublicKey, key *bls.SerializedPublicKey,
) error ) error
@ -753,6 +754,7 @@ func (node *Node) StartPubSub() error {
// interface pass to p2p message validator // interface pass to p2p message validator
type validated struct { type validated struct {
peerID libp2p_peer.ID
consensusBound bool consensusBound bool
handleC p2pHandlerConsensus handleC p2pHandlerConsensus
handleCArg *msg_pb.Message handleCArg *msg_pb.Message
@ -810,7 +812,7 @@ func (node *Node) StartPubSub() error {
// validate consensus message // validate consensus message
validMsg, senderPubKey, ignore, err := validateShardBoundMessage( validMsg, senderPubKey, ignore, err := validateShardBoundMessage(
node.Consensus, node.NodeConfig, openBox[proto.MessageCategoryBytes:], node.Consensus, peer, node.NodeConfig, openBox[proto.MessageCategoryBytes:],
) )
if err != nil { if err != nil {
@ -824,6 +826,7 @@ func (node *Node) StartPubSub() error {
} }
msg.ValidatorData = validated{ msg.ValidatorData = validated{
peerID: peer,
consensusBound: true, consensusBound: true,
handleC: node.Consensus.HandleMessageUpdate, handleC: node.Consensus.HandleMessageUpdate,
handleCArg: validMsg, handleCArg: validMsg,
@ -854,6 +857,7 @@ func (node *Node) StartPubSub() error {
} }
} }
msg.ValidatorData = validated{ msg.ValidatorData = validated{
peerID: peer,
consensusBound: false, consensusBound: false,
handleE: node.HandleNodeMessage, handleE: node.HandleNodeMessage,
handleEArg: validMsg, handleEArg: validMsg,
@ -905,7 +909,7 @@ func (node *Node) StartPubSub() error {
errChan <- withError{err, nil} errChan <- withError{err, nil}
} }
} else { } else {
if err := msg.handleC(ctx, msg.handleCArg, msg.senderPubKey); err != nil { if err := msg.handleC(ctx, msg.peerID, msg.handleCArg, msg.senderPubKey); err != nil {
errChan <- withError{err, msg.senderPubKey} errChan <- withError{err, msg.senderPubKey}
} }
} }

@ -11,6 +11,13 @@ import (
"sync" "sync"
"time" "time"
"github.com/harmony-one/bls/ffi/go/bls"
nodeconfig "github.com/harmony-one/harmony/internal/configs/node"
"github.com/harmony-one/harmony/internal/utils"
"github.com/harmony-one/harmony/internal/utils/blockedpeers"
"github.com/harmony-one/harmony/p2p/discovery"
"github.com/harmony-one/harmony/p2p/security"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht" dht "github.com/libp2p/go-libp2p-kad-dht"
libp2p_pubsub "github.com/libp2p/go-libp2p-pubsub" libp2p_pubsub "github.com/libp2p/go-libp2p-pubsub"
@ -24,19 +31,11 @@ import (
"github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing" "github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/net/connmgr" "github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/security/noise" "github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/harmony-one/bls/ffi/go/bls"
nodeconfig "github.com/harmony-one/harmony/internal/configs/node"
"github.com/harmony-one/harmony/internal/utils"
"github.com/harmony-one/harmony/p2p/discovery"
"github.com/harmony-one/harmony/p2p/security"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
) )
type ConnectCallback func(net libp2p_network.Network, conn libp2p_network.Conn) error type ConnectCallback func(net libp2p_network.Network, conn libp2p_network.Conn) error
@ -254,7 +253,8 @@ func NewHost(cfg HostConfig) (Host, error) {
self.PeerID = p2pHost.ID() self.PeerID = p2pHost.ID()
subLogger := utils.Logger().With().Str("hostID", p2pHost.ID().Pretty()).Logger() subLogger := utils.Logger().With().Str("hostID", p2pHost.ID().Pretty()).Logger()
security := security.NewManager(cfg.MaxConnPerIP, int(cfg.MaxPeers)) banned := blockedpeers.NewManager(1024)
security := security.NewManager(cfg.MaxConnPerIP, int(cfg.MaxPeers, banned))
// has to save the private key for host // has to save the private key for host
h := &HostV2{ h := &HostV2{
h: p2pHost, h: p2pHost,
@ -269,6 +269,7 @@ func NewHost(cfg HostConfig) (Host, error) {
logger: &subLogger, logger: &subLogger,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
banned: banned,
} }
utils.Logger().Info(). utils.Logger().Info().
@ -323,6 +324,7 @@ type HostV2 struct {
onDisconnects DisconnectCallbacks onDisconnects DisconnectCallbacks
ctx context.Context ctx context.Context
cancel func() cancel func()
banned *blockedpeers.Manager
} }
// PubSub .. // PubSub ..

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/harmony-one/harmony/internal/utils" "github.com/harmony-one/harmony/internal/utils/blockedpeers"
libp2p_network "github.com/libp2p/go-libp2p/core/network" libp2p_network "github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -15,14 +15,6 @@ type Security interface {
OnDisconnectCheck(conn libp2p_network.Conn) error OnDisconnectCheck(conn libp2p_network.Conn) error
} }
type Manager struct {
maxConnPerIP int
maxPeers int
mutex sync.Mutex
peers *peerMap // All the connected nodes, key is the Peer's IP, value is the peer's ID array
}
type peerMap struct { type peerMap struct {
peers map[string][]string peers map[string][]string
} }
@ -63,7 +55,16 @@ func (peerMap *peerMap) Range(f func(key string, value []string) bool) {
} }
} }
func NewManager(maxConnPerIP int, maxPeers int) *Manager { type Manager struct {
maxConnPerIP int
maxPeers int64
mutex sync.Mutex
peers peerMap // All the connected nodes, key is the Peer's IP, value is the peer's ID array
banned *blockedpeers.Manager
}
func NewManager(maxConnPerIP int, maxPeers int64, banned *blockedpeers.Manager) *Manager {
if maxConnPerIP < 0 { if maxConnPerIP < 0 {
panic("maximum connections per IP must not be negative") panic("maximum connections per IP must not be negative")
} }
@ -74,6 +75,7 @@ func NewManager(maxConnPerIP int, maxPeers int) *Manager {
maxConnPerIP: maxConnPerIP, maxConnPerIP: maxConnPerIP,
maxPeers: maxPeers, maxPeers: maxPeers,
peers: newPeersMap(), peers: newPeersMap(),
banned: banned,
} }
} }
@ -118,6 +120,13 @@ func (m *Manager) OnConnectCheck(net libp2p_network.Network, conn libp2p_network
Msg("too many peers, closing") Msg("too many peers, closing")
return net.ClosePeer(conn.RemotePeer()) return net.ClosePeer(conn.RemotePeer())
} }
if m.banned.IsBanned(conn.RemotePeer(), time.Now()) {
utils.Logger().Warn().
Str("new peer", remoteIp).
Msg("peer is banned, closing")
return net.ClosePeer(conn.RemotePeer())
}
m.peers.Store(remoteIp, peers) m.peers.Store(remoteIp, peers)
return nil return nil
} }

@ -3,6 +3,7 @@ package security
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/harmony-one/harmony/internal/utils/blockedpeers"
"testing" "testing"
"time" "time"
@ -58,7 +59,7 @@ func TestManager_OnConnectCheck(t *testing.T) {
defer h1.Close() defer h1.Close()
fakeHost := &fakeHost{} fakeHost := &fakeHost{}
security := NewManager(2, 1) security := NewManager(2, 1, blockedpeers.NewManager(4))
h1.Network().Notify(fakeHost) h1.Network().Notify(fakeHost)
fakeHost.SetConnectCallback(security.OnConnectCheck) fakeHost.SetConnectCallback(security.OnConnectCheck)
fakeHost.SetDisconnectCallback(security.OnDisconnectCheck) fakeHost.SetDisconnectCallback(security.OnDisconnectCheck)
@ -100,7 +101,7 @@ func TestManager_OnDisconnectCheck(t *testing.T) {
defer h1.Close() defer h1.Close()
fakeHost := &fakeHost{} fakeHost := &fakeHost{}
security := NewManager(2, 0) security := NewManager(2, 0, blockedpeers.NewManager(4))
h1.Network().Notify(fakeHost) h1.Network().Notify(fakeHost)
fakeHost.SetConnectCallback(security.OnConnectCheck) fakeHost.SetConnectCallback(security.OnConnectCheck)
fakeHost.SetDisconnectCallback(security.OnDisconnectCheck) fakeHost.SetDisconnectCallback(security.OnDisconnectCheck)

Loading…
Cancel
Save