Fix race errors. (#4184)

* Fix data races.

* Fix block num.

* Pub key lock.

* Fixed race errors.

* Fix flag.

* Fix comments.

* Fix comments.

* Fix type.

* Fix race errors in tests.

Co-authored-by: Konstantin <k.potapov@softpro.com>
pull/4219/head
Konstantin 2 years ago committed by GitHub
parent 4eabc120b1
commit 06de7dcd6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 14
      consensus/checks.go
  3. 17
      consensus/consensus.go
  4. 30
      consensus/consensus_fbft.go
  5. 15
      consensus/consensus_fbft_test.go
  6. 15
      consensus/consensus_service.go
  7. 2
      consensus/consensus_test.go
  8. 28
      consensus/consensus_v2.go
  9. 2
      consensus/construct.go
  10. 9
      consensus/construct_test.go
  11. 21
      consensus/debug.go
  12. 2
      consensus/double_sign.go
  13. 4
      consensus/leader.go
  14. 6
      consensus/threshold.go
  15. 30
      consensus/validator.go
  16. 14
      consensus/view_change.go
  17. 6
      consensus/view_change_msg.go
  18. 6
      consensus/view_change_test.go
  19. 70
      core/tx_pool_test.go
  20. 5
      core/types/block.go
  21. 30
      internal/utils/keylocker/keylocker.go
  22. 17
      internal/utils/keylocker/keylocker_test.go
  23. 35
      internal/utils/singleton_test.go
  24. 14
      internal/utils/timer.go
  25. 2
      node/api.go
  26. 2
      node/node_newblock.go
  27. 16
      p2p/host.go
  28. 13
      p2p/stream/protocols/sync/chain.go
  29. 19
      p2p/stream/protocols/sync/protocol_test.go
  30. 41
      p2p/utils.go
  31. 37
      p2p/utils_test.go
  32. 2
      scripts/travis_go_checker.sh

@ -156,4 +156,4 @@ go-vet:
go vet ./... go vet ./...
go-test: go-test:
go test ./... go test -vet=all -race ./...

@ -57,9 +57,9 @@ func (consensus *Consensus) senderKeySanityChecks(msg *msg_pb.Message, senderKey
func (consensus *Consensus) isRightBlockNumAndViewID(recvMsg *FBFTMessage, func (consensus *Consensus) isRightBlockNumAndViewID(recvMsg *FBFTMessage,
) bool { ) bool {
if recvMsg.ViewID != consensus.GetCurBlockViewID() || recvMsg.BlockNum != consensus.blockNum { if recvMsg.ViewID != consensus.GetCurBlockViewID() || recvMsg.BlockNum != consensus.BlockNum() {
consensus.getLogger().Debug(). consensus.getLogger().Debug().
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Str("recvMsg", recvMsg.String()). Str("recvMsg", recvMsg.String()).
Msg("BlockNum/viewID not match") Msg("BlockNum/viewID not match")
return false return false
@ -103,12 +103,12 @@ func (consensus *Consensus) onAnnounceSanityChecks(recvMsg *FBFTMessage) bool {
} }
func (consensus *Consensus) isRightBlockNumCheck(recvMsg *FBFTMessage) bool { func (consensus *Consensus) isRightBlockNumCheck(recvMsg *FBFTMessage) bool {
if recvMsg.BlockNum < consensus.blockNum { if recvMsg.BlockNum < consensus.BlockNum() {
consensus.getLogger().Debug(). consensus.getLogger().Debug().
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Msg("Wrong BlockNum Received, ignoring!") Msg("Wrong BlockNum Received, ignoring!")
return false return false
} else if recvMsg.BlockNum-consensus.blockNum > MaxBlockNumDiff { } else if recvMsg.BlockNum-consensus.BlockNum() > MaxBlockNumDiff {
consensus.getLogger().Debug(). consensus.getLogger().Debug().
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("MaxBlockNumDiff", MaxBlockNumDiff). Uint64("MaxBlockNumDiff", MaxBlockNumDiff).
@ -122,7 +122,7 @@ func (consensus *Consensus) newBlockSanityChecks(
blockObj *types.Block, recvMsg *FBFTMessage, blockObj *types.Block, recvMsg *FBFTMessage,
) bool { ) bool {
if blockObj.NumberU64() != recvMsg.BlockNum || if blockObj.NumberU64() != recvMsg.BlockNum ||
recvMsg.BlockNum < consensus.blockNum { recvMsg.BlockNum < consensus.BlockNum() {
consensus.getLogger().Warn(). consensus.getLogger().Warn().
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("blockNum", blockObj.NumberU64()). Uint64("blockNum", blockObj.NumberU64()).
@ -152,12 +152,12 @@ func (consensus *Consensus) onViewChangeSanityCheck(recvMsg *FBFTMessage) bool {
Interface("SendPubKeys", recvMsg.SenderPubkeys). Interface("SendPubKeys", recvMsg.SenderPubkeys).
Msg("[onViewChangeSanityCheck]") Msg("[onViewChangeSanityCheck]")
if consensus.blockNum > recvMsg.BlockNum { if consensus.BlockNum() > recvMsg.BlockNum {
consensus.getLogger().Debug(). consensus.getLogger().Debug().
Msg("[onViewChange] Message BlockNum Is Low") Msg("[onViewChange] Message BlockNum Is Low")
return false return false
} }
if consensus.blockNum < recvMsg.BlockNum { if consensus.BlockNum() < recvMsg.BlockNum {
consensus.getLogger().Warn(). consensus.getLogger().Warn().
Msg("[onViewChangeSanityCheck] MsgBlockNum is different from my BlockNumber") Msg("[onViewChangeSanityCheck] MsgBlockNum is different from my BlockNumber")
return false return false

@ -3,6 +3,7 @@ package consensus
import ( import (
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/harmony-one/harmony/crypto/bls" "github.com/harmony-one/harmony/crypto/bls"
@ -45,7 +46,7 @@ type Consensus struct {
// FBFTLog stores the pbft messages and blocks during FBFT process // FBFTLog stores the pbft messages and blocks during FBFT process
FBFTLog *FBFTLog FBFTLog *FBFTLog
// phase: different phase of FBFT protocol: pre-prepare, prepare, commit, finish etc // phase: different phase of FBFT protocol: pre-prepare, prepare, commit, finish etc
phase FBFTPhase phase *LockedFBFTPhase
// current indicates what state a node is in // current indicates what state a node is in
current State current State
// isBackup declarative the node is in backup mode // isBackup declarative the node is in backup mode
@ -131,7 +132,7 @@ type Consensus struct {
// finality of previous consensus in the unit of milliseconds // finality of previous consensus in the unit of milliseconds
finality int64 finality int64
// finalityCounter keep tracks of the finality time // finalityCounter keep tracks of the finality time
finalityCounter int64 finalityCounter atomic.Value //int64
dHelper *downloadHelper dHelper *downloadHelper
} }
@ -168,6 +169,12 @@ func (consensus *Consensus) GetPublicKeys() multibls.PublicKeys {
return consensus.priKey.GetPublicKeys() return consensus.priKey.GetPublicKeys()
} }
func (consensus *Consensus) GetLeaderPubKey() *bls_cosi.PublicKeyWrapper {
consensus.pubKeyLock.Lock()
defer consensus.pubKeyLock.Unlock()
return consensus.LeaderPubKey
}
func (consensus *Consensus) GetPrivateKeys() multibls.PrivateKeys { func (consensus *Consensus) GetPrivateKeys() multibls.PrivateKeys {
return consensus.priKey return consensus.priKey
} }
@ -197,6 +204,10 @@ func (consensus *Consensus) IsBackup() bool {
return consensus.isBackup return consensus.isBackup
} }
func (consensus *Consensus) BlockNum() uint64 {
return atomic.LoadUint64(&consensus.blockNum)
}
// New create a new Consensus record // New create a new Consensus record
func New( func New(
host p2p.Host, shard uint32, leader p2p.Peer, multiBLSPriKey multibls.PrivateKeys, host p2p.Host, shard uint32, leader p2p.Peer, multiBLSPriKey multibls.PrivateKeys,
@ -209,7 +220,7 @@ func New(
consensus.BlockNumLowChan = make(chan struct{}, 1) consensus.BlockNumLowChan = make(chan struct{}, 1)
// FBFT related // FBFT related
consensus.FBFTLog = NewFBFTLog() consensus.FBFTLog = NewFBFTLog()
consensus.phase = FBFTAnnounce consensus.phase = NewLockedFBFTPhase(FBFTAnnounce)
consensus.current = State{mode: Normal} consensus.current = State{mode: Normal}
// FBFT timeout // FBFT timeout
consensus.consensusTimeout = createTimeout() consensus.consensusTimeout = createTimeout()

@ -0,0 +1,30 @@
package consensus
import "sync"
type LockedFBFTPhase struct {
mu sync.Mutex
phase FBFTPhase
}
func NewLockedFBFTPhase(initialPhrase FBFTPhase) *LockedFBFTPhase {
return &LockedFBFTPhase{
phase: initialPhrase,
}
}
func (a *LockedFBFTPhase) Set(phrase FBFTPhase) {
a.mu.Lock()
a.phase = phrase
a.mu.Unlock()
}
func (a *LockedFBFTPhase) Get() FBFTPhase {
a.mu.Lock()
defer a.mu.Unlock()
return a.phase
}
func (a *LockedFBFTPhase) String() string {
return a.Get().String()
}

@ -0,0 +1,15 @@
package consensus
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestLockedFBFTPhase(t *testing.T) {
s := NewLockedFBFTPhase(FBFTAnnounce)
require.Equal(t, FBFTAnnounce, s.Get())
s.Set(FBFTCommit)
require.Equal(t, FBFTCommit, s.Get())
}

@ -403,7 +403,9 @@ func (consensus *Consensus) UpdateConsensusInformation() Mode {
consensus.getLogger().Info(). consensus.getLogger().Info().
Str("leaderPubKey", leaderPubKey.Bytes.Hex()). Str("leaderPubKey", leaderPubKey.Bytes.Hex()).
Msg("[UpdateConsensusInformation] Most Recent LeaderPubKey Updated Based on BlockChain") Msg("[UpdateConsensusInformation] Most Recent LeaderPubKey Updated Based on BlockChain")
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = leaderPubKey consensus.LeaderPubKey = leaderPubKey
consensus.pubKeyLock.Unlock()
} }
} }
@ -442,8 +444,11 @@ func (consensus *Consensus) UpdateConsensusInformation() Mode {
// IsLeader check if the node is a leader or not by comparing the public key of // IsLeader check if the node is a leader or not by comparing the public key of
// the node with the leader public key // the node with the leader public key
func (consensus *Consensus) IsLeader() bool { func (consensus *Consensus) IsLeader() bool {
consensus.pubKeyLock.Lock()
obj := consensus.LeaderPubKey.Object
consensus.pubKeyLock.Unlock()
for _, key := range consensus.priKey { for _, key := range consensus.priKey {
if key.Pub.Object.IsEqual(consensus.LeaderPubKey.Object) { if key.Pub.Object.IsEqual(obj) {
return true return true
} }
} }
@ -469,13 +474,13 @@ func (consensus *Consensus) SetViewChangingID(viewID uint64) {
// StartFinalityCount set the finality counter to current time // StartFinalityCount set the finality counter to current time
func (consensus *Consensus) StartFinalityCount() { func (consensus *Consensus) StartFinalityCount() {
consensus.finalityCounter = time.Now().UnixNano() consensus.finalityCounter.Store(time.Now().UnixNano())
} }
// FinishFinalityCount calculate the current finality // FinishFinalityCount calculate the current finality
func (consensus *Consensus) FinishFinalityCount() { func (consensus *Consensus) FinishFinalityCount() {
d := time.Now().UnixNano() d := time.Now().UnixNano()
consensus.finality = (d - consensus.finalityCounter) / 1000000 consensus.finality = (d - consensus.finalityCounter.Load().(int64)) / 1000000
consensusFinalityHistogram.Observe(float64(consensus.finality)) consensusFinalityHistogram.Observe(float64(consensus.finality))
} }
@ -491,7 +496,7 @@ func (consensus *Consensus) switchPhase(subject string, desired FBFTPhase) {
Str("to:", desired.String()). Str("to:", desired.String()).
Str("switchPhase:", subject) Str("switchPhase:", subject)
consensus.phase = desired consensus.phase.Set(desired)
return return
} }
@ -580,7 +585,7 @@ func (consensus *Consensus) NumSignaturesIncludedInBlock(block *types.Block) uin
// getLogger returns logger for consensus contexts added // getLogger returns logger for consensus contexts added
func (consensus *Consensus) getLogger() *zerolog.Logger { func (consensus *Consensus) getLogger() *zerolog.Logger {
logger := utils.Logger().With(). logger := utils.Logger().With().
Uint64("myBlock", consensus.blockNum). Uint64("myBlock", consensus.BlockNum()).
Uint64("myViewID", consensus.GetCurBlockViewID()). Uint64("myViewID", consensus.GetCurBlockViewID()).
Str("phase", consensus.phase.String()). Str("phase", consensus.phase.String()).
Str("mode", consensus.current.Mode().String()). Str("mode", consensus.current.Mode().String()).

@ -38,7 +38,7 @@ func TestConsensusInitialization(t *testing.T) {
// FBFTLog // FBFTLog
assert.Equal(t, fbtLog, consensus.FBFTLog) assert.Equal(t, fbtLog, consensus.FBFTLog)
assert.Equal(t, FBFTAnnounce, consensus.phase) assert.Equal(t, FBFTAnnounce, consensus.phase.Get())
// State / consensus.current // State / consensus.current
assert.Equal(t, state.mode, consensus.current.mode) assert.Equal(t, state.mode, consensus.current.mode)

@ -135,7 +135,7 @@ func (consensus *Consensus) finalCommit() {
consensus.getLogger().Info(). consensus.getLogger().Info().
Int64("NumCommits", numCommits). Int64("NumCommits", numCommits).
Msg("[finalCommit] Finalizing Consensus") Msg("[finalCommit] Finalizing Consensus")
beforeCatchupNum := consensus.blockNum beforeCatchupNum := consensus.BlockNum()
leaderPriKey, err := consensus.GetConsensusLeaderPrivateKey() leaderPriKey, err := consensus.GetConsensusLeaderPrivateKey()
if err != nil { if err != nil {
@ -188,7 +188,7 @@ func (consensus *Consensus) finalCommit() {
} else { } else {
consensus.getLogger().Info(). consensus.getLogger().Info().
Hex("blockHash", curBlockHash[:]). Hex("blockHash", curBlockHash[:]).
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Msg("[finalCommit] Sent Committed Message") Msg("[finalCommit] Sent Committed Message")
} }
consensus.getLogger().Info().Msg("[finalCommit] Start consensus timer") consensus.getLogger().Info().Msg("[finalCommit] Start consensus timer")
@ -203,7 +203,7 @@ func (consensus *Consensus) finalCommit() {
p2p.ConstructMessage(msgToSend)) p2p.ConstructMessage(msgToSend))
consensus.getLogger().Info(). consensus.getLogger().Info().
Hex("blockHash", curBlockHash[:]). Hex("blockHash", curBlockHash[:]).
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Hex("lastCommitSig", commitSigAndBitmap). Hex("lastCommitSig", commitSigAndBitmap).
Msg("[finalCommit] Queued Committed Message") Msg("[finalCommit] Queued Committed Message")
} }
@ -211,7 +211,7 @@ func (consensus *Consensus) finalCommit() {
block.SetCurrentCommitSig(commitSigAndBitmap) block.SetCurrentCommitSig(commitSigAndBitmap)
err = consensus.commitBlock(block, FBFTMsg) err = consensus.commitBlock(block, FBFTMsg)
if err != nil || consensus.blockNum-beforeCatchupNum != 1 { if err != nil || consensus.BlockNum()-beforeCatchupNum != 1 {
consensus.getLogger().Err(err). consensus.getLogger().Err(err).
Uint64("beforeCatchupBlockNum", beforeCatchupNum). Uint64("beforeCatchupBlockNum", beforeCatchupNum).
Msg("[finalCommit] Leader failed to commit the confirmed block") Msg("[finalCommit] Leader failed to commit the confirmed block")
@ -264,7 +264,7 @@ func (consensus *Consensus) finalCommit() {
// BlockCommitSigs returns the byte array of aggregated // BlockCommitSigs returns the byte array of aggregated
// commit signature and bitmap signed on the block // commit signature and bitmap signed on the block
func (consensus *Consensus) BlockCommitSigs(blockNum uint64) ([]byte, error) { func (consensus *Consensus) BlockCommitSigs(blockNum uint64) ([]byte, error) {
if consensus.blockNum <= 1 { if consensus.BlockNum() <= 1 {
return nil, nil return nil, nil
} }
lastCommits, err := consensus.Blockchain.ReadCommitSig(blockNum) lastCommits, err := consensus.Blockchain.ReadCommitSig(blockNum)
@ -363,7 +363,7 @@ func (consensus *Consensus) Start(
case <-consensus.syncReadyChan: case <-consensus.syncReadyChan:
consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncReadyChan") consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncReadyChan")
consensus.mutex.Lock() consensus.mutex.Lock()
if consensus.blockNum < consensus.Blockchain.CurrentHeader().Number().Uint64()+1 { if consensus.BlockNum() < consensus.Blockchain.CurrentHeader().Number().Uint64()+1 {
consensus.SetBlockNum(consensus.Blockchain.CurrentHeader().Number().Uint64() + 1) consensus.SetBlockNum(consensus.Blockchain.CurrentHeader().Number().Uint64() + 1)
consensus.SetViewIDs(consensus.Blockchain.CurrentHeader().ViewID().Uint64() + 1) consensus.SetViewIDs(consensus.Blockchain.CurrentHeader().ViewID().Uint64() + 1)
mode := consensus.UpdateConsensusInformation() mode := consensus.UpdateConsensusInformation()
@ -396,7 +396,7 @@ func (consensus *Consensus) Start(
Uint64("MsgBlockNum", newBlock.NumberU64()). Uint64("MsgBlockNum", newBlock.NumberU64()).
Msg("[ConsensusMainLoop] Received Proposed New Block!") Msg("[ConsensusMainLoop] Received Proposed New Block!")
if newBlock.NumberU64() < consensus.blockNum { if newBlock.NumberU64() < consensus.BlockNum() {
consensus.getLogger().Warn().Uint64("newBlockNum", newBlock.NumberU64()). consensus.getLogger().Warn().Uint64("newBlockNum", newBlock.NumberU64()).
Msg("[ConsensusMainLoop] received old block, abort") Msg("[ConsensusMainLoop] received old block, abort")
continue continue
@ -443,7 +443,7 @@ func (consensus *Consensus) Close() error {
// waitForCommit wait extra 2 seconds for commit phase to finish // waitForCommit wait extra 2 seconds for commit phase to finish
func (consensus *Consensus) waitForCommit() { func (consensus *Consensus) waitForCommit() {
if consensus.Mode() != Normal || consensus.phase != FBFTCommit { if consensus.Mode() != Normal || consensus.phase.Get() != FBFTCommit {
return return
} }
// We only need to wait consensus is in normal commit phase // We only need to wait consensus is in normal commit phase
@ -569,7 +569,7 @@ func (consensus *Consensus) preCommitAndPropose(blk *types.Block) error {
} else { } else {
consensus.getLogger().Info(). consensus.getLogger().Info().
Str("blockHash", blk.Hash().Hex()). Str("blockHash", blk.Hash().Hex()).
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Hex("lastCommitSig", bareMinimumCommit). Hex("lastCommitSig", bareMinimumCommit).
Msg("[preCommitAndPropose] Sent Committed Message") Msg("[preCommitAndPropose] Sent Committed Message")
} }
@ -621,7 +621,7 @@ func (consensus *Consensus) tryCatchup() error {
if consensus.BlockVerifier == nil { if consensus.BlockVerifier == nil {
return errors.New("consensus haven't finished initialization") return errors.New("consensus haven't finished initialization")
} }
initBN := consensus.blockNum initBN := consensus.BlockNum()
defer consensus.postCatchup(initBN) defer consensus.postCatchup(initBN)
blks, msgs, err := consensus.getLastMileBlocksAndMsg(initBN) blks, msgs, err := consensus.getLastMileBlocksAndMsg(initBN)
@ -683,7 +683,9 @@ func (consensus *Consensus) commitBlock(blk *types.Block, committedMsg *FBFTMess
func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg *FBFTMessage) { func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg *FBFTMessage) {
atomic.StoreUint64(&consensus.blockNum, blk.NumberU64()+1) atomic.StoreUint64(&consensus.blockNum, blk.NumberU64()+1)
consensus.SetCurBlockViewID(committedMsg.ViewID + 1) consensus.SetCurBlockViewID(committedMsg.ViewID + 1)
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = committedMsg.SenderPubkeys[0] consensus.LeaderPubKey = committedMsg.SenderPubkeys[0]
consensus.pubKeyLock.Unlock()
// Update consensus keys at last so the change of leader status doesn't mess up normal flow // Update consensus keys at last so the change of leader status doesn't mess up normal flow
if blk.IsLastBlockInEpoch() { if blk.IsLastBlockInEpoch() {
consensus.SetMode(consensus.UpdateConsensusInformation()) consensus.SetMode(consensus.UpdateConsensusInformation())
@ -693,15 +695,15 @@ func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg
} }
func (consensus *Consensus) postCatchup(initBN uint64) { func (consensus *Consensus) postCatchup(initBN uint64) {
if initBN < consensus.blockNum { if initBN < consensus.BlockNum() {
consensus.getLogger().Info(). consensus.getLogger().Info().
Uint64("From", initBN). Uint64("From", initBN).
Uint64("To", consensus.blockNum). Uint64("To", consensus.BlockNum()).
Msg("[TryCatchup] Caught up!") Msg("[TryCatchup] Caught up!")
consensus.switchPhase("TryCatchup", FBFTAnnounce) consensus.switchPhase("TryCatchup", FBFTAnnounce)
} }
// catch up and skip from view change trap // catch up and skip from view change trap
if initBN < consensus.blockNum && consensus.IsViewChangingMode() { if initBN < consensus.BlockNum() && consensus.IsViewChangingMode() {
consensus.current.SetMode(Normal) consensus.current.SetMode(Normal)
consensus.consensusTimeout[timeoutViewChange].Stop() consensus.consensusTimeout[timeoutViewChange].Stop()
} }

@ -30,7 +30,7 @@ func (consensus *Consensus) populateMessageFields(
request *msg_pb.ConsensusRequest, blockHash []byte, request *msg_pb.ConsensusRequest, blockHash []byte,
) *msg_pb.ConsensusRequest { ) *msg_pb.ConsensusRequest {
request.ViewId = consensus.GetCurBlockViewID() request.ViewId = consensus.GetCurBlockViewID()
request.BlockNum = consensus.blockNum request.BlockNum = consensus.BlockNum()
request.ShardId = consensus.ShardID request.ShardId = consensus.ShardID
// 32 byte block hash // 32 byte block hash
request.BlockHash = blockHash request.BlockHash = blockHash

@ -2,6 +2,7 @@ package consensus
import ( import (
"bytes" "bytes"
"sync/atomic"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -86,7 +87,7 @@ func TestConstructPreparedMessage(test *testing.T) {
[]*bls.PublicKeyWrapper{&leaderKeyWrapper}, []*bls.PublicKeyWrapper{&leaderKeyWrapper},
leaderPriKey.Sign(message), leaderPriKey.Sign(message),
common.BytesToHash(consensus.blockHash[:]), common.BytesToHash(consensus.blockHash[:]),
consensus.blockNum, consensus.BlockNum(),
consensus.GetCurBlockViewID(), consensus.GetCurBlockViewID(),
) )
if _, err := consensus.Decider.AddNewVote( if _, err := consensus.Decider.AddNewVote(
@ -94,7 +95,7 @@ func TestConstructPreparedMessage(test *testing.T) {
[]*bls.PublicKeyWrapper{&validatorKeyWrapper}, []*bls.PublicKeyWrapper{&validatorKeyWrapper},
validatorPriKey.Sign(message), validatorPriKey.Sign(message),
common.BytesToHash(consensus.blockHash[:]), common.BytesToHash(consensus.blockHash[:]),
consensus.blockNum, consensus.BlockNum(),
consensus.GetCurBlockViewID(), consensus.GetCurBlockViewID(),
); err != nil { ); err != nil {
test.Log(err) test.Log(err)
@ -156,7 +157,7 @@ func TestConstructPrepareMessage(test *testing.T) {
consensus.SetCurBlockViewID(2) consensus.SetCurBlockViewID(2)
consensus.blockHash = [32]byte{} consensus.blockHash = [32]byte{}
copy(consensus.blockHash[:], []byte("random")) copy(consensus.blockHash[:], []byte("random"))
consensus.blockNum = 1000 atomic.StoreUint64(&consensus.blockNum, 1000)
sig := priKeyWrapper1.Pri.SignHash(consensus.blockHash[:]) sig := priKeyWrapper1.Pri.SignHash(consensus.blockHash[:])
network, err := consensus.construct(msg_pb.MessageType_PREPARE, nil, []*bls.PrivateKeyWrapper{&priKeyWrapper1}) network, err := consensus.construct(msg_pb.MessageType_PREPARE, nil, []*bls.PrivateKeyWrapper{&priKeyWrapper1})
@ -248,7 +249,7 @@ func TestConstructCommitMessage(test *testing.T) {
consensus.SetCurBlockViewID(2) consensus.SetCurBlockViewID(2)
consensus.blockHash = [32]byte{} consensus.blockHash = [32]byte{}
copy(consensus.blockHash[:], []byte("random")) copy(consensus.blockHash[:], []byte("random"))
consensus.blockNum = 1000 atomic.StoreUint64(&consensus.blockNum, 1000)
sigPayload := []byte("payload") sigPayload := []byte("payload")

@ -1,26 +1,21 @@
package consensus package consensus
// GetConsensusPhase returns the current phase of the consensus // GetConsensusPhase returns the current phase of the consensus
func (c *Consensus) GetConsensusPhase() string { func (consensus *Consensus) GetConsensusPhase() string {
return c.phase.String() return consensus.phase.String()
} }
// GetConsensusMode returns the current mode of the consensus // GetConsensusMode returns the current mode of the consensus
func (c *Consensus) GetConsensusMode() string { func (consensus *Consensus) GetConsensusMode() string {
return c.current.mode.String() return consensus.current.mode.String()
} }
// GetCurBlockViewID returns the current view ID of the consensus // GetCurBlockViewID returns the current view ID of the consensus
func (c *Consensus) GetCurBlockViewID() uint64 { func (consensus *Consensus) GetCurBlockViewID() uint64 {
return c.current.GetCurBlockViewID() return consensus.current.GetCurBlockViewID()
} }
// GetViewChangingID returns the current view changing ID of the consensus // GetViewChangingID returns the current view changing ID of the consensus
func (c *Consensus) GetViewChangingID() uint64 { func (consensus *Consensus) GetViewChangingID() uint64 {
return c.current.GetViewChangingID() return consensus.current.GetViewChangingID()
}
// GetBlockNum return the current blockNum of the consensus struct
func (c *Consensus) GetBlockNum() uint64 {
return c.blockNum
} }

@ -137,7 +137,7 @@ func (consensus *Consensus) checkDoubleSign(recvMsg *FBFTMessage) bool {
func (consensus *Consensus) couldThisBeADoubleSigner( func (consensus *Consensus) couldThisBeADoubleSigner(
recvMsg *FBFTMessage, recvMsg *FBFTMessage,
) bool { ) bool {
num, hash := consensus.blockNum, recvMsg.BlockHash num, hash := consensus.BlockNum(), recvMsg.BlockHash
suspicious := !consensus.FBFTLog.HasMatchingAnnounce(num, hash) || suspicious := !consensus.FBFTLog.HasMatchingAnnounce(num, hash) ||
!consensus.FBFTLog.HasMatchingPrepared(num, hash) !consensus.FBFTLog.HasMatchingPrepared(num, hash)
if suspicious { if suspicious {

@ -78,7 +78,7 @@ func (consensus *Consensus) announce(block *types.Block) {
} }
// Construct broadcast p2p message // Construct broadcast p2p message
if err := consensus.msgSender.SendWithRetry( if err := consensus.msgSender.SendWithRetry(
consensus.blockNum, msg_pb.MessageType_ANNOUNCE, []nodeconfig.GroupID{ consensus.BlockNum(), msg_pb.MessageType_ANNOUNCE, []nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)), nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)),
}, p2p.ConstructMessage(msgToSend)); err != nil { }, p2p.ConstructMessage(msgToSend)); err != nil {
consensus.getLogger().Warn(). consensus.getLogger().Warn().
@ -99,7 +99,7 @@ func (consensus *Consensus) announce(block *types.Block) {
func (consensus *Consensus) onPrepare(recvMsg *FBFTMessage) { func (consensus *Consensus) onPrepare(recvMsg *FBFTMessage) {
// TODO(audit): make FBFT lookup using map instead of looping through all items. // TODO(audit): make FBFT lookup using map instead of looping through all items.
if !consensus.FBFTLog.HasMatchingViewAnnounce( if !consensus.FBFTLog.HasMatchingViewAnnounce(
consensus.blockNum, consensus.GetCurBlockViewID(), recvMsg.BlockHash, consensus.BlockNum(), consensus.GetCurBlockViewID(), recvMsg.BlockHash,
) { ) {
consensus.getLogger().Debug(). consensus.getLogger().Debug().
Uint64("MsgViewID", recvMsg.ViewID). Uint64("MsgViewID", recvMsg.ViewID).

@ -42,7 +42,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error {
if err := rlp.DecodeBytes(consensus.block, &blockObj); err != nil { if err := rlp.DecodeBytes(consensus.block, &blockObj); err != nil {
consensus.getLogger().Warn(). consensus.getLogger().Warn().
Err(err). Err(err).
Uint64("BlockNum", consensus.blockNum). Uint64("BlockNum", consensus.BlockNum()).
Msg("[didReachPrepareQuorum] Unparseable block data") Msg("[didReachPrepareQuorum] Unparseable block data")
return err return err
} }
@ -69,7 +69,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error {
} }
} }
if err := consensus.msgSender.SendWithRetry( if err := consensus.msgSender.SendWithRetry(
consensus.blockNum, consensus.BlockNum(),
msg_pb.MessageType_PREPARED, []nodeconfig.GroupID{ msg_pb.MessageType_PREPARED, []nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)), nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)),
}, },
@ -79,7 +79,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error {
} else { } else {
consensus.getLogger().Info(). consensus.getLogger().Info().
Hex("blockHash", consensus.blockHash[:]). Hex("blockHash", consensus.blockHash[:]).
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Msg("[OnPrepare] Sent Prepared Message!!") Msg("[OnPrepare] Sent Prepared Message!!")
} }
consensus.msgSender.StopRetry(msg_pb.MessageType_ANNOUNCE) consensus.msgSender.StopRetry(msg_pb.MessageType_ANNOUNCE)

@ -175,7 +175,7 @@ func (consensus *Consensus) sendCommitMessages(blockObj *types.Block) {
consensus.getLogger().Warn().Err(err).Msg("[sendCommitMessages] Cannot send commit message!!") consensus.getLogger().Warn().Err(err).Msg("[sendCommitMessages] Cannot send commit message!!")
} else { } else {
consensus.getLogger().Info(). consensus.getLogger().Info().
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Hex("blockHash", consensus.blockHash[:]). Hex("blockHash", consensus.blockHash[:]).
Msg("[sendCommitMessages] Sent Commit Message!!") Msg("[sendCommitMessages] Sent Commit Message!!")
} }
@ -192,14 +192,14 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) {
Uint64("MsgViewID", recvMsg.ViewID). Uint64("MsgViewID", recvMsg.ViewID).
Msg("[OnPrepared] Received prepared message") Msg("[OnPrepared] Received prepared message")
if recvMsg.BlockNum < consensus.blockNum { if recvMsg.BlockNum < consensus.BlockNum() {
consensus.getLogger().Info().Uint64("MsgBlockNum", recvMsg.BlockNum). consensus.getLogger().Info().Uint64("MsgBlockNum", recvMsg.BlockNum).
Msg("Wrong BlockNum Received, ignoring!") Msg("Wrong BlockNum Received, ignoring!")
return return
} }
if recvMsg.BlockNum > consensus.blockNum { if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Warn(). consensus.getLogger().Warn().
Uint64("myBlockNum", consensus.blockNum). Uint64("myBlockNum", consensus.BlockNum()).
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Hex("myBlockHash", consensus.blockHash[:]). Hex("myBlockHash", consensus.blockHash[:]).
Hex("MsgBlockHash", recvMsg.BlockHash[:]). Hex("MsgBlockHash", recvMsg.BlockHash[:]).
@ -245,10 +245,10 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) {
} }
return return
} }
if recvMsg.BlockNum > consensus.blockNum { if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Info(). consensus.getLogger().Info().
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("blockNum", consensus.blockNum). Uint64("blockNum", consensus.BlockNum()).
Msg("[OnPrepared] Future Block Received, ignoring!!") Msg("[OnPrepared] Future Block Received, ignoring!!")
return return
} }
@ -274,12 +274,12 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) {
if blockObj == nil { if blockObj == nil {
return return
} }
curBlockNum := consensus.blockNum curBlockNum := consensus.BlockNum()
for _, committedMsg := range consensus.FBFTLog.GetNotVerifiedCommittedMessages(blockObj.NumberU64(), blockObj.Header().ViewID().Uint64(), blockObj.Hash()) { for _, committedMsg := range consensus.FBFTLog.GetNotVerifiedCommittedMessages(blockObj.NumberU64(), blockObj.Header().ViewID().Uint64(), blockObj.Hash()) {
if committedMsg != nil { if committedMsg != nil {
consensus.onCommitted(committedMsg) consensus.onCommitted(committedMsg)
} }
if curBlockNum < consensus.blockNum { if curBlockNum < consensus.BlockNum() {
consensus.getLogger().Info().Msg("[OnPrepared] Successfully caught up with committed message") consensus.getLogger().Info().Msg("[OnPrepared] Successfully caught up with committed message")
break break
} }
@ -297,16 +297,16 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) {
Msg("[OnCommitted] Received committed message") Msg("[OnCommitted] Received committed message")
// Ok to receive committed from last block since it could have more signatures // Ok to receive committed from last block since it could have more signatures
if recvMsg.BlockNum < consensus.blockNum-1 { if recvMsg.BlockNum < consensus.BlockNum()-1 {
consensus.getLogger().Info(). consensus.getLogger().Info().
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Msg("Wrong BlockNum Received, ignoring!") Msg("Wrong BlockNum Received, ignoring!")
return return
} }
if recvMsg.BlockNum > consensus.blockNum { if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Info(). consensus.getLogger().Info().
Uint64("myBlockNum", consensus.blockNum). Uint64("myBlockNum", consensus.BlockNum()).
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Hex("myBlockHash", consensus.blockHash[:]). Hex("myBlockHash", consensus.blockHash[:]).
Hex("MsgBlockHash", recvMsg.BlockHash[:]). Hex("MsgBlockHash", recvMsg.BlockHash[:]).
@ -372,12 +372,12 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) {
} }
} }
initBn := consensus.blockNum initBn := consensus.BlockNum()
consensus.tryCatchup() consensus.tryCatchup()
if recvMsg.BlockNum > consensus.blockNum { if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Info(). consensus.getLogger().Info().
Uint64("myBlockNum", consensus.blockNum). Uint64("myBlockNum", consensus.BlockNum()).
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Hex("myBlockHash", consensus.blockHash[:]). Hex("myBlockHash", consensus.blockHash[:]).
Hex("MsgBlockHash", recvMsg.BlockHash[:]). Hex("MsgBlockHash", recvMsg.BlockHash[:]).
@ -395,7 +395,7 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) {
consensus.getLogger().Debug().Msg("[OnCommitted] stop bootstrap timer only once") consensus.getLogger().Debug().Msg("[OnCommitted] stop bootstrap timer only once")
} }
if initBn < consensus.blockNum { if initBn < consensus.BlockNum() {
consensus.getLogger().Info().Msg("[OnCommitted] Start consensus timer (new block added)") consensus.getLogger().Info().Msg("[OnCommitted] Start consensus timer (new block added)")
consensus.consensusTimeout[timeoutConsensus].Start() consensus.consensusTimeout[timeoutConsensus].Start()
} }

@ -264,7 +264,9 @@ func (consensus *Consensus) startViewChange() {
// aganist the consensus.LeaderPubKey variable. // aganist the consensus.LeaderPubKey variable.
// Ideally, we shall use another variable to keep track of the // Ideally, we shall use another variable to keep track of the
// leader pubkey in viewchange mode // leader pubkey in viewchange mode
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = consensus.getNextLeaderKey(nextViewID) consensus.LeaderPubKey = consensus.getNextLeaderKey(nextViewID)
consensus.pubKeyLock.Unlock()
consensus.getLogger().Warn(). consensus.getLogger().Warn().
Uint64("nextViewID", nextViewID). Uint64("nextViewID", nextViewID).
@ -285,7 +287,7 @@ func (consensus *Consensus) startViewChange() {
if err := consensus.vc.InitPayload( if err := consensus.vc.InitPayload(
consensus.FBFTLog, consensus.FBFTLog,
nextViewID, nextViewID,
consensus.blockNum, consensus.BlockNum(),
consensus.priKey, consensus.priKey,
members); err != nil { members); err != nil {
consensus.getLogger().Error().Err(err).Msg("[startViewChange] Init Payload Error") consensus.getLogger().Error().Err(err).Msg("[startViewChange] Init Payload Error")
@ -299,7 +301,7 @@ func (consensus *Consensus) startViewChange() {
} }
msgToSend := consensus.constructViewChangeMessage(&key) msgToSend := consensus.constructViewChangeMessage(&key)
if err := consensus.msgSender.SendWithRetry( if err := consensus.msgSender.SendWithRetry(
consensus.blockNum, consensus.BlockNum(),
msg_pb.MessageType_VIEWCHANGE, msg_pb.MessageType_VIEWCHANGE,
[]nodeconfig.GroupID{ []nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))}, nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))},
@ -325,7 +327,7 @@ func (consensus *Consensus) startNewView(viewID uint64, newLeaderPriKey *bls.Pri
} }
if err := consensus.msgSender.SendWithRetry( if err := consensus.msgSender.SendWithRetry(
consensus.blockNum, consensus.BlockNum(),
msg_pb.MessageType_NEWVIEW, msg_pb.MessageType_NEWVIEW,
[]nodeconfig.GroupID{ []nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))}, nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))},
@ -471,10 +473,10 @@ func (consensus *Consensus) onNewView(recvMsg *FBFTMessage) {
Msg("[onNewView] Received NewView Message") Msg("[onNewView] Received NewView Message")
// change view and leaderKey to keep in sync with network // change view and leaderKey to keep in sync with network
if consensus.blockNum != recvMsg.BlockNum { if consensus.BlockNum() != recvMsg.BlockNum {
consensus.getLogger().Warn(). consensus.getLogger().Warn().
Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("myBlockNum", consensus.blockNum). Uint64("myBlockNum", consensus.BlockNum()).
Msg("[onNewView] Invalid block number") Msg("[onNewView] Invalid block number")
return return
} }
@ -551,7 +553,9 @@ func (consensus *Consensus) onNewView(recvMsg *FBFTMessage) {
// newView message verified success, override my state // newView message verified success, override my state
consensus.SetViewIDs(recvMsg.ViewID) consensus.SetViewIDs(recvMsg.ViewID)
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = senderKey consensus.LeaderPubKey = senderKey
consensus.pubKeyLock.Unlock()
consensus.ResetViewChangeState() consensus.ResetViewChangeState()
consensus.msgSender.StopRetry(msg_pb.MessageType_VIEWCHANGE) consensus.msgSender.StopRetry(msg_pb.MessageType_VIEWCHANGE)

@ -24,7 +24,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra
Request: &msg_pb.Message_Viewchange{ Request: &msg_pb.Message_Viewchange{
Viewchange: &msg_pb.ViewChangeRequest{ Viewchange: &msg_pb.ViewChangeRequest{
ViewId: consensus.GetViewChangingID(), ViewId: consensus.GetViewChangingID(),
BlockNum: consensus.blockNum, BlockNum: consensus.BlockNum(),
ShardId: consensus.ShardID, ShardId: consensus.ShardID,
SenderPubkey: priKey.Pub.Bytes[:], SenderPubkey: priKey.Pub.Bytes[:],
LeaderPubkey: consensus.LeaderPubKey.Bytes[:], LeaderPubkey: consensus.LeaderPubKey.Bytes[:],
@ -33,7 +33,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra
} }
preparedMsgs := consensus.FBFTLog.GetMessagesByTypeSeq( preparedMsgs := consensus.FBFTLog.GetMessagesByTypeSeq(
msg_pb.MessageType_PREPARED, consensus.blockNum, msg_pb.MessageType_PREPARED, consensus.BlockNum(),
) )
preparedMsg := consensus.FBFTLog.FindMessageByMaxViewID(preparedMsgs) preparedMsg := consensus.FBFTLog.FindMessageByMaxViewID(preparedMsgs)
@ -107,7 +107,7 @@ func (consensus *Consensus) constructNewViewMessage(viewID uint64, priKey *bls.P
Request: &msg_pb.Message_Viewchange{ Request: &msg_pb.Message_Viewchange{
Viewchange: &msg_pb.ViewChangeRequest{ Viewchange: &msg_pb.ViewChangeRequest{
ViewId: viewID, ViewId: viewID,
BlockNum: consensus.blockNum, BlockNum: consensus.BlockNum(),
ShardId: consensus.ShardID, ShardId: consensus.ShardID,
SenderPubkey: priKey.Pub.Bytes[:], SenderPubkey: priKey.Pub.Bytes[:],
}, },

@ -43,7 +43,7 @@ func TestPhaseSwitching(t *testing.T) {
_, _, consensus, _, err := GenerateConsensusForTesting() _, _, consensus, _, err := GenerateConsensusForTesting()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, FBFTAnnounce, consensus.phase) // It's a new consensus, we should be at the FBFTAnnounce phase assert.Equal(t, FBFTAnnounce, consensus.phase.Get()) // It's a new consensus, we should be at the FBFTAnnounce phase.
switches := []phaseSwitch{ switches := []phaseSwitch{
{start: FBFTAnnounce, end: FBFTPrepare}, {start: FBFTAnnounce, end: FBFTPrepare},
@ -73,10 +73,10 @@ func TestPhaseSwitching(t *testing.T) {
func testPhaseGroupSwitching(t *testing.T, consensus *Consensus, phases []FBFTPhase, startPhase FBFTPhase, desiredPhase FBFTPhase) { func testPhaseGroupSwitching(t *testing.T, consensus *Consensus, phases []FBFTPhase, startPhase FBFTPhase, desiredPhase FBFTPhase) {
for range phases { for range phases {
consensus.switchPhase("test", desiredPhase) consensus.switchPhase("test", desiredPhase)
assert.Equal(t, desiredPhase, consensus.phase) assert.Equal(t, desiredPhase, consensus.phase.Get())
} }
assert.Equal(t, desiredPhase, consensus.phase) assert.Equal(t, desiredPhase, consensus.phase.Get())
return return
} }

@ -23,6 +23,7 @@ import (
"math/big" "math/big"
"math/rand" "math/rand"
"os" "os"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -69,9 +70,13 @@ type testBlockChain struct {
chainHeadFeed *event.Feed chainHeadFeed *event.Feed
} }
func (bc *testBlockChain) SetGasLimit(value uint64) {
atomic.StoreUint64(&bc.gasLimit, value)
}
func (bc *testBlockChain) CurrentBlock() *types.Block { func (bc *testBlockChain) CurrentBlock() *types.Block {
return types.NewBlock(blockfactory.NewTestHeader().With(). return types.NewBlock(blockfactory.NewTestHeader().With().
GasLimit(bc.gasLimit). GasLimit(atomic.LoadUint64(&bc.gasLimit)).
Header(), nil, nil, nil, nil, nil) Header(), nil, nil, nil, nil, nil)
} }
@ -162,12 +167,14 @@ func createBlockChain() *BlockChainImpl {
return blockchain return blockchain
} }
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { func setupTxPool(chain blockChain) (*TxPool, *ecdsa.PrivateKey) {
if chain == nil {
statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()))
blockchain := &testBlockChain{statedb, 1e18, new(event.Feed)} chain = &testBlockChain{statedb, 1e18, new(event.Feed)}
}
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, dummyErrorSink) pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, chain, dummyErrorSink)
return pool, key return pool, key
} }
@ -282,7 +289,7 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) {
func TestInvalidTransactions(t *testing.T) { func TestInvalidTransactions(t *testing.T) {
t.Parallel() t.Parallel()
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
tx := transaction(0, 0, 100, key) tx := transaction(0, 0, 100, key)
@ -324,8 +331,7 @@ func TestInvalidTransactions(t *testing.T) {
func TestErrorSink(t *testing.T) { func TestErrorSink(t *testing.T) {
t.Parallel() t.Parallel()
pool, key := setupTxPool() pool, key := setupTxPool(createBlockChain())
pool.chain = createBlockChain()
defer pool.Stop() defer pool.Stop()
testTxErrorSink := types.NewTransactionErrorSink() testTxErrorSink := types.NewTransactionErrorSink()
@ -396,8 +402,7 @@ func TestErrorSink(t *testing.T) {
func TestCreateValidatorTransaction(t *testing.T) { func TestCreateValidatorTransaction(t *testing.T) {
t.Parallel() t.Parallel()
pool, _ := setupTxPool() pool, _ := setupTxPool(createBlockChain())
pool.chain = createBlockChain()
defer pool.Stop() defer pool.Stop()
fromKey, _ := crypto.GenerateKey() fromKey, _ := crypto.GenerateKey()
@ -422,8 +427,7 @@ func TestCreateValidatorTransaction(t *testing.T) {
func TestMixedTransactions(t *testing.T) { func TestMixedTransactions(t *testing.T) {
t.Parallel() t.Parallel()
pool, _ := setupTxPool() pool, _ := setupTxPool(createBlockChain())
pool.chain = createBlockChain()
defer pool.Stop() defer pool.Stop()
fromKey, _ := crypto.GenerateKey() fromKey, _ := crypto.GenerateKey()
@ -457,7 +461,7 @@ func TestBlacklistedTransactions(t *testing.T) {
// DO NOT parallelize, test will add accounts to tx pool config. // DO NOT parallelize, test will add accounts to tx pool config.
// Create the pool // Create the pool
pool, _ := setupTxPool() pool, _ := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
// Create testing keys // Create testing keys
@ -501,7 +505,7 @@ func TestBlacklistedTransactions(t *testing.T) {
func TestTransactionQueue(t *testing.T) { func TestTransactionQueue(t *testing.T) {
t.Parallel() t.Parallel()
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
tx := transaction(0, 0, 100, key) tx := transaction(0, 0, 100, key)
@ -528,7 +532,7 @@ func TestTransactionQueue(t *testing.T) {
t.Error("expected transaction queue to be empty. is", len(pool.queue)) t.Error("expected transaction queue to be empty. is", len(pool.queue))
} }
pool, key = setupTxPool() pool, key = setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
tx1 := transaction(0, 0, 100, key) tx1 := transaction(0, 0, 100, key)
@ -555,7 +559,7 @@ func TestTransactionQueue(t *testing.T) {
func TestTransactionNegativeValue(t *testing.T) { func TestTransactionNegativeValue(t *testing.T) {
t.Parallel() t.Parallel()
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
tx, _ := types.SignTx( tx, _ := types.SignTx(
@ -569,9 +573,10 @@ func TestTransactionNegativeValue(t *testing.T) {
} }
func TestTransactionChainFork(t *testing.T) { func TestTransactionChainFork(t *testing.T) {
t.Skip("This test doesn't work with race detector")
t.Parallel() t.Parallel()
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
@ -600,18 +605,13 @@ func TestTransactionChainFork(t *testing.T) {
func TestTransactionDoubleNonce(t *testing.T) { func TestTransactionDoubleNonce(t *testing.T) {
t.Parallel() t.Parallel()
pool, key := setupTxPool() key, _ := crypto.GenerateKey()
defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() {
statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()))
statedb.AddBalance(addr, big.NewInt(1000000000000000000)) statedb.AddBalance(addr, big.NewInt(1000000000000000000))
pool, _ := setupTxPool(&testBlockChain{statedb, 1000000, new(event.Feed)})
pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)} defer pool.Stop()
pool.lockedReset(nil, nil) pool.lockedReset(nil, nil)
}
resetState()
signer := types.HomesteadSigner{} signer := types.HomesteadSigner{}
tx1, _ := types.SignTx( tx1, _ := types.SignTx(
@ -656,7 +656,7 @@ func TestTransactionDoubleNonce(t *testing.T) {
func TestTransactionMissingNonce(t *testing.T) { func TestTransactionMissingNonce(t *testing.T) {
t.Parallel() t.Parallel()
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
@ -680,7 +680,7 @@ func TestTransactionNonceRecovery(t *testing.T) {
t.Parallel() t.Parallel()
const n = 10 const n = 10
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
@ -706,7 +706,7 @@ func TestTransactionDropping(t *testing.T) {
t.Parallel() t.Parallel()
// Create a test account and fund it // Create a test account and fund it
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key)) account, _ := deriveSender(transaction(0, 0, 0, key))
@ -774,7 +774,7 @@ func TestTransactionDropping(t *testing.T) {
t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4) t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4)
} }
// Reduce the block gas limit, check that invalidated transactions are dropped // Reduce the block gas limit, check that invalidated transactions are dropped
pool.chain.(*testBlockChain).gasLimit = 100 pool.chain.(*testBlockChain).SetGasLimit(100)
pool.lockedReset(nil, nil) pool.lockedReset(nil, nil)
if _, ok := pool.pending[account].txs.items[tx0.Nonce()]; !ok { if _, ok := pool.pending[account].txs.items[tx0.Nonce()]; !ok {
@ -918,7 +918,7 @@ func TestTransactionQueueAccountLimiting(t *testing.T) {
t.Parallel() t.Parallel()
// Create a test account and fund it // Create a test account and fund it
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key)) account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1119,7 +1119,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
t.Parallel() t.Parallel()
// Add a batch of transactions to a pool one by one // Add a batch of transactions to a pool one by one
pool1, key1 := setupTxPool() pool1, key1 := setupTxPool(nil)
defer pool1.Stop() defer pool1.Stop()
account1, _ := deriveSender(transaction(0, 0, 0, key1)) account1, _ := deriveSender(transaction(0, 0, 0, key1))
@ -1131,7 +1131,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
} }
} }
// Add a batch of transactions to a pool in one big batch // Add a batch of transactions to a pool in one big batch
pool2, key2 := setupTxPool() pool2, key2 := setupTxPool(nil)
defer pool2.Stop() defer pool2.Stop()
account2, _ := deriveSender(transaction(0, 0, 0, key2)) account2, _ := deriveSender(transaction(0, 0, 0, key2))
@ -1523,7 +1523,7 @@ func BenchmarkPendingDemotion10000(b *testing.B) { benchmarkPendingDemotion(b, 1
func benchmarkPendingDemotion(b *testing.B, size int) { func benchmarkPendingDemotion(b *testing.B, size int) {
// Add a batch of transactions to a pool one by one // Add a batch of transactions to a pool one by one
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key)) account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1548,7 +1548,7 @@ func BenchmarkFuturePromotion10000(b *testing.B) { benchmarkFuturePromotion(b, 1
func benchmarkFuturePromotion(b *testing.B, size int) { func benchmarkFuturePromotion(b *testing.B, size int) {
// Add a batch of transactions to a pool one by one // Add a batch of transactions to a pool one by one
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key)) account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1568,7 +1568,7 @@ func benchmarkFuturePromotion(b *testing.B, size int) {
// Benchmarks the speed of iterative transaction insertion. // Benchmarks the speed of iterative transaction insertion.
func BenchmarkPoolInsert(b *testing.B) { func BenchmarkPoolInsert(b *testing.B) {
// Generate a batch of transactions to enqueue into the pool // Generate a batch of transactions to enqueue into the pool
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key)) account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1592,7 +1592,7 @@ func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 1
func benchmarkPoolBatchInsert(b *testing.B, size int) { func benchmarkPoolBatchInsert(b *testing.B, size int) {
// Generate a batch of transactions to enqueue into the pool // Generate a batch of transactions to enqueue into the pool
pool, key := setupTxPool() pool, key := setupTxPool(nil)
defer pool.Stop() defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key)) account, _ := deriveSender(transaction(0, 0, 0, key))

@ -227,6 +227,7 @@ type Block struct {
ReceivedAt time.Time ReceivedAt time.Time
ReceivedFrom interface{} ReceivedFrom interface{}
commitLock sync.Mutex
// Commit Signatures/Bitmap // Commit Signatures/Bitmap
commitSigAndBitmap []byte commitSigAndBitmap []byte
} }
@ -264,11 +265,15 @@ func (b *Block) SetCurrentCommitSig(sigAndBitmap []byte) {
Int("dstLen", len(b.header.LastCommitSignature())). Int("dstLen", len(b.header.LastCommitSignature())).
Msg("SetCurrentCommitSig: sig size mismatch") Msg("SetCurrentCommitSig: sig size mismatch")
} }
b.commitLock.Lock()
b.commitSigAndBitmap = sigAndBitmap b.commitSigAndBitmap = sigAndBitmap
b.commitLock.Unlock()
} }
// GetCurrentCommitSig get the commit group signature that signed on this block. // GetCurrentCommitSig get the commit group signature that signed on this block.
func (b *Block) GetCurrentCommitSig() []byte { func (b *Block) GetCurrentCommitSig() []byte {
b.commitLock.Lock()
defer b.commitLock.Unlock()
return b.commitSigAndBitmap return b.commitSigAndBitmap
} }

@ -0,0 +1,30 @@
package keylocker
import "sync"
type KeyLocker struct {
m sync.Map
}
func New() *KeyLocker {
return &KeyLocker{}
}
func (a *KeyLocker) Lock(key interface{}, f func() (interface{}, error)) (interface{}, error) {
mu := &sync.Mutex{}
for {
actual, _ := a.m.LoadOrStore(key, mu)
mu2 := actual.(*sync.Mutex)
mu.Lock()
if mu2 != mu {
// acquired someone else lock.
mu.Unlock()
continue
}
rs, err := f()
mu.Unlock()
a.m.Delete(key)
return rs, err
}
}

@ -0,0 +1,17 @@
package keylocker
import "testing"
func TestKeyLocker_Sequential(t *testing.T) {
n := New()
key := 1
v := make(chan struct{}, 1)
n.Lock(key, func() (interface{}, error) {
v <- struct{}{}
return nil, nil
})
<-v // we know that goroutine really executed
n.Lock(key, func() (interface{}, error) {
return nil, nil
})
}

@ -1,35 +0,0 @@
package utils
import (
"sync"
"testing"
"time"
)
var NumThreads = 20
func TestSingleton(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < NumThreads; i++ {
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(time.Millisecond)
}()
}
wg.Wait()
n := 100
for i := 0; i < NumThreads; i++ {
wg.Add(1)
go func() {
defer wg.Done()
n++
time.Sleep(time.Millisecond)
}()
}
wg.Wait()
}

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"sync"
"time" "time"
) )
@ -19,6 +20,7 @@ type Timeout struct {
state TimeoutState state TimeoutState
d time.Duration d time.Duration
start time.Time start time.Time
mu sync.Mutex
} }
// NewTimeout creates a new timeout class // NewTimeout creates a new timeout class
@ -29,18 +31,24 @@ func NewTimeout(d time.Duration) *Timeout {
// Start starts the timeout clock // Start starts the timeout clock
func (timeout *Timeout) Start() { func (timeout *Timeout) Start() {
timeout.mu.Lock()
timeout.state = Active timeout.state = Active
timeout.start = time.Now() timeout.start = time.Now()
timeout.mu.Unlock()
} }
// Stop stops the timeout clock // Stop stops the timeout clock
func (timeout *Timeout) Stop() { func (timeout *Timeout) Stop() {
timeout.mu.Lock()
timeout.state = Inactive timeout.state = Inactive
timeout.start = time.Now() timeout.start = time.Now()
timeout.mu.Unlock()
} }
// CheckExpire checks whether the timeout is reached/expired // CheckExpire checks whether the timeout is reached/expired
func (timeout *Timeout) CheckExpire() bool { func (timeout *Timeout) CheckExpire() bool {
timeout.mu.Lock()
defer timeout.mu.Unlock()
if timeout.state == Active && time.Since(timeout.start) > timeout.d { if timeout.state == Active && time.Since(timeout.start) > timeout.d {
timeout.state = Expired timeout.state = Expired
} }
@ -52,17 +60,23 @@ func (timeout *Timeout) CheckExpire() bool {
// Duration returns the duration period of timeout // Duration returns the duration period of timeout
func (timeout *Timeout) Duration() time.Duration { func (timeout *Timeout) Duration() time.Duration {
timeout.mu.Lock()
defer timeout.mu.Unlock()
return timeout.d return timeout.d
} }
// SetDuration set new duration for the timer // SetDuration set new duration for the timer
func (timeout *Timeout) SetDuration(nd time.Duration) { func (timeout *Timeout) SetDuration(nd time.Duration) {
timeout.mu.Lock()
timeout.d = nd timeout.d = nd
timeout.mu.Unlock()
} }
// IsActive checks whether timeout clock is active; // IsActive checks whether timeout clock is active;
// A timeout is active means it's not stopped caused by stop // A timeout is active means it's not stopped caused by stop
// and also not expired with time elapses longer than duration from start // and also not expired with time elapses longer than duration from start
func (timeout *Timeout) IsActive() bool { func (timeout *Timeout) IsActive() bool {
timeout.mu.Lock()
defer timeout.mu.Unlock()
return timeout.state == Active return timeout.state == Active
} }

@ -125,7 +125,7 @@ func (node *Node) GetConsensusCurViewID() uint64 {
// GetConsensusBlockNum returns the current block number of the consensus // GetConsensusBlockNum returns the current block number of the consensus
func (node *Node) GetConsensusBlockNum() uint64 { func (node *Node) GetConsensusBlockNum() uint64 {
return node.Consensus.GetBlockNum() return node.Consensus.BlockNum()
} }
// GetConsensusInternal returns consensus internal data // GetConsensusInternal returns consensus internal data

@ -134,7 +134,7 @@ func (node *Node) ProposeNewBlock(commitSigs chan []byte) (*types.Block, error)
header := node.Worker.GetCurrentHeader() header := node.Worker.GetCurrentHeader()
// Update worker's current header and // Update worker's current header and
// state data in preparation to propose/process new transactions // state data in preparation to propose/process new transactions
leaderKey := node.Consensus.LeaderPubKey leaderKey := node.Consensus.GetLeaderPubKey()
var ( var (
coinbase = node.GetAddressForBLSKey(leaderKey.Object, header.Epoch()) coinbase = node.GetAddressForBLSKey(leaderKey.Object, header.Epoch())
beneficiary = coinbase beneficiary = coinbase

@ -190,8 +190,8 @@ func NewHost(cfg HostConfig) (Host, error) {
priKey: key, priKey: key,
discovery: disc, discovery: disc,
security: security, security: security,
onConnections: []ConnectCallback{}, onConnections: ConnectCallbacks{},
onDisconnects: []DisconnectCallback{}, onDisconnects: DisconnectCallbacks{},
logger: &subLogger, logger: &subLogger,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -218,8 +218,8 @@ type HostV2 struct {
security security.Security security security.Security
logger *zerolog.Logger logger *zerolog.Logger
blocklist libp2p_pubsub.Blacklist blocklist libp2p_pubsub.Blacklist
onConnections []ConnectCallback onConnections ConnectCallbacks
onDisconnects []DisconnectCallback onDisconnects DisconnectCallbacks
ctx context.Context ctx context.Context
cancel func() cancel func()
} }
@ -433,7 +433,7 @@ func (host *HostV2) ListenClose(net libp2p_network.Network, addr ma.Multiaddr) {
func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Conn) { func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Conn) {
host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer connected") host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer connected")
for _, function := range host.onConnections { for _, function := range host.onConnections.GetAll() {
if err := function(net, conn); err != nil { if err := function(net, conn); err != nil {
host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer connected callback") host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer connected callback")
} }
@ -444,7 +444,7 @@ func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Co
func (host *HostV2) Disconnected(net libp2p_network.Network, conn libp2p_network.Conn) { func (host *HostV2) Disconnected(net libp2p_network.Network, conn libp2p_network.Conn) {
host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer disconnected") host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer disconnected")
for _, function := range host.onDisconnects { for _, function := range host.onDisconnects.GetAll() {
if err := function(conn); err != nil { if err := function(conn); err != nil {
host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer disconnected callback") host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer disconnected callback")
} }
@ -462,11 +462,11 @@ func (host *HostV2) ClosedStream(net libp2p_network.Network, stream libp2p_netwo
} }
func (host *HostV2) SetConnectCallback(callback ConnectCallback) { func (host *HostV2) SetConnectCallback(callback ConnectCallback) {
host.onConnections = append(host.onConnections, callback) host.onConnections.Add(callback)
} }
func (host *HostV2) SetDisconnectCallback(callback DisconnectCallback) { func (host *HostV2) SetDisconnectCallback(callback DisconnectCallback) {
host.onDisconnects = append(host.onDisconnects, callback) host.onDisconnects.Add(callback)
} }
// NamedTopic represents pubsub topic // NamedTopic represents pubsub topic

@ -6,6 +6,7 @@ import (
"github.com/harmony-one/harmony/consensus/engine" "github.com/harmony-one/harmony/consensus/engine"
"github.com/harmony-one/harmony/core/types" "github.com/harmony-one/harmony/core/types"
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding"
"github.com/harmony-one/harmony/internal/utils/keylocker"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -20,12 +21,14 @@ type chainHelper interface {
type chainHelperImpl struct { type chainHelperImpl struct {
chain engine.ChainReader chain engine.ChainReader
schedule shardingconfig.Schedule schedule shardingconfig.Schedule
keyLocker *keylocker.KeyLocker
} }
func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl { func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl {
return &chainHelperImpl{ return &chainHelperImpl{
chain: chain, chain: chain,
schedule: schedule, schedule: schedule,
keyLocker: keylocker.New(),
} }
} }
@ -89,9 +92,8 @@ func (ch *chainHelperImpl) getBlocksByHashes(hs []common.Hash) ([]*types.Block,
return blocks, nil return blocks, nil
} }
var errBlockNotFound = errors.New("block not found")
func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) { func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) {
rs, err := ch.keyLocker.Lock(header.Number().Uint64(), func() (interface{}, error) {
b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64()) b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64())
if b == nil { if b == nil {
return nil, nil return nil, nil
@ -102,6 +104,13 @@ func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types
} }
b.SetCurrentCommitSig(commitSig) b.SetCurrentCommitSig(commitSig)
return b, nil return b, nil
})
if err != nil {
return nil, err
}
return rs.(*types.Block), nil
} }
func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) { func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) {

@ -2,6 +2,7 @@ package sync
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
@ -50,11 +51,12 @@ func TestProtocol_advertiseLoop(t *testing.T) {
time.Sleep(150 * time.Millisecond) time.Sleep(150 * time.Millisecond)
close(p.closeC) close(p.closeC)
if len(disc.advCnt) != len(p.supportedVersions()) { advCnt := disc.Extract()
t.Errorf("unexpected advertise topic count: %v / %v", len(disc.advCnt), if len(advCnt) != len(p.supportedVersions()) {
t.Errorf("unexpected advertise topic count: %v / %v", len(advCnt),
len(p.supportedVersions())) len(p.supportedVersions()))
} }
for _, cnt := range disc.advCnt { for _, cnt := range advCnt {
if cnt < 1 { if cnt < 1 {
t.Errorf("unexpected discovery count: %v", cnt) t.Errorf("unexpected discovery count: %v", cnt)
} }
@ -64,6 +66,7 @@ func TestProtocol_advertiseLoop(t *testing.T) {
type testDiscovery struct { type testDiscovery struct {
advCnt map[string]int advCnt map[string]int
sleep time.Duration sleep time.Duration
mu sync.Mutex
} }
func newTestDiscovery(discInterval time.Duration) *testDiscovery { func newTestDiscovery(discInterval time.Duration) *testDiscovery {
@ -82,10 +85,20 @@ func (disc *testDiscovery) Close() error {
} }
func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) { func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) {
disc.mu.Lock()
defer disc.mu.Unlock()
disc.advCnt[ns]++ disc.advCnt[ns]++
return disc.sleep, nil return disc.sleep, nil
} }
func (disc *testDiscovery) Extract() map[string]int {
disc.mu.Lock()
defer disc.mu.Unlock()
var out map[string]int
out, disc.advCnt = disc.advCnt, make(map[string]int)
return out
}
func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) {
return nil, nil return nil, nil
} }

@ -0,0 +1,41 @@
package p2p
import "sync"
type ConnectCallbacks struct {
cbs []ConnectCallback
mu sync.RWMutex
}
func (a *ConnectCallbacks) Add(cb ConnectCallback) {
a.mu.Lock()
defer a.mu.Unlock()
a.cbs = append(a.cbs, cb)
}
func (a *ConnectCallbacks) GetAll() []ConnectCallback {
a.mu.RLock()
defer a.mu.RUnlock()
out := make([]ConnectCallback, len(a.cbs))
copy(out, a.cbs)
return out
}
type DisconnectCallbacks struct {
cbs []DisconnectCallback
mu sync.RWMutex
}
func (a *DisconnectCallbacks) Add(cb DisconnectCallback) {
a.mu.Lock()
defer a.mu.Unlock()
a.cbs = append(a.cbs, cb)
}
func (a *DisconnectCallbacks) GetAll() []DisconnectCallback {
a.mu.RLock()
defer a.mu.RUnlock()
out := make([]DisconnectCallback, len(a.cbs))
copy(out, a.cbs)
return out
}

@ -0,0 +1,37 @@
package p2p
import (
"reflect"
"testing"
libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/stretchr/testify/require"
)
func TestConnectCallbacks(t *testing.T) {
cbs := ConnectCallbacks{}
fn := func(net libp2p_network.Network, conn libp2p_network.Conn) error {
return nil
}
require.Equal(t, 0, len(cbs.GetAll()))
cbs.Add(fn)
require.Equal(t, 1, len(cbs.GetAll()))
require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(cbs.GetAll()[0]).Pointer())
}
func TestDisConnectCallbacks(t *testing.T) {
cbs := DisconnectCallbacks{}
fn := func(conn libp2p_network.Conn) error {
return nil
}
require.Equal(t, 0, len(cbs.GetAll()))
cbs.Add(fn)
require.Equal(t, 1, len(cbs.GetAll()))
require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(cbs.GetAll()[0]).Pointer())
}

@ -82,7 +82,7 @@ fi
echo "Running go test..." echo "Running go test..."
# Fix https://github.com/golang/go/issues/44129#issuecomment-788351567 # Fix https://github.com/golang/go/issues/44129#issuecomment-788351567
go get -t ./... go get -t ./...
if go test -v -count=1 ./... if go test -v -count=1 -vet=all -race ./...
then then
echo "go test succeeded." echo "go test succeeded."
else else

Loading…
Cancel
Save