package security import ( "fmt" "sync" "github.com/harmony-one/harmony/internal/utils" libp2p_network "github.com/libp2p/go-libp2p/core/network" ma "github.com/multiformats/go-multiaddr" "github.com/pkg/errors" ) type Security interface { OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error OnDisconnectCheck(conn libp2p_network.Conn) error } type Manager struct { maxConnPerIP int maxPeers int mutex sync.Mutex peers *peerMap // All the connected nodes, key is the Peer's IP, value is the peer's ID array } type peerMap struct { peers map[string][]string } func newPeersMap() *peerMap { return &peerMap{ peers: make(map[string][]string), } } func (peerMap *peerMap) Len() int { return len(peerMap.peers) } func (peerMap *peerMap) Store(key string, value []string) { peerMap.peers[key] = value } func (peerMap *peerMap) HasKey(key string) bool { _, ok := peerMap.peers[key] return ok } func (peerMap *peerMap) Delete(key string) { delete(peerMap.peers, key) } func (peerMap *peerMap) Load(key string) (value []string, ok bool) { value, ok = peerMap.peers[key] return value, ok } func (peerMap *peerMap) Range(f func(key string, value []string) bool) { for key, value := range peerMap.peers { if !f(key, value) { break } } } func NewManager(maxConnPerIP int, maxPeers int) *Manager { if maxConnPerIP < 0 { panic("maximum connections per IP must not be negative") } if maxPeers < 0 { panic("maximum peers must not be negative") } return &Manager{ maxConnPerIP: maxConnPerIP, maxPeers: maxPeers, peers: newPeersMap(), } } func (m *Manager) RangePeers(f func(key string, value []string) bool) { m.mutex.Lock() defer m.mutex.Unlock() m.peers.Range(f) } func (m *Manager) OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error { m.mutex.Lock() defer m.mutex.Unlock() remoteIp, err := getRemoteIP(conn) if err != nil { return errors.Wrap(err, "failed on get remote ip") } peers, _ := m.peers.Load(remoteIp) // avoid add repeatedly peerID := conn.RemotePeer().String() _, ok := find(peers, peerID) if !ok { peers = append(peers, peerID) } if m.maxConnPerIP > 0 && len(peers) > m.maxConnPerIP { utils.Logger().Warn(). Int("len(peers)", len(peers)). Int("maxConnPerIP", m.maxConnPerIP). Msgf("too many connections from %s, closing", remoteIp) return net.ClosePeer(conn.RemotePeer()) } currentPeerCount := m.peers.Len() // only limit addition if it's a new peer and not an existing peer with new connection if m.maxPeers > 0 && currentPeerCount >= m.maxPeers && !m.peers.HasKey(remoteIp) { utils.Logger().Warn(). Int("connected peers", currentPeerCount). Str("new peer", remoteIp). Msg("too many peers, closing") return net.ClosePeer(conn.RemotePeer()) } m.peers.Store(remoteIp, peers) return nil } func (m *Manager) OnDisconnectCheck(conn libp2p_network.Conn) error { m.mutex.Lock() defer m.mutex.Unlock() ip, err := getRemoteIP(conn) if err != nil { return errors.Wrap(err, "failed on get ip") } peers, ok := m.peers.Load(ip) if !ok { return nil } peerID := conn.RemotePeer().String() index, ok := find(peers, peerID) if ok { peers = append(peers[:index], peers[index+1:]...) if len(peers) == 0 { m.peers.Delete(ip) } else { m.peers.Store(ip, peers) } } return nil } func find(slice []string, val string) (int, bool) { for i, item := range slice { if item == val { return i, true } } return -1, false } func getRemoteIP(conn libp2p_network.Conn) (string, error) { for _, protocol := range conn.RemoteMultiaddr().Protocols() { switch protocol.Code { case ma.P_IP4: ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP4) if err != nil { return "", errors.Wrap(err, "failed on get ipv4 addr") } return ip, nil case ma.P_IP6: ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP6) if err != nil { return "", errors.Wrap(err, "failed on get ipv6 addr") } return ip, nil } } return "", errors.New(fmt.Sprintf("failed on get remote peer ip from addr: %s", conn.RemoteMultiaddr().String())) }