diff --git a/Makefile b/Makefile index e7d6db4b4..776ecf5b7 100644 --- a/Makefile +++ b/Makefile @@ -156,4 +156,4 @@ go-vet: go vet ./... go-test: - go test ./... \ No newline at end of file + go test -vet=all -race ./... \ No newline at end of file diff --git a/consensus/checks.go b/consensus/checks.go index 219198881..ceaf9987b 100644 --- a/consensus/checks.go +++ b/consensus/checks.go @@ -57,9 +57,9 @@ func (consensus *Consensus) senderKeySanityChecks(msg *msg_pb.Message, senderKey func (consensus *Consensus) isRightBlockNumAndViewID(recvMsg *FBFTMessage, ) bool { - if recvMsg.ViewID != consensus.GetCurBlockViewID() || recvMsg.BlockNum != consensus.blockNum { + if recvMsg.ViewID != consensus.GetCurBlockViewID() || recvMsg.BlockNum != consensus.BlockNum() { consensus.getLogger().Debug(). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Str("recvMsg", recvMsg.String()). Msg("BlockNum/viewID not match") return false @@ -103,12 +103,12 @@ func (consensus *Consensus) onAnnounceSanityChecks(recvMsg *FBFTMessage) bool { } func (consensus *Consensus) isRightBlockNumCheck(recvMsg *FBFTMessage) bool { - if recvMsg.BlockNum < consensus.blockNum { + if recvMsg.BlockNum < consensus.BlockNum() { consensus.getLogger().Debug(). Uint64("MsgBlockNum", recvMsg.BlockNum). Msg("Wrong BlockNum Received, ignoring!") return false - } else if recvMsg.BlockNum-consensus.blockNum > MaxBlockNumDiff { + } else if recvMsg.BlockNum-consensus.BlockNum() > MaxBlockNumDiff { consensus.getLogger().Debug(). Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("MaxBlockNumDiff", MaxBlockNumDiff). @@ -122,7 +122,7 @@ func (consensus *Consensus) newBlockSanityChecks( blockObj *types.Block, recvMsg *FBFTMessage, ) bool { if blockObj.NumberU64() != recvMsg.BlockNum || - recvMsg.BlockNum < consensus.blockNum { + recvMsg.BlockNum < consensus.BlockNum() { consensus.getLogger().Warn(). Uint64("MsgBlockNum", recvMsg.BlockNum). Uint64("blockNum", blockObj.NumberU64()). @@ -152,12 +152,12 @@ func (consensus *Consensus) onViewChangeSanityCheck(recvMsg *FBFTMessage) bool { Interface("SendPubKeys", recvMsg.SenderPubkeys). Msg("[onViewChangeSanityCheck]") - if consensus.blockNum > recvMsg.BlockNum { + if consensus.BlockNum() > recvMsg.BlockNum { consensus.getLogger().Debug(). Msg("[onViewChange] Message BlockNum Is Low") return false } - if consensus.blockNum < recvMsg.BlockNum { + if consensus.BlockNum() < recvMsg.BlockNum { consensus.getLogger().Warn(). Msg("[onViewChangeSanityCheck] MsgBlockNum is different from my BlockNumber") return false diff --git a/consensus/consensus.go b/consensus/consensus.go index 047744f34..bdc671e97 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -3,6 +3,7 @@ package consensus import ( "fmt" "sync" + "sync/atomic" "time" "github.com/harmony-one/harmony/crypto/bls" @@ -45,7 +46,7 @@ type Consensus struct { // FBFTLog stores the pbft messages and blocks during FBFT process FBFTLog *FBFTLog // phase: different phase of FBFT protocol: pre-prepare, prepare, commit, finish etc - phase FBFTPhase + phase *LockedFBFTPhase // current indicates what state a node is in current State // isBackup declarative the node is in backup mode @@ -131,7 +132,7 @@ type Consensus struct { // finality of previous consensus in the unit of milliseconds finality int64 // finalityCounter keep tracks of the finality time - finalityCounter int64 + finalityCounter atomic.Value //int64 dHelper *downloadHelper } @@ -168,6 +169,12 @@ func (consensus *Consensus) GetPublicKeys() multibls.PublicKeys { return consensus.priKey.GetPublicKeys() } +func (consensus *Consensus) GetLeaderPubKey() *bls_cosi.PublicKeyWrapper { + consensus.pubKeyLock.Lock() + defer consensus.pubKeyLock.Unlock() + return consensus.LeaderPubKey +} + func (consensus *Consensus) GetPrivateKeys() multibls.PrivateKeys { return consensus.priKey } @@ -197,6 +204,10 @@ func (consensus *Consensus) IsBackup() bool { return consensus.isBackup } +func (consensus *Consensus) BlockNum() uint64 { + return atomic.LoadUint64(&consensus.blockNum) +} + // New create a new Consensus record func New( host p2p.Host, shard uint32, leader p2p.Peer, multiBLSPriKey multibls.PrivateKeys, @@ -209,7 +220,7 @@ func New( consensus.BlockNumLowChan = make(chan struct{}, 1) // FBFT related consensus.FBFTLog = NewFBFTLog() - consensus.phase = FBFTAnnounce + consensus.phase = NewLockedFBFTPhase(FBFTAnnounce) consensus.current = State{mode: Normal} // FBFT timeout consensus.consensusTimeout = createTimeout() diff --git a/consensus/consensus_fbft.go b/consensus/consensus_fbft.go new file mode 100644 index 000000000..313abf061 --- /dev/null +++ b/consensus/consensus_fbft.go @@ -0,0 +1,30 @@ +package consensus + +import "sync" + +type LockedFBFTPhase struct { + mu sync.Mutex + phase FBFTPhase +} + +func NewLockedFBFTPhase(initialPhrase FBFTPhase) *LockedFBFTPhase { + return &LockedFBFTPhase{ + phase: initialPhrase, + } +} + +func (a *LockedFBFTPhase) Set(phrase FBFTPhase) { + a.mu.Lock() + a.phase = phrase + a.mu.Unlock() +} + +func (a *LockedFBFTPhase) Get() FBFTPhase { + a.mu.Lock() + defer a.mu.Unlock() + return a.phase +} + +func (a *LockedFBFTPhase) String() string { + return a.Get().String() +} diff --git a/consensus/consensus_fbft_test.go b/consensus/consensus_fbft_test.go new file mode 100644 index 000000000..a84cc3c83 --- /dev/null +++ b/consensus/consensus_fbft_test.go @@ -0,0 +1,15 @@ +package consensus + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLockedFBFTPhase(t *testing.T) { + s := NewLockedFBFTPhase(FBFTAnnounce) + require.Equal(t, FBFTAnnounce, s.Get()) + + s.Set(FBFTCommit) + require.Equal(t, FBFTCommit, s.Get()) +} diff --git a/consensus/consensus_service.go b/consensus/consensus_service.go index 2c59a311f..4b01fe07e 100644 --- a/consensus/consensus_service.go +++ b/consensus/consensus_service.go @@ -403,7 +403,9 @@ func (consensus *Consensus) UpdateConsensusInformation() Mode { consensus.getLogger().Info(). Str("leaderPubKey", leaderPubKey.Bytes.Hex()). Msg("[UpdateConsensusInformation] Most Recent LeaderPubKey Updated Based on BlockChain") + consensus.pubKeyLock.Lock() consensus.LeaderPubKey = leaderPubKey + consensus.pubKeyLock.Unlock() } } @@ -442,8 +444,11 @@ func (consensus *Consensus) UpdateConsensusInformation() Mode { // IsLeader check if the node is a leader or not by comparing the public key of // the node with the leader public key func (consensus *Consensus) IsLeader() bool { + consensus.pubKeyLock.Lock() + obj := consensus.LeaderPubKey.Object + consensus.pubKeyLock.Unlock() for _, key := range consensus.priKey { - if key.Pub.Object.IsEqual(consensus.LeaderPubKey.Object) { + if key.Pub.Object.IsEqual(obj) { return true } } @@ -469,13 +474,13 @@ func (consensus *Consensus) SetViewChangingID(viewID uint64) { // StartFinalityCount set the finality counter to current time func (consensus *Consensus) StartFinalityCount() { - consensus.finalityCounter = time.Now().UnixNano() + consensus.finalityCounter.Store(time.Now().UnixNano()) } // FinishFinalityCount calculate the current finality func (consensus *Consensus) FinishFinalityCount() { d := time.Now().UnixNano() - consensus.finality = (d - consensus.finalityCounter) / 1000000 + consensus.finality = (d - consensus.finalityCounter.Load().(int64)) / 1000000 consensusFinalityHistogram.Observe(float64(consensus.finality)) } @@ -491,7 +496,7 @@ func (consensus *Consensus) switchPhase(subject string, desired FBFTPhase) { Str("to:", desired.String()). Str("switchPhase:", subject) - consensus.phase = desired + consensus.phase.Set(desired) return } @@ -580,7 +585,7 @@ func (consensus *Consensus) NumSignaturesIncludedInBlock(block *types.Block) uin // getLogger returns logger for consensus contexts added func (consensus *Consensus) getLogger() *zerolog.Logger { logger := utils.Logger().With(). - Uint64("myBlock", consensus.blockNum). + Uint64("myBlock", consensus.BlockNum()). Uint64("myViewID", consensus.GetCurBlockViewID()). Str("phase", consensus.phase.String()). Str("mode", consensus.current.Mode().String()). diff --git a/consensus/consensus_test.go b/consensus/consensus_test.go index d2bc0fb0b..36f7160de 100644 --- a/consensus/consensus_test.go +++ b/consensus/consensus_test.go @@ -38,7 +38,7 @@ func TestConsensusInitialization(t *testing.T) { // FBFTLog assert.Equal(t, fbtLog, consensus.FBFTLog) - assert.Equal(t, FBFTAnnounce, consensus.phase) + assert.Equal(t, FBFTAnnounce, consensus.phase.Get()) // State / consensus.current assert.Equal(t, state.mode, consensus.current.mode) diff --git a/consensus/consensus_v2.go b/consensus/consensus_v2.go index fc9cd56bf..11c7b9da9 100644 --- a/consensus/consensus_v2.go +++ b/consensus/consensus_v2.go @@ -135,7 +135,7 @@ func (consensus *Consensus) finalCommit() { consensus.getLogger().Info(). Int64("NumCommits", numCommits). Msg("[finalCommit] Finalizing Consensus") - beforeCatchupNum := consensus.blockNum + beforeCatchupNum := consensus.BlockNum() leaderPriKey, err := consensus.GetConsensusLeaderPrivateKey() if err != nil { @@ -188,7 +188,7 @@ func (consensus *Consensus) finalCommit() { } else { consensus.getLogger().Info(). Hex("blockHash", curBlockHash[:]). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Msg("[finalCommit] Sent Committed Message") } consensus.getLogger().Info().Msg("[finalCommit] Start consensus timer") @@ -203,7 +203,7 @@ func (consensus *Consensus) finalCommit() { p2p.ConstructMessage(msgToSend)) consensus.getLogger().Info(). Hex("blockHash", curBlockHash[:]). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Hex("lastCommitSig", commitSigAndBitmap). Msg("[finalCommit] Queued Committed Message") } @@ -211,7 +211,7 @@ func (consensus *Consensus) finalCommit() { block.SetCurrentCommitSig(commitSigAndBitmap) err = consensus.commitBlock(block, FBFTMsg) - if err != nil || consensus.blockNum-beforeCatchupNum != 1 { + if err != nil || consensus.BlockNum()-beforeCatchupNum != 1 { consensus.getLogger().Err(err). Uint64("beforeCatchupBlockNum", beforeCatchupNum). Msg("[finalCommit] Leader failed to commit the confirmed block") @@ -264,7 +264,7 @@ func (consensus *Consensus) finalCommit() { // BlockCommitSigs returns the byte array of aggregated // commit signature and bitmap signed on the block func (consensus *Consensus) BlockCommitSigs(blockNum uint64) ([]byte, error) { - if consensus.blockNum <= 1 { + if consensus.BlockNum() <= 1 { return nil, nil } lastCommits, err := consensus.Blockchain.ReadCommitSig(blockNum) @@ -363,7 +363,7 @@ func (consensus *Consensus) Start( case <-consensus.syncReadyChan: consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncReadyChan") consensus.mutex.Lock() - if consensus.blockNum < consensus.Blockchain.CurrentHeader().Number().Uint64()+1 { + if consensus.BlockNum() < consensus.Blockchain.CurrentHeader().Number().Uint64()+1 { consensus.SetBlockNum(consensus.Blockchain.CurrentHeader().Number().Uint64() + 1) consensus.SetViewIDs(consensus.Blockchain.CurrentHeader().ViewID().Uint64() + 1) mode := consensus.UpdateConsensusInformation() @@ -396,7 +396,7 @@ func (consensus *Consensus) Start( Uint64("MsgBlockNum", newBlock.NumberU64()). Msg("[ConsensusMainLoop] Received Proposed New Block!") - if newBlock.NumberU64() < consensus.blockNum { + if newBlock.NumberU64() < consensus.BlockNum() { consensus.getLogger().Warn().Uint64("newBlockNum", newBlock.NumberU64()). Msg("[ConsensusMainLoop] received old block, abort") continue @@ -443,7 +443,7 @@ func (consensus *Consensus) Close() error { // waitForCommit wait extra 2 seconds for commit phase to finish func (consensus *Consensus) waitForCommit() { - if consensus.Mode() != Normal || consensus.phase != FBFTCommit { + if consensus.Mode() != Normal || consensus.phase.Get() != FBFTCommit { return } // We only need to wait consensus is in normal commit phase @@ -569,7 +569,7 @@ func (consensus *Consensus) preCommitAndPropose(blk *types.Block) error { } else { consensus.getLogger().Info(). Str("blockHash", blk.Hash().Hex()). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Hex("lastCommitSig", bareMinimumCommit). Msg("[preCommitAndPropose] Sent Committed Message") } @@ -621,7 +621,7 @@ func (consensus *Consensus) tryCatchup() error { if consensus.BlockVerifier == nil { return errors.New("consensus haven't finished initialization") } - initBN := consensus.blockNum + initBN := consensus.BlockNum() defer consensus.postCatchup(initBN) blks, msgs, err := consensus.getLastMileBlocksAndMsg(initBN) @@ -683,7 +683,9 @@ func (consensus *Consensus) commitBlock(blk *types.Block, committedMsg *FBFTMess func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg *FBFTMessage) { atomic.StoreUint64(&consensus.blockNum, blk.NumberU64()+1) consensus.SetCurBlockViewID(committedMsg.ViewID + 1) + consensus.pubKeyLock.Lock() consensus.LeaderPubKey = committedMsg.SenderPubkeys[0] + consensus.pubKeyLock.Unlock() // Update consensus keys at last so the change of leader status doesn't mess up normal flow if blk.IsLastBlockInEpoch() { consensus.SetMode(consensus.UpdateConsensusInformation()) @@ -693,15 +695,15 @@ func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg } func (consensus *Consensus) postCatchup(initBN uint64) { - if initBN < consensus.blockNum { + if initBN < consensus.BlockNum() { consensus.getLogger().Info(). Uint64("From", initBN). - Uint64("To", consensus.blockNum). + Uint64("To", consensus.BlockNum()). Msg("[TryCatchup] Caught up!") consensus.switchPhase("TryCatchup", FBFTAnnounce) } // catch up and skip from view change trap - if initBN < consensus.blockNum && consensus.IsViewChangingMode() { + if initBN < consensus.BlockNum() && consensus.IsViewChangingMode() { consensus.current.SetMode(Normal) consensus.consensusTimeout[timeoutViewChange].Stop() } diff --git a/consensus/construct.go b/consensus/construct.go index ebf72c1d9..bbee71203 100644 --- a/consensus/construct.go +++ b/consensus/construct.go @@ -30,7 +30,7 @@ func (consensus *Consensus) populateMessageFields( request *msg_pb.ConsensusRequest, blockHash []byte, ) *msg_pb.ConsensusRequest { request.ViewId = consensus.GetCurBlockViewID() - request.BlockNum = consensus.blockNum + request.BlockNum = consensus.BlockNum() request.ShardId = consensus.ShardID // 32 byte block hash request.BlockHash = blockHash diff --git a/consensus/construct_test.go b/consensus/construct_test.go index 9e3b54712..37b76045c 100644 --- a/consensus/construct_test.go +++ b/consensus/construct_test.go @@ -2,6 +2,7 @@ package consensus import ( "bytes" + "sync/atomic" "testing" "github.com/ethereum/go-ethereum/common" @@ -86,7 +87,7 @@ func TestConstructPreparedMessage(test *testing.T) { []*bls.PublicKeyWrapper{&leaderKeyWrapper}, leaderPriKey.Sign(message), common.BytesToHash(consensus.blockHash[:]), - consensus.blockNum, + consensus.BlockNum(), consensus.GetCurBlockViewID(), ) if _, err := consensus.Decider.AddNewVote( @@ -94,7 +95,7 @@ func TestConstructPreparedMessage(test *testing.T) { []*bls.PublicKeyWrapper{&validatorKeyWrapper}, validatorPriKey.Sign(message), common.BytesToHash(consensus.blockHash[:]), - consensus.blockNum, + consensus.BlockNum(), consensus.GetCurBlockViewID(), ); err != nil { test.Log(err) @@ -156,7 +157,7 @@ func TestConstructPrepareMessage(test *testing.T) { consensus.SetCurBlockViewID(2) consensus.blockHash = [32]byte{} copy(consensus.blockHash[:], []byte("random")) - consensus.blockNum = 1000 + atomic.StoreUint64(&consensus.blockNum, 1000) sig := priKeyWrapper1.Pri.SignHash(consensus.blockHash[:]) network, err := consensus.construct(msg_pb.MessageType_PREPARE, nil, []*bls.PrivateKeyWrapper{&priKeyWrapper1}) @@ -248,7 +249,7 @@ func TestConstructCommitMessage(test *testing.T) { consensus.SetCurBlockViewID(2) consensus.blockHash = [32]byte{} copy(consensus.blockHash[:], []byte("random")) - consensus.blockNum = 1000 + atomic.StoreUint64(&consensus.blockNum, 1000) sigPayload := []byte("payload") diff --git a/consensus/debug.go b/consensus/debug.go index fb97c1116..95a34cd86 100644 --- a/consensus/debug.go +++ b/consensus/debug.go @@ -1,26 +1,21 @@ package consensus // GetConsensusPhase returns the current phase of the consensus -func (c *Consensus) GetConsensusPhase() string { - return c.phase.String() +func (consensus *Consensus) GetConsensusPhase() string { + return consensus.phase.String() } // GetConsensusMode returns the current mode of the consensus -func (c *Consensus) GetConsensusMode() string { - return c.current.mode.String() +func (consensus *Consensus) GetConsensusMode() string { + return consensus.current.mode.String() } // GetCurBlockViewID returns the current view ID of the consensus -func (c *Consensus) GetCurBlockViewID() uint64 { - return c.current.GetCurBlockViewID() +func (consensus *Consensus) GetCurBlockViewID() uint64 { + return consensus.current.GetCurBlockViewID() } // GetViewChangingID returns the current view changing ID of the consensus -func (c *Consensus) GetViewChangingID() uint64 { - return c.current.GetViewChangingID() -} - -// GetBlockNum return the current blockNum of the consensus struct -func (c *Consensus) GetBlockNum() uint64 { - return c.blockNum +func (consensus *Consensus) GetViewChangingID() uint64 { + return consensus.current.GetViewChangingID() } diff --git a/consensus/double_sign.go b/consensus/double_sign.go index a899961ca..98ffe3309 100644 --- a/consensus/double_sign.go +++ b/consensus/double_sign.go @@ -137,7 +137,7 @@ func (consensus *Consensus) checkDoubleSign(recvMsg *FBFTMessage) bool { func (consensus *Consensus) couldThisBeADoubleSigner( recvMsg *FBFTMessage, ) bool { - num, hash := consensus.blockNum, recvMsg.BlockHash + num, hash := consensus.BlockNum(), recvMsg.BlockHash suspicious := !consensus.FBFTLog.HasMatchingAnnounce(num, hash) || !consensus.FBFTLog.HasMatchingPrepared(num, hash) if suspicious { diff --git a/consensus/leader.go b/consensus/leader.go index 335b59bbf..940227f68 100644 --- a/consensus/leader.go +++ b/consensus/leader.go @@ -78,7 +78,7 @@ func (consensus *Consensus) announce(block *types.Block) { } // Construct broadcast p2p message if err := consensus.msgSender.SendWithRetry( - consensus.blockNum, msg_pb.MessageType_ANNOUNCE, []nodeconfig.GroupID{ + consensus.BlockNum(), msg_pb.MessageType_ANNOUNCE, []nodeconfig.GroupID{ nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)), }, p2p.ConstructMessage(msgToSend)); err != nil { consensus.getLogger().Warn(). @@ -99,7 +99,7 @@ func (consensus *Consensus) announce(block *types.Block) { func (consensus *Consensus) onPrepare(recvMsg *FBFTMessage) { // TODO(audit): make FBFT lookup using map instead of looping through all items. if !consensus.FBFTLog.HasMatchingViewAnnounce( - consensus.blockNum, consensus.GetCurBlockViewID(), recvMsg.BlockHash, + consensus.BlockNum(), consensus.GetCurBlockViewID(), recvMsg.BlockHash, ) { consensus.getLogger().Debug(). Uint64("MsgViewID", recvMsg.ViewID). diff --git a/consensus/threshold.go b/consensus/threshold.go index e1867a757..e4a7ff0a7 100644 --- a/consensus/threshold.go +++ b/consensus/threshold.go @@ -42,7 +42,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error { if err := rlp.DecodeBytes(consensus.block, &blockObj); err != nil { consensus.getLogger().Warn(). Err(err). - Uint64("BlockNum", consensus.blockNum). + Uint64("BlockNum", consensus.BlockNum()). Msg("[didReachPrepareQuorum] Unparseable block data") return err } @@ -69,7 +69,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error { } } if err := consensus.msgSender.SendWithRetry( - consensus.blockNum, + consensus.BlockNum(), msg_pb.MessageType_PREPARED, []nodeconfig.GroupID{ nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)), }, @@ -79,7 +79,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error { } else { consensus.getLogger().Info(). Hex("blockHash", consensus.blockHash[:]). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Msg("[OnPrepare] Sent Prepared Message!!") } consensus.msgSender.StopRetry(msg_pb.MessageType_ANNOUNCE) diff --git a/consensus/validator.go b/consensus/validator.go index 8a67d0de7..cacbac109 100644 --- a/consensus/validator.go +++ b/consensus/validator.go @@ -175,7 +175,7 @@ func (consensus *Consensus) sendCommitMessages(blockObj *types.Block) { consensus.getLogger().Warn().Err(err).Msg("[sendCommitMessages] Cannot send commit message!!") } else { consensus.getLogger().Info(). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Hex("blockHash", consensus.blockHash[:]). Msg("[sendCommitMessages] Sent Commit Message!!") } @@ -192,14 +192,14 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) { Uint64("MsgViewID", recvMsg.ViewID). Msg("[OnPrepared] Received prepared message") - if recvMsg.BlockNum < consensus.blockNum { + if recvMsg.BlockNum < consensus.BlockNum() { consensus.getLogger().Info().Uint64("MsgBlockNum", recvMsg.BlockNum). Msg("Wrong BlockNum Received, ignoring!") return } - if recvMsg.BlockNum > consensus.blockNum { + if recvMsg.BlockNum > consensus.BlockNum() { consensus.getLogger().Warn(). - Uint64("myBlockNum", consensus.blockNum). + Uint64("myBlockNum", consensus.BlockNum()). Uint64("MsgBlockNum", recvMsg.BlockNum). Hex("myBlockHash", consensus.blockHash[:]). Hex("MsgBlockHash", recvMsg.BlockHash[:]). @@ -245,10 +245,10 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) { } return } - if recvMsg.BlockNum > consensus.blockNum { + if recvMsg.BlockNum > consensus.BlockNum() { consensus.getLogger().Info(). Uint64("MsgBlockNum", recvMsg.BlockNum). - Uint64("blockNum", consensus.blockNum). + Uint64("blockNum", consensus.BlockNum()). Msg("[OnPrepared] Future Block Received, ignoring!!") return } @@ -274,12 +274,12 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) { if blockObj == nil { return } - curBlockNum := consensus.blockNum + curBlockNum := consensus.BlockNum() for _, committedMsg := range consensus.FBFTLog.GetNotVerifiedCommittedMessages(blockObj.NumberU64(), blockObj.Header().ViewID().Uint64(), blockObj.Hash()) { if committedMsg != nil { consensus.onCommitted(committedMsg) } - if curBlockNum < consensus.blockNum { + if curBlockNum < consensus.BlockNum() { consensus.getLogger().Info().Msg("[OnPrepared] Successfully caught up with committed message") break } @@ -297,16 +297,16 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) { Msg("[OnCommitted] Received committed message") // Ok to receive committed from last block since it could have more signatures - if recvMsg.BlockNum < consensus.blockNum-1 { + if recvMsg.BlockNum < consensus.BlockNum()-1 { consensus.getLogger().Info(). Uint64("MsgBlockNum", recvMsg.BlockNum). Msg("Wrong BlockNum Received, ignoring!") return } - if recvMsg.BlockNum > consensus.blockNum { + if recvMsg.BlockNum > consensus.BlockNum() { consensus.getLogger().Info(). - Uint64("myBlockNum", consensus.blockNum). + Uint64("myBlockNum", consensus.BlockNum()). Uint64("MsgBlockNum", recvMsg.BlockNum). Hex("myBlockHash", consensus.blockHash[:]). Hex("MsgBlockHash", recvMsg.BlockHash[:]). @@ -372,12 +372,12 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) { } } - initBn := consensus.blockNum + initBn := consensus.BlockNum() consensus.tryCatchup() - if recvMsg.BlockNum > consensus.blockNum { + if recvMsg.BlockNum > consensus.BlockNum() { consensus.getLogger().Info(). - Uint64("myBlockNum", consensus.blockNum). + Uint64("myBlockNum", consensus.BlockNum()). Uint64("MsgBlockNum", recvMsg.BlockNum). Hex("myBlockHash", consensus.blockHash[:]). Hex("MsgBlockHash", recvMsg.BlockHash[:]). @@ -395,7 +395,7 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) { consensus.getLogger().Debug().Msg("[OnCommitted] stop bootstrap timer only once") } - if initBn < consensus.blockNum { + if initBn < consensus.BlockNum() { consensus.getLogger().Info().Msg("[OnCommitted] Start consensus timer (new block added)") consensus.consensusTimeout[timeoutConsensus].Start() } diff --git a/consensus/view_change.go b/consensus/view_change.go index e83dac043..8ca60f826 100644 --- a/consensus/view_change.go +++ b/consensus/view_change.go @@ -264,7 +264,9 @@ func (consensus *Consensus) startViewChange() { // aganist the consensus.LeaderPubKey variable. // Ideally, we shall use another variable to keep track of the // leader pubkey in viewchange mode + consensus.pubKeyLock.Lock() consensus.LeaderPubKey = consensus.getNextLeaderKey(nextViewID) + consensus.pubKeyLock.Unlock() consensus.getLogger().Warn(). Uint64("nextViewID", nextViewID). @@ -285,7 +287,7 @@ func (consensus *Consensus) startViewChange() { if err := consensus.vc.InitPayload( consensus.FBFTLog, nextViewID, - consensus.blockNum, + consensus.BlockNum(), consensus.priKey, members); err != nil { consensus.getLogger().Error().Err(err).Msg("[startViewChange] Init Payload Error") @@ -299,7 +301,7 @@ func (consensus *Consensus) startViewChange() { } msgToSend := consensus.constructViewChangeMessage(&key) if err := consensus.msgSender.SendWithRetry( - consensus.blockNum, + consensus.BlockNum(), msg_pb.MessageType_VIEWCHANGE, []nodeconfig.GroupID{ nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))}, @@ -325,7 +327,7 @@ func (consensus *Consensus) startNewView(viewID uint64, newLeaderPriKey *bls.Pri } if err := consensus.msgSender.SendWithRetry( - consensus.blockNum, + consensus.BlockNum(), msg_pb.MessageType_NEWVIEW, []nodeconfig.GroupID{ nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))}, @@ -471,10 +473,10 @@ func (consensus *Consensus) onNewView(recvMsg *FBFTMessage) { Msg("[onNewView] Received NewView Message") // change view and leaderKey to keep in sync with network - if consensus.blockNum != recvMsg.BlockNum { + if consensus.BlockNum() != recvMsg.BlockNum { consensus.getLogger().Warn(). Uint64("MsgBlockNum", recvMsg.BlockNum). - Uint64("myBlockNum", consensus.blockNum). + Uint64("myBlockNum", consensus.BlockNum()). Msg("[onNewView] Invalid block number") return } @@ -551,7 +553,9 @@ func (consensus *Consensus) onNewView(recvMsg *FBFTMessage) { // newView message verified success, override my state consensus.SetViewIDs(recvMsg.ViewID) + consensus.pubKeyLock.Lock() consensus.LeaderPubKey = senderKey + consensus.pubKeyLock.Unlock() consensus.ResetViewChangeState() consensus.msgSender.StopRetry(msg_pb.MessageType_VIEWCHANGE) diff --git a/consensus/view_change_msg.go b/consensus/view_change_msg.go index db2dfd432..be1974105 100644 --- a/consensus/view_change_msg.go +++ b/consensus/view_change_msg.go @@ -24,7 +24,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra Request: &msg_pb.Message_Viewchange{ Viewchange: &msg_pb.ViewChangeRequest{ ViewId: consensus.GetViewChangingID(), - BlockNum: consensus.blockNum, + BlockNum: consensus.BlockNum(), ShardId: consensus.ShardID, SenderPubkey: priKey.Pub.Bytes[:], LeaderPubkey: consensus.LeaderPubKey.Bytes[:], @@ -33,7 +33,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra } preparedMsgs := consensus.FBFTLog.GetMessagesByTypeSeq( - msg_pb.MessageType_PREPARED, consensus.blockNum, + msg_pb.MessageType_PREPARED, consensus.BlockNum(), ) preparedMsg := consensus.FBFTLog.FindMessageByMaxViewID(preparedMsgs) @@ -107,7 +107,7 @@ func (consensus *Consensus) constructNewViewMessage(viewID uint64, priKey *bls.P Request: &msg_pb.Message_Viewchange{ Viewchange: &msg_pb.ViewChangeRequest{ ViewId: viewID, - BlockNum: consensus.blockNum, + BlockNum: consensus.BlockNum(), ShardId: consensus.ShardID, SenderPubkey: priKey.Pub.Bytes[:], }, diff --git a/consensus/view_change_test.go b/consensus/view_change_test.go index 20bf54f9e..fc80b6ccf 100644 --- a/consensus/view_change_test.go +++ b/consensus/view_change_test.go @@ -43,7 +43,7 @@ func TestPhaseSwitching(t *testing.T) { _, _, consensus, _, err := GenerateConsensusForTesting() assert.NoError(t, err) - assert.Equal(t, FBFTAnnounce, consensus.phase) // It's a new consensus, we should be at the FBFTAnnounce phase + assert.Equal(t, FBFTAnnounce, consensus.phase.Get()) // It's a new consensus, we should be at the FBFTAnnounce phase. switches := []phaseSwitch{ {start: FBFTAnnounce, end: FBFTPrepare}, @@ -73,10 +73,10 @@ func TestPhaseSwitching(t *testing.T) { func testPhaseGroupSwitching(t *testing.T, consensus *Consensus, phases []FBFTPhase, startPhase FBFTPhase, desiredPhase FBFTPhase) { for range phases { consensus.switchPhase("test", desiredPhase) - assert.Equal(t, desiredPhase, consensus.phase) + assert.Equal(t, desiredPhase, consensus.phase.Get()) } - assert.Equal(t, desiredPhase, consensus.phase) + assert.Equal(t, desiredPhase, consensus.phase.Get()) return } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 5b9b9bdb4..c23d38b56 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -23,6 +23,7 @@ import ( "math/big" "math/rand" "os" + "sync/atomic" "testing" "time" @@ -69,9 +70,13 @@ type testBlockChain struct { chainHeadFeed *event.Feed } +func (bc *testBlockChain) SetGasLimit(value uint64) { + atomic.StoreUint64(&bc.gasLimit, value) +} + func (bc *testBlockChain) CurrentBlock() *types.Block { return types.NewBlock(blockfactory.NewTestHeader().With(). - GasLimit(bc.gasLimit). + GasLimit(atomic.LoadUint64(&bc.gasLimit)). Header(), nil, nil, nil, nil, nil) } @@ -162,12 +167,14 @@ func createBlockChain() *BlockChainImpl { return blockchain } -func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { - statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) - blockchain := &testBlockChain{statedb, 1e18, new(event.Feed)} +func setupTxPool(chain blockChain) (*TxPool, *ecdsa.PrivateKey) { + if chain == nil { + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + chain = &testBlockChain{statedb, 1e18, new(event.Feed)} + } key, _ := crypto.GenerateKey() - pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, dummyErrorSink) + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, chain, dummyErrorSink) return pool, key } @@ -282,7 +289,7 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) { func TestInvalidTransactions(t *testing.T) { t.Parallel() - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() tx := transaction(0, 0, 100, key) @@ -324,8 +331,7 @@ func TestInvalidTransactions(t *testing.T) { func TestErrorSink(t *testing.T) { t.Parallel() - pool, key := setupTxPool() - pool.chain = createBlockChain() + pool, key := setupTxPool(createBlockChain()) defer pool.Stop() testTxErrorSink := types.NewTransactionErrorSink() @@ -396,8 +402,7 @@ func TestErrorSink(t *testing.T) { func TestCreateValidatorTransaction(t *testing.T) { t.Parallel() - pool, _ := setupTxPool() - pool.chain = createBlockChain() + pool, _ := setupTxPool(createBlockChain()) defer pool.Stop() fromKey, _ := crypto.GenerateKey() @@ -422,8 +427,7 @@ func TestCreateValidatorTransaction(t *testing.T) { func TestMixedTransactions(t *testing.T) { t.Parallel() - pool, _ := setupTxPool() - pool.chain = createBlockChain() + pool, _ := setupTxPool(createBlockChain()) defer pool.Stop() fromKey, _ := crypto.GenerateKey() @@ -457,7 +461,7 @@ func TestBlacklistedTransactions(t *testing.T) { // DO NOT parallelize, test will add accounts to tx pool config. // Create the pool - pool, _ := setupTxPool() + pool, _ := setupTxPool(nil) defer pool.Stop() // Create testing keys @@ -501,7 +505,7 @@ func TestBlacklistedTransactions(t *testing.T) { func TestTransactionQueue(t *testing.T) { t.Parallel() - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() tx := transaction(0, 0, 100, key) @@ -528,7 +532,7 @@ func TestTransactionQueue(t *testing.T) { t.Error("expected transaction queue to be empty. is", len(pool.queue)) } - pool, key = setupTxPool() + pool, key = setupTxPool(nil) defer pool.Stop() tx1 := transaction(0, 0, 100, key) @@ -555,7 +559,7 @@ func TestTransactionQueue(t *testing.T) { func TestTransactionNegativeValue(t *testing.T) { t.Parallel() - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() tx, _ := types.SignTx( @@ -569,9 +573,10 @@ func TestTransactionNegativeValue(t *testing.T) { } func TestTransactionChainFork(t *testing.T) { + t.Skip("This test doesn't work with race detector") t.Parallel() - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() addr := crypto.PubkeyToAddress(key.PublicKey) @@ -600,18 +605,13 @@ func TestTransactionChainFork(t *testing.T) { func TestTransactionDoubleNonce(t *testing.T) { t.Parallel() - pool, key := setupTxPool() - defer pool.Stop() - + key, _ := crypto.GenerateKey() addr := crypto.PubkeyToAddress(key.PublicKey) - resetState := func() { - statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) - statedb.AddBalance(addr, big.NewInt(1000000000000000000)) - - pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)} - pool.lockedReset(nil, nil) - } - resetState() + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + statedb.AddBalance(addr, big.NewInt(1000000000000000000)) + pool, _ := setupTxPool(&testBlockChain{statedb, 1000000, new(event.Feed)}) + defer pool.Stop() + pool.lockedReset(nil, nil) signer := types.HomesteadSigner{} tx1, _ := types.SignTx( @@ -656,7 +656,7 @@ func TestTransactionDoubleNonce(t *testing.T) { func TestTransactionMissingNonce(t *testing.T) { t.Parallel() - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() addr := crypto.PubkeyToAddress(key.PublicKey) @@ -680,7 +680,7 @@ func TestTransactionNonceRecovery(t *testing.T) { t.Parallel() const n = 10 - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() addr := crypto.PubkeyToAddress(key.PublicKey) @@ -706,7 +706,7 @@ func TestTransactionDropping(t *testing.T) { t.Parallel() // Create a test account and fund it - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() account, _ := deriveSender(transaction(0, 0, 0, key)) @@ -774,7 +774,7 @@ func TestTransactionDropping(t *testing.T) { t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4) } // Reduce the block gas limit, check that invalidated transactions are dropped - pool.chain.(*testBlockChain).gasLimit = 100 + pool.chain.(*testBlockChain).SetGasLimit(100) pool.lockedReset(nil, nil) if _, ok := pool.pending[account].txs.items[tx0.Nonce()]; !ok { @@ -918,7 +918,7 @@ func TestTransactionQueueAccountLimiting(t *testing.T) { t.Parallel() // Create a test account and fund it - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() account, _ := deriveSender(transaction(0, 0, 0, key)) @@ -1119,7 +1119,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { t.Parallel() // Add a batch of transactions to a pool one by one - pool1, key1 := setupTxPool() + pool1, key1 := setupTxPool(nil) defer pool1.Stop() account1, _ := deriveSender(transaction(0, 0, 0, key1)) @@ -1131,7 +1131,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { } } // Add a batch of transactions to a pool in one big batch - pool2, key2 := setupTxPool() + pool2, key2 := setupTxPool(nil) defer pool2.Stop() account2, _ := deriveSender(transaction(0, 0, 0, key2)) @@ -1523,7 +1523,7 @@ func BenchmarkPendingDemotion10000(b *testing.B) { benchmarkPendingDemotion(b, 1 func benchmarkPendingDemotion(b *testing.B, size int) { // Add a batch of transactions to a pool one by one - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() account, _ := deriveSender(transaction(0, 0, 0, key)) @@ -1548,7 +1548,7 @@ func BenchmarkFuturePromotion10000(b *testing.B) { benchmarkFuturePromotion(b, 1 func benchmarkFuturePromotion(b *testing.B, size int) { // Add a batch of transactions to a pool one by one - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() account, _ := deriveSender(transaction(0, 0, 0, key)) @@ -1568,7 +1568,7 @@ func benchmarkFuturePromotion(b *testing.B, size int) { // Benchmarks the speed of iterative transaction insertion. func BenchmarkPoolInsert(b *testing.B) { // Generate a batch of transactions to enqueue into the pool - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() account, _ := deriveSender(transaction(0, 0, 0, key)) @@ -1592,7 +1592,7 @@ func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 1 func benchmarkPoolBatchInsert(b *testing.B, size int) { // Generate a batch of transactions to enqueue into the pool - pool, key := setupTxPool() + pool, key := setupTxPool(nil) defer pool.Stop() account, _ := deriveSender(transaction(0, 0, 0, key)) diff --git a/core/types/block.go b/core/types/block.go index f8a3f126a..b4debaa72 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -227,6 +227,7 @@ type Block struct { ReceivedAt time.Time ReceivedFrom interface{} + commitLock sync.Mutex // Commit Signatures/Bitmap commitSigAndBitmap []byte } @@ -264,11 +265,15 @@ func (b *Block) SetCurrentCommitSig(sigAndBitmap []byte) { Int("dstLen", len(b.header.LastCommitSignature())). Msg("SetCurrentCommitSig: sig size mismatch") } + b.commitLock.Lock() b.commitSigAndBitmap = sigAndBitmap + b.commitLock.Unlock() } // GetCurrentCommitSig get the commit group signature that signed on this block. func (b *Block) GetCurrentCommitSig() []byte { + b.commitLock.Lock() + defer b.commitLock.Unlock() return b.commitSigAndBitmap } diff --git a/internal/utils/keylocker/keylocker.go b/internal/utils/keylocker/keylocker.go new file mode 100644 index 000000000..46ef6413d --- /dev/null +++ b/internal/utils/keylocker/keylocker.go @@ -0,0 +1,30 @@ +package keylocker + +import "sync" + +type KeyLocker struct { + m sync.Map +} + +func New() *KeyLocker { + return &KeyLocker{} +} + +func (a *KeyLocker) Lock(key interface{}, f func() (interface{}, error)) (interface{}, error) { + mu := &sync.Mutex{} + for { + actual, _ := a.m.LoadOrStore(key, mu) + mu2 := actual.(*sync.Mutex) + mu.Lock() + if mu2 != mu { + // acquired someone else lock. + mu.Unlock() + continue + } + rs, err := f() + mu.Unlock() + a.m.Delete(key) + return rs, err + } + +} diff --git a/internal/utils/keylocker/keylocker_test.go b/internal/utils/keylocker/keylocker_test.go new file mode 100644 index 000000000..511049544 --- /dev/null +++ b/internal/utils/keylocker/keylocker_test.go @@ -0,0 +1,17 @@ +package keylocker + +import "testing" + +func TestKeyLocker_Sequential(t *testing.T) { + n := New() + key := 1 + v := make(chan struct{}, 1) + n.Lock(key, func() (interface{}, error) { + v <- struct{}{} + return nil, nil + }) + <-v // we know that goroutine really executed + n.Lock(key, func() (interface{}, error) { + return nil, nil + }) +} diff --git a/internal/utils/singleton_test.go b/internal/utils/singleton_test.go deleted file mode 100644 index 3d3325f0f..000000000 --- a/internal/utils/singleton_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package utils - -import ( - "sync" - "testing" - "time" -) - -var NumThreads = 20 - -func TestSingleton(t *testing.T) { - var wg sync.WaitGroup - - for i := 0; i < NumThreads; i++ { - wg.Add(1) - go func() { - defer wg.Done() - time.Sleep(time.Millisecond) - - }() - } - wg.Wait() - - n := 100 - for i := 0; i < NumThreads; i++ { - wg.Add(1) - go func() { - defer wg.Done() - n++ - time.Sleep(time.Millisecond) - }() - } - - wg.Wait() -} diff --git a/internal/utils/timer.go b/internal/utils/timer.go index 2e8a77667..3502d68ec 100644 --- a/internal/utils/timer.go +++ b/internal/utils/timer.go @@ -1,6 +1,7 @@ package utils import ( + "sync" "time" ) @@ -19,6 +20,7 @@ type Timeout struct { state TimeoutState d time.Duration start time.Time + mu sync.Mutex } // NewTimeout creates a new timeout class @@ -29,18 +31,24 @@ func NewTimeout(d time.Duration) *Timeout { // Start starts the timeout clock func (timeout *Timeout) Start() { + timeout.mu.Lock() timeout.state = Active timeout.start = time.Now() + timeout.mu.Unlock() } // Stop stops the timeout clock func (timeout *Timeout) Stop() { + timeout.mu.Lock() timeout.state = Inactive timeout.start = time.Now() + timeout.mu.Unlock() } // CheckExpire checks whether the timeout is reached/expired func (timeout *Timeout) CheckExpire() bool { + timeout.mu.Lock() + defer timeout.mu.Unlock() if timeout.state == Active && time.Since(timeout.start) > timeout.d { timeout.state = Expired } @@ -52,17 +60,23 @@ func (timeout *Timeout) CheckExpire() bool { // Duration returns the duration period of timeout func (timeout *Timeout) Duration() time.Duration { + timeout.mu.Lock() + defer timeout.mu.Unlock() return timeout.d } // SetDuration set new duration for the timer func (timeout *Timeout) SetDuration(nd time.Duration) { + timeout.mu.Lock() timeout.d = nd + timeout.mu.Unlock() } // IsActive checks whether timeout clock is active; // A timeout is active means it's not stopped caused by stop // and also not expired with time elapses longer than duration from start func (timeout *Timeout) IsActive() bool { + timeout.mu.Lock() + defer timeout.mu.Unlock() return timeout.state == Active } diff --git a/node/api.go b/node/api.go index 0388a3500..bf74bd7ec 100644 --- a/node/api.go +++ b/node/api.go @@ -125,7 +125,7 @@ func (node *Node) GetConsensusCurViewID() uint64 { // GetConsensusBlockNum returns the current block number of the consensus func (node *Node) GetConsensusBlockNum() uint64 { - return node.Consensus.GetBlockNum() + return node.Consensus.BlockNum() } // GetConsensusInternal returns consensus internal data diff --git a/node/node_newblock.go b/node/node_newblock.go index 9aec09838..d4cf2807e 100644 --- a/node/node_newblock.go +++ b/node/node_newblock.go @@ -134,7 +134,7 @@ func (node *Node) ProposeNewBlock(commitSigs chan []byte) (*types.Block, error) header := node.Worker.GetCurrentHeader() // Update worker's current header and // state data in preparation to propose/process new transactions - leaderKey := node.Consensus.LeaderPubKey + leaderKey := node.Consensus.GetLeaderPubKey() var ( coinbase = node.GetAddressForBLSKey(leaderKey.Object, header.Epoch()) beneficiary = coinbase diff --git a/p2p/host.go b/p2p/host.go index 1db22e92d..b7b9c3eca 100644 --- a/p2p/host.go +++ b/p2p/host.go @@ -190,8 +190,8 @@ func NewHost(cfg HostConfig) (Host, error) { priKey: key, discovery: disc, security: security, - onConnections: []ConnectCallback{}, - onDisconnects: []DisconnectCallback{}, + onConnections: ConnectCallbacks{}, + onDisconnects: DisconnectCallbacks{}, logger: &subLogger, ctx: ctx, cancel: cancel, @@ -218,8 +218,8 @@ type HostV2 struct { security security.Security logger *zerolog.Logger blocklist libp2p_pubsub.Blacklist - onConnections []ConnectCallback - onDisconnects []DisconnectCallback + onConnections ConnectCallbacks + onDisconnects DisconnectCallbacks ctx context.Context cancel func() } @@ -433,7 +433,7 @@ func (host *HostV2) ListenClose(net libp2p_network.Network, addr ma.Multiaddr) { func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Conn) { host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer connected") - for _, function := range host.onConnections { + for _, function := range host.onConnections.GetAll() { if err := function(net, conn); err != nil { host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer connected callback") } @@ -444,7 +444,7 @@ func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Co func (host *HostV2) Disconnected(net libp2p_network.Network, conn libp2p_network.Conn) { host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer disconnected") - for _, function := range host.onDisconnects { + for _, function := range host.onDisconnects.GetAll() { if err := function(conn); err != nil { host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer disconnected callback") } @@ -462,11 +462,11 @@ func (host *HostV2) ClosedStream(net libp2p_network.Network, stream libp2p_netwo } func (host *HostV2) SetConnectCallback(callback ConnectCallback) { - host.onConnections = append(host.onConnections, callback) + host.onConnections.Add(callback) } func (host *HostV2) SetDisconnectCallback(callback DisconnectCallback) { - host.onDisconnects = append(host.onDisconnects, callback) + host.onDisconnects.Add(callback) } // NamedTopic represents pubsub topic diff --git a/p2p/stream/protocols/sync/chain.go b/p2p/stream/protocols/sync/chain.go index eb7b28d5d..c898a5ed9 100644 --- a/p2p/stream/protocols/sync/chain.go +++ b/p2p/stream/protocols/sync/chain.go @@ -6,6 +6,7 @@ import ( "github.com/harmony-one/harmony/consensus/engine" "github.com/harmony-one/harmony/core/types" shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" + "github.com/harmony-one/harmony/internal/utils/keylocker" "github.com/pkg/errors" ) @@ -18,14 +19,16 @@ type chainHelper interface { } type chainHelperImpl struct { - chain engine.ChainReader - schedule shardingconfig.Schedule + chain engine.ChainReader + schedule shardingconfig.Schedule + keyLocker *keylocker.KeyLocker } func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl { return &chainHelperImpl{ - chain: chain, - schedule: schedule, + chain: chain, + schedule: schedule, + keyLocker: keylocker.New(), } } @@ -89,19 +92,25 @@ func (ch *chainHelperImpl) getBlocksByHashes(hs []common.Hash) ([]*types.Block, return blocks, nil } -var errBlockNotFound = errors.New("block not found") - func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) { - b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64()) - if b == nil { - return nil, nil - } - commitSig, err := ch.getBlockSigAndBitmap(header) + rs, err := ch.keyLocker.Lock(header.Number().Uint64(), func() (interface{}, error) { + b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64()) + if b == nil { + return nil, nil + } + commitSig, err := ch.getBlockSigAndBitmap(header) + if err != nil { + return nil, errors.New("missing commit signature") + } + b.SetCurrentCommitSig(commitSig) + return b, nil + }) + if err != nil { - return nil, errors.New("missing commit signature") + return nil, err } - b.SetCurrentCommitSig(commitSig) - return b, nil + + return rs.(*types.Block), nil } func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) { diff --git a/p2p/stream/protocols/sync/protocol_test.go b/p2p/stream/protocols/sync/protocol_test.go index b16fdfc40..df6768195 100644 --- a/p2p/stream/protocols/sync/protocol_test.go +++ b/p2p/stream/protocols/sync/protocol_test.go @@ -2,6 +2,7 @@ package sync import ( "context" + "sync" "testing" "time" @@ -50,11 +51,12 @@ func TestProtocol_advertiseLoop(t *testing.T) { time.Sleep(150 * time.Millisecond) close(p.closeC) - if len(disc.advCnt) != len(p.supportedVersions()) { - t.Errorf("unexpected advertise topic count: %v / %v", len(disc.advCnt), + advCnt := disc.Extract() + if len(advCnt) != len(p.supportedVersions()) { + t.Errorf("unexpected advertise topic count: %v / %v", len(advCnt), len(p.supportedVersions())) } - for _, cnt := range disc.advCnt { + for _, cnt := range advCnt { if cnt < 1 { t.Errorf("unexpected discovery count: %v", cnt) } @@ -64,6 +66,7 @@ func TestProtocol_advertiseLoop(t *testing.T) { type testDiscovery struct { advCnt map[string]int sleep time.Duration + mu sync.Mutex } func newTestDiscovery(discInterval time.Duration) *testDiscovery { @@ -82,10 +85,20 @@ func (disc *testDiscovery) Close() error { } func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) { + disc.mu.Lock() + defer disc.mu.Unlock() disc.advCnt[ns]++ return disc.sleep, nil } +func (disc *testDiscovery) Extract() map[string]int { + disc.mu.Lock() + defer disc.mu.Unlock() + var out map[string]int + out, disc.advCnt = disc.advCnt, make(map[string]int) + return out +} + func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { return nil, nil } diff --git a/p2p/utils.go b/p2p/utils.go new file mode 100644 index 000000000..822963c1c --- /dev/null +++ b/p2p/utils.go @@ -0,0 +1,41 @@ +package p2p + +import "sync" + +type ConnectCallbacks struct { + cbs []ConnectCallback + mu sync.RWMutex +} + +func (a *ConnectCallbacks) Add(cb ConnectCallback) { + a.mu.Lock() + defer a.mu.Unlock() + a.cbs = append(a.cbs, cb) +} + +func (a *ConnectCallbacks) GetAll() []ConnectCallback { + a.mu.RLock() + defer a.mu.RUnlock() + out := make([]ConnectCallback, len(a.cbs)) + copy(out, a.cbs) + return out +} + +type DisconnectCallbacks struct { + cbs []DisconnectCallback + mu sync.RWMutex +} + +func (a *DisconnectCallbacks) Add(cb DisconnectCallback) { + a.mu.Lock() + defer a.mu.Unlock() + a.cbs = append(a.cbs, cb) +} + +func (a *DisconnectCallbacks) GetAll() []DisconnectCallback { + a.mu.RLock() + defer a.mu.RUnlock() + out := make([]DisconnectCallback, len(a.cbs)) + copy(out, a.cbs) + return out +} diff --git a/p2p/utils_test.go b/p2p/utils_test.go new file mode 100644 index 000000000..7d0ead3ea --- /dev/null +++ b/p2p/utils_test.go @@ -0,0 +1,37 @@ +package p2p + +import ( + "reflect" + "testing" + + libp2p_network "github.com/libp2p/go-libp2p-core/network" + "github.com/stretchr/testify/require" +) + +func TestConnectCallbacks(t *testing.T) { + cbs := ConnectCallbacks{} + fn := func(net libp2p_network.Network, conn libp2p_network.Conn) error { + return nil + } + + require.Equal(t, 0, len(cbs.GetAll())) + + cbs.Add(fn) + + require.Equal(t, 1, len(cbs.GetAll())) + require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(cbs.GetAll()[0]).Pointer()) +} + +func TestDisConnectCallbacks(t *testing.T) { + cbs := DisconnectCallbacks{} + fn := func(conn libp2p_network.Conn) error { + return nil + } + + require.Equal(t, 0, len(cbs.GetAll())) + + cbs.Add(fn) + + require.Equal(t, 1, len(cbs.GetAll())) + require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(cbs.GetAll()[0]).Pointer()) +} diff --git a/scripts/travis_go_checker.sh b/scripts/travis_go_checker.sh index 2ecadaa13..20da2bccc 100755 --- a/scripts/travis_go_checker.sh +++ b/scripts/travis_go_checker.sh @@ -82,7 +82,7 @@ fi echo "Running go test..." # Fix https://github.com/golang/go/issues/44129#issuecomment-788351567 go get -t ./... -if go test -v -count=1 ./... +if go test -v -count=1 -vet=all -race ./... then echo "go test succeeded." else