You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
2.6 KiB
127 lines
2.6 KiB
package security
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
|
|
libp2p_network "github.com/libp2p/go-libp2p-core/network"
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type Security interface {
|
|
OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error
|
|
OnDisconnectCheck(conn libp2p_network.Conn) error
|
|
}
|
|
|
|
type Manager struct {
|
|
maxConnPerIP int
|
|
|
|
mutex sync.Mutex
|
|
peers sync.Map // All the connected nodes, key is the Peer's IP, value is the peer's ID array
|
|
}
|
|
|
|
func NewManager(maxConnPerIP int) *Manager {
|
|
return &Manager{
|
|
maxConnPerIP: maxConnPerIP,
|
|
}
|
|
}
|
|
|
|
func (m *Manager) OnConnectCheck(net libp2p_network.Network, conn libp2p_network.Conn) error {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
ip, err := getIP(conn)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed on get ip")
|
|
}
|
|
|
|
value, ok := m.peers.Load(ip)
|
|
if !ok {
|
|
value = []string{}
|
|
}
|
|
|
|
peers, ok := value.([]string)
|
|
if !ok {
|
|
return errors.New("peers info type err")
|
|
}
|
|
|
|
// avoid add repeatedly
|
|
peerID := conn.RemotePeer().String()
|
|
_, ok = find(peers, peerID)
|
|
if !ok {
|
|
peers = append(peers, peerID)
|
|
m.peers.Store(ip, peers)
|
|
}
|
|
|
|
if len(peers) > m.maxConnPerIP {
|
|
if err := net.ClosePeer(conn.RemotePeer()); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) OnDisconnectCheck(conn libp2p_network.Conn) error {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
ip, err := getIP(conn)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed on get ip")
|
|
}
|
|
|
|
value, ok := m.peers.Load(ip)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
peers, ok := value.([]string)
|
|
if !ok {
|
|
return errors.New("peers info type err")
|
|
}
|
|
|
|
peerID := conn.RemotePeer().String()
|
|
index, ok := find(peers, peerID)
|
|
if ok {
|
|
peers = append(peers[:index], peers[index+1:]...)
|
|
m.peers.Store(ip, peers)
|
|
if len(peers) == 0 {
|
|
m.peers.Delete(ip)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func find(slice []string, val string) (int, bool) {
|
|
for i, item := range slice {
|
|
if item == val {
|
|
return i, true
|
|
}
|
|
}
|
|
|
|
return -1, false
|
|
}
|
|
|
|
func getIP(conn libp2p_network.Conn) (string, error) {
|
|
for _, protocol := range conn.RemoteMultiaddr().Protocols() {
|
|
switch protocol.Code {
|
|
case ma.P_IP4:
|
|
ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP4)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "failed on get ipv4 addr")
|
|
}
|
|
return ip, nil
|
|
case ma.P_IP6:
|
|
ip, err := conn.RemoteMultiaddr().ValueForProtocol(ma.P_IP6)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "failed on get ipv6 addr")
|
|
}
|
|
return ip, nil
|
|
}
|
|
}
|
|
|
|
return "", errors.New(fmt.Sprintf("failed on get remote peer ip from addr: %s", conn.RemoteMultiaddr().String()))
|
|
}
|
|
|