|
|
|
package security
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
|
|
|
|
"github.com/harmony-one/harmony/internal/utils"
|
|
|
|
libp2p_network "github.com/libp2p/go-libp2p/core/network"
|
|
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Security interface {
|
|
|
|
OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error
|
|
|
|
OnDisconnectCheck(conn libp2p_network.Conn) error
|
|
|
|
}
|
|
|
|
|
|
|
|
type Manager struct {
|
|
|
|
maxConnPerIP int
|
|
|
|
maxPeers int64
|
|
|
|
|
|
|
|
mutex sync.Mutex
|
|
|
|
peers peerMap // All the connected nodes, key is the Peer's IP, value is the peer's ID array
|
|
|
|
}
|
|
|
|
|
|
|
|
type peerMap struct {
|
|
|
|
count int64
|
|
|
|
peers sync.Map
|
|
|
|
}
|
|
|
|
|
|
|
|
func (peerMap *peerMap) Len() int64 {
|
|
|
|
return atomic.LoadInt64(&peerMap.count)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (peerMap *peerMap) Store(key, value interface{}) {
|
|
|
|
// only increment if you didn't have this key
|
|
|
|
hasKey := peerMap.HasKey(key)
|
|
|
|
peerMap.peers.Store(key, value)
|
|
|
|
if !hasKey {
|
|
|
|
atomic.AddInt64(&peerMap.count, 1)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (peerMap *peerMap) HasKey(key interface{}) bool {
|
|
|
|
hasKey := false
|
|
|
|
peerMap.peers.Range(func(k, v interface{}) bool {
|
|
|
|
if k == key {
|
|
|
|
hasKey = true
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
return true
|
|
|
|
})
|
|
|
|
return hasKey
|
|
|
|
}
|
|
|
|
|
|
|
|
func (peerMap *peerMap) Delete(key interface{}) {
|
|
|
|
peerMap.peers.Delete(key)
|
|
|
|
atomic.AddInt64(&peerMap.count, -1)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (peerMap *peerMap) Load(key interface{}) (value interface{}, ok bool) {
|
|
|
|
return peerMap.peers.Load(key)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (peerMap *peerMap) Range(f func(key, value any) bool) {
|
|
|
|
peerMap.peers.Range(f)
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewManager(maxConnPerIP int, maxPeers int64) *Manager {
|
|
|
|
if maxConnPerIP < 0 {
|
|
|
|
panic("maximum connections per IP must not be negative")
|
|
|
|
}
|
|
|
|
if maxPeers < 0 {
|
|
|
|
panic("maximum peers must not be negative")
|
|
|
|
}
|
|
|
|
return &Manager{
|
|
|
|
maxConnPerIP: maxConnPerIP,
|
|
|
|
maxPeers: maxPeers,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Manager) OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error {
|
|
|
|
m.mutex.Lock()
|
|
|
|
defer m.mutex.Unlock()
|
|
|
|
|
|
|
|
remoteIp, err := getRemoteIP(conn)
|
|
|
|
if err != nil {
|
|
|
|
return errors.Wrap(err, "failed on get remote ip")
|
|
|
|
}
|
|
|
|
|
|
|
|
value, ok := m.peers.Load(remoteIp)
|
|
|
|
if !ok {
|
|
|
|
value = []string{}
|
|
|
|
}
|
|
|
|
|
|
|
|
peers, ok := value.([]string)
|
|
|
|
if !ok {
|
|
|
|
return errors.New("peers info type err")
|
|
|
|
}
|
|
|
|
|
|
|
|
// avoid add repeatedly
|
|
|
|
peerID := conn.RemotePeer().String()
|
|
|
|
_, ok = find(peers, peerID)
|
|
|
|
if !ok {
|
|
|
|
peers = append(peers, peerID)
|
|
|
|
}
|
|
|
|
|
|
|
|
if m.maxConnPerIP > 0 && len(peers) > m.maxConnPerIP {
|
|
|
|
utils.Logger().Warn().
|
|
|
|
Int("len(peers)", len(peers)).
|
|
|
|
Int("maxConnPerIP", m.maxConnPerIP).
|
|
|
|
Msgf("too many connections from %s, closing", remoteIp)
|
|
|
|
return net.ClosePeer(conn.RemotePeer())
|
|
|
|
}
|
|
|
|
|
|
|
|
currentPeerCount := m.peers.Len()
|
|
|
|
// only limit addition if it's a new peer and not an existing peer with new connection
|
|
|
|
if m.maxPeers > 0 && currentPeerCount >= m.maxPeers && !m.peers.HasKey(remoteIp) {
|
|
|
|
utils.Logger().Warn().
|
|
|
|
Int64("connected peers", currentPeerCount).
|
|
|
|
Str("new peer", remoteIp).
|
|
|
|
Msg("too many peers, closing")
|
|
|
|
return net.ClosePeer(conn.RemotePeer())
|
|
|
|
}
|
|
|
|
m.peers.Store(remoteIp, peers)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Manager) OnDisconnectCheck(conn libp2p_network.Conn) error {
|
|
|
|
m.mutex.Lock()
|
|
|
|
defer m.mutex.Unlock()
|
|
|
|
|
|
|
|
ip, err := getRemoteIP(conn)
|
|
|
|
if err != nil {
|
|
|
|
return errors.Wrap(err, "failed on get ip")
|
|
|
|
}
|
|
|
|
|
|
|
|
value, ok := m.peers.Load(ip)
|
|
|
|
if !ok {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
peers, ok := value.([]string)
|
|
|
|
if !ok {
|
|
|
|
return errors.New("peers info type err")
|
|
|
|
}
|
|
|
|
|
|
|
|
peerID := conn.RemotePeer().String()
|
|
|
|
index, ok := find(peers, peerID)
|
|
|
|
if ok {
|
|
|
|
peers = append(peers[:index], peers[index+1:]...)
|
|
|
|
if len(peers) == 0 {
|
|
|
|
m.peers.Delete(ip)
|
|
|
|
} else {
|
|
|
|
m.peers.Store(ip, peers)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func find(slice []string, val string) (int, bool) {
|
|
|
|
for i, item := range slice {
|
|
|
|
if item == val {
|
|
|
|
return i, true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return -1, false
|
|
|
|
}
|
|
|
|
|
|
|
|
func getRemoteIP(conn libp2p_network.Conn) (string, error) {
|
|
|
|
for _, protocol := range conn.RemoteMultiaddr().Protocols() {
|
|
|
|
switch protocol.Code {
|
|
|
|
case ma.P_IP4:
|
|
|
|
ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP4)
|
|
|
|
if err != nil {
|
|
|
|
return "", errors.Wrap(err, "failed on get ipv4 addr")
|
|
|
|
}
|
|
|
|
return ip, nil
|
|
|
|
case ma.P_IP6:
|
|
|
|
ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP6)
|
|
|
|
if err != nil {
|
|
|
|
return "", errors.Wrap(err, "failed on get ipv6 addr")
|
|
|
|
}
|
|
|
|
return ip, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return "", errors.New(fmt.Sprintf("failed on get remote peer ip from addr: %s", conn.RemoteMultiaddr().String()))
|
|
|
|
}
|