diff --git a/api/service/syncing/syncing.go b/api/service/syncing/syncing.go index 3fca87d9e..9e5e8e9bb 100644 --- a/api/service/syncing/syncing.go +++ b/api/service/syncing/syncing.go @@ -62,6 +62,11 @@ func (peerConfig *SyncPeerConfig) GetClient() *downloader.Client { return peerConfig.client } +// IsEqual checks the equality between two sync peers +func (peerConfig *SyncPeerConfig) IsEqual(pc2 *SyncPeerConfig) bool { + return peerConfig.ip == pc2.ip && peerConfig.port == pc2.port +} + // SyncBlockTask is the task struct to sync a specific block. type SyncBlockTask struct { index int @@ -110,6 +115,13 @@ type SyncConfig struct { func (sc *SyncConfig) AddPeer(peer *SyncPeerConfig) { sc.mtx.Lock() defer sc.mtx.Unlock() + + // Ensure no duplicate peers + for _, p2 := range sc.peers { + if peer.IsEqual(p2) { + return + } + } sc.peers = append(sc.peers, peer) } @@ -279,6 +291,10 @@ func (peerConfig *SyncPeerConfig) GetBlocks(hashes [][]byte) ([][]byte, error) { // CreateSyncConfig creates SyncConfig for StateSync object. func (ss *StateSync) CreateSyncConfig(peers []p2p.Peer, isBeacon bool) error { + // sanity check to ensure no duplicate peers + if err := checkPeersDuplicity(peers); err != nil { + return err + } // limit the number of dns peers to connect randSeed := time.Now().UnixNano() peers = limitNumPeers(peers, randSeed) @@ -322,6 +338,23 @@ func (ss *StateSync) CreateSyncConfig(peers []p2p.Peer, isBeacon bool) error { return nil } +// checkPeersDuplicity checks whether there are duplicates in p2p.Peer +func checkPeersDuplicity(ps []p2p.Peer) error { + type peerDupID struct { + ip string + port string + } + m := make(map[peerDupID]struct{}) + for _, p := range ps { + dip := peerDupID{p.IP, p.Port} + if _, ok := m[dip]; ok { + return fmt.Errorf("duplicate peer [%v:%v]", p.IP, p.Port) + } + m[dip] = struct{}{} + } + return nil +} + // limitNumPeers limits number of peers to release some server end sources. func limitNumPeers(ps []p2p.Peer, randSeed int64) []p2p.Peer { targetSize := calcNumPeersWithBound(len(ps), NumPeersLowBound, numPeersHighBound) diff --git a/api/service/syncing/syncing_test.go b/api/service/syncing/syncing_test.go index bd98ec238..b157a45c1 100644 --- a/api/service/syncing/syncing_test.go +++ b/api/service/syncing/syncing_test.go @@ -1,8 +1,10 @@ package syncing import ( + "errors" "fmt" "reflect" + "strings" "testing" "github.com/harmony-one/harmony/api/service/syncing/downloader" @@ -10,6 +12,53 @@ import ( "github.com/stretchr/testify/assert" ) +func TestSyncPeerConfig_IsEqual(t *testing.T) { + tests := []struct { + p1, p2 *SyncPeerConfig + exp bool + }{ + { + p1: &SyncPeerConfig{ + ip: "0.0.0.1", + port: "1", + }, + p2: &SyncPeerConfig{ + ip: "0.0.0.1", + port: "2", + }, + exp: false, + }, + { + p1: &SyncPeerConfig{ + ip: "0.0.0.1", + port: "1", + }, + p2: &SyncPeerConfig{ + ip: "0.0.0.2", + port: "1", + }, + exp: false, + }, + { + p1: &SyncPeerConfig{ + ip: "0.0.0.1", + port: "1", + }, + p2: &SyncPeerConfig{ + ip: "0.0.0.1", + port: "1", + }, + exp: true, + }, + } + for i, test := range tests { + res := test.p1.IsEqual(test.p2) + if res != test.exp { + t.Errorf("Test %v: unexpected res %v / %v", i, res, test.exp) + } + } +} + // Simple test for IncorrectResponse func TestCreateTestSyncPeerConfig(t *testing.T) { client := &downloader.Client{} @@ -53,6 +102,32 @@ func TestCreateStateSync(t *testing.T) { } } +func TestCheckPeersDuplicity(t *testing.T) { + tests := []struct { + peers []p2p.Peer + expErr error + }{ + { + peers: makePeersForTest(100), + expErr: nil, + }, + { + peers: append(makePeersForTest(100), p2p.Peer{ + IP: makeTestPeerIP(0), + }), + expErr: errors.New("duplicate peer"), + }, + } + + for i, test := range tests { + err := checkPeersDuplicity(test.peers) + + if assErr := assertTestError(err, test.expErr); assErr != nil { + t.Errorf("Test %v: %v", i, assErr) + } + } +} + func TestLimitPeersWithBound(t *testing.T) { tests := []struct { size int @@ -76,7 +151,7 @@ func TestLimitPeersWithBound(t *testing.T) { if len(res) != test.expSize { t.Errorf("result size unexpected: %v / %v", len(res), test.expSize) } - if err := checkTestPeerDuplicity(res); err != nil { + if err := checkPeersDuplicity(res); err != nil { t.Error(err) } } @@ -97,24 +172,30 @@ func TestLimitPeersWithBound_random(t *testing.T) { func makePeersForTest(size int) []p2p.Peer { ps := make([]p2p.Peer, 0, size) for i := 0; i != size; i++ { - ps = append(ps, p2p.Peer{ - IP: makeTestPeerIP(i), - }) + ps = append(ps, makePeerForTest(i)) } return ps } -func checkTestPeerDuplicity(ps []p2p.Peer) error { - m := make(map[string]struct{}) - for _, p := range ps { - if _, ok := m[p.IP]; ok { - return fmt.Errorf("duplicate ip") - } - m[p.IP] = struct{}{} +func makePeerForTest(i interface{}) p2p.Peer { + return p2p.Peer{ + IP: makeTestPeerIP(i), } - return nil } func makeTestPeerIP(i interface{}) string { return fmt.Sprintf("%v", i) } + +func assertTestError(got, expect error) error { + if (got == nil) && (expect == nil) { + return nil + } + if (got == nil) != (expect == nil) { + return fmt.Errorf("unexpected error: [%v] / [%v]", got, expect) + } + if !strings.Contains(got.Error(), expect.Error()) { + return fmt.Errorf("unexpected error: [%v] / [%v]", got, expect) + } + return nil +}