The core protocol of WoopChain
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
woop/p2p/security/security.go

193 lines
4.3 KiB

package security
import (
"fmt"
"sync"
"sync/atomic"
"github.com/harmony-one/harmony/internal/utils"
libp2p_network "github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
"github.com/pkg/errors"
)
type Security interface {
OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error
OnDisconnectCheck(conn libp2p_network.Conn) error
}
type Manager struct {
maxConnPerIP int
maxPeers int64
mutex sync.Mutex
peers peerMap // All the connected nodes, key is the Peer's IP, value is the peer's ID array
}
type peerMap struct {
count int64
peers sync.Map
}
func (peerMap *peerMap) Len() int64 {
return atomic.LoadInt64(&peerMap.count)
}
func (peerMap *peerMap) Store(key, value interface{}) {
// only increment if you didn't have this key
hasKey := peerMap.HasKey(key)
peerMap.peers.Store(key, value)
if !hasKey {
atomic.AddInt64(&peerMap.count, 1)
}
}
func (peerMap *peerMap) HasKey(key interface{}) bool {
hasKey := false
peerMap.peers.Range(func(k, v interface{}) bool {
if k == key {
hasKey = true
return false
}
return true
})
return hasKey
}
func (peerMap *peerMap) Delete(key interface{}) {
peerMap.peers.Delete(key)
atomic.AddInt64(&peerMap.count, -1)
}
func (peerMap *peerMap) Load(key interface{}) (value interface{}, ok bool) {
return peerMap.peers.Load(key)
}
func (peerMap *peerMap) Range(f func(key, value any) bool) {
peerMap.peers.Range(f)
}
func NewManager(maxConnPerIP int, maxPeers int64) *Manager {
if maxConnPerIP < 0 {
panic("maximum connections per IP must not be negative")
}
if maxPeers < 0 {
panic("maximum peers must not be negative")
}
return &Manager{
maxConnPerIP: maxConnPerIP,
maxPeers: maxPeers,
}
}
func (m *Manager) OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
remoteIp, err := getRemoteIP(conn)
if err != nil {
return errors.Wrap(err, "failed on get remote ip")
}
value, ok := m.peers.Load(remoteIp)
if !ok {
value = []string{}
}
peers, ok := value.([]string)
if !ok {
return errors.New("peers info type err")
}
// avoid add repeatedly
peerID := conn.RemotePeer().String()
_, ok = find(peers, peerID)
if !ok {
peers = append(peers, peerID)
}
if m.maxConnPerIP > 0 && len(peers) > m.maxConnPerIP {
utils.Logger().Warn().
Int("len(peers)", len(peers)).
Int("maxConnPerIP", m.maxConnPerIP).
Msgf("too many connections from %s, closing", remoteIp)
return net.ClosePeer(conn.RemotePeer())
}
currentPeerCount := m.peers.Len()
// only limit addition if it's a new peer and not an existing peer with new connection
if m.maxPeers > 0 && currentPeerCount >= m.maxPeers && !m.peers.HasKey(remoteIp) {
utils.Logger().Warn().
Int64("connected peers", currentPeerCount).
Str("new peer", remoteIp).
Msg("too many peers, closing")
return net.ClosePeer(conn.RemotePeer())
}
m.peers.Store(remoteIp, peers)
return nil
}
func (m *Manager) OnDisconnectCheck(conn libp2p_network.Conn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
ip, err := getRemoteIP(conn)
if err != nil {
return errors.Wrap(err, "failed on get ip")
}
value, ok := m.peers.Load(ip)
if !ok {
return nil
}
peers, ok := value.([]string)
if !ok {
return errors.New("peers info type err")
}
peerID := conn.RemotePeer().String()
index, ok := find(peers, peerID)
if ok {
peers = append(peers[:index], peers[index+1:]...)
if len(peers) == 0 {
m.peers.Delete(ip)
} else {
m.peers.Store(ip, peers)
}
}
return nil
}
func find(slice []string, val string) (int, bool) {
for i, item := range slice {
if item == val {
return i, true
}
}
return -1, false
}
func getRemoteIP(conn libp2p_network.Conn) (string, error) {
for _, protocol := range conn.RemoteMultiaddr().Protocols() {
switch protocol.Code {
case ma.P_IP4:
ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP4)
if err != nil {
return "", errors.Wrap(err, "failed on get ipv4 addr")
}
return ip, nil
case ma.P_IP6:
ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP6)
if err != nil {
return "", errors.Wrap(err, "failed on get ipv6 addr")
}
return ip, nil
}
}
return "", errors.New(fmt.Sprintf("failed on get remote peer ip from addr: %s", conn.RemoteMultiaddr().String()))
}